首页 > 其他 > 详细

Pytorch手写线性回归

时间:2019-08-19 09:29:18      阅读:132      评论:0      收藏:0      [点我收藏+]

pytorch手写线性回归

 

import torch
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

LEARN_RATE = 0.1
#1.准备数据
x = torch.randn([500,1])
y_true = x*0.8+3

#2.计算预测值 t_tred = x*w + b

w = torch.rand([],requires_grad=True)
b = torch.tensor(0.,requires_grad=True)

plt.figure()
plt.grid(True)

#开启交互模式
plt.ion()
for i in range(50):

    plt.cla()

    for j in [w,b]:
        if j.grad is not None:
            j.grad.zero_()
    y_predict = x*w+b

    #3.计算损失,把参数的梯度置为0,进行反向传播

    loss = (y_predict-y_true).pow(2).mean()

    loss.backward()

    #4.更新参数,grad表示导数

    w.data = w.data - LEARN_RATE*w.grad
    b.data = b.data - LEARN_RATE*b.grad


    plt.scatter(x.numpy(),y_true.numpy())
    plt.plot(x.numpy(),y_predict.detach().numpy(),color="g")

    plt.pause(0.1)


    if i %50 ==0:
        print( "第{}次,损失{},权重w={},偏执b={}".format(i,loss.data,w.data,b.data))

#关闭交互模式
plt.ioff()
plt.show()

  

Pytorch手写线性回归

原文:https://www.cnblogs.com/LiuXinyu12378/p/11374748.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!