首页 > 其他 > 详细

pytorch反向传播两次,梯度相加,retain_graph=True

时间:2020-10-15 12:09:03      阅读:288      评论:0      收藏:0      [点我收藏+]

在PyTorch中,loss.backward()之后,只有叶子节点的梯度会被保留,非叶子节点的梯度会被释放掉(可参考这篇博客)。

在默认情况下,PyTorch的网络只允许一次反向传播,如果要进行两次反向传播,则需要在第一次反向传播时设置retain_graph=True,即 loss.backwad(retain_graph=True) ,这样做可以保留第一次反向传播时非叶子节点的梯度,在第二次反向传播时,将自动和第二次的梯度相加。

示例:

import torch

input_ = torch.tensor([[1., 2.], [3., 4.]], requires_grad=False)
w1 = torch.tensor(2.0, requires_grad=True)
w2 = torch.tensor(3.0, requires_grad=True)

l1 = input_ * w1
l2 = l1 + w2
loss1 = l2.mean()
loss1.backward(retain_graph=True)

print(w1.grad)  # 输出:tensor(2.5)
print(w2.grad)  # 输出:tensor(1.)

loss2 = l2.sum()
loss2.backward()

print(w1.grad)  # 输出:tensor(12.5)
print(w2.grad)  # 输出:tensor(5.)

示例中的梯度推导很简单,我在这篇博客里推了一下。从输出结果来看,程序确实是把两次的梯度加起来了。

附注:如果网络要进行两次反向传播,却没有用retain_graph=True,则运行时会报错:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

 

pytorch反向传播两次,梯度相加,retain_graph=True

原文:https://www.cnblogs.com/picassooo/p/13818952.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!