首页 > 其他 > 详细

pytorch 踩坑笔记之w.grad.data.zero_()

时间:2019-07-22 17:43:00      阅读:1316      评论:0      收藏:0      [点我收藏+]

  在使用pytorch实现多项线性回归中,在grad更新时,每一次运算后都需要将上一次的梯度记录清空,运用如下方法:

     w.grad.data.zero_()
     b.grad.data.zero_() 

   但是,运行程序就会报如下错误:

技术分享图片

  报错,grad没有data这个属性,

  原因是,在系统将w的grad值初始化为none,第一次求梯度计算是在none值上进行报错,自然会没有data属性

  修改方法:添加一个判断语句,从第二次循环开始执行求导运算

for i in range(100):
    y_pred = multi_linear(x_train)
    loss = getloss(y_pred,y_train)
    if i != 0:
        w.grad.data.zero_()
        b.grad.data.zero_()
    loss.backward()
    w.data = w.data - 0.001 * w.grad.data
    b.data = b.data - 0.001 * b.grad.data

 

pytorch 踩坑笔记之w.grad.data.zero_()

原文:https://www.cnblogs.com/keep-s/p/11227159.html

(1)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!