首页 > 其他 > 详细

pytorch0-梯度取反

时间:2021-06-21 20:41:30      阅读:21      评论:0      收藏:0      [点我收藏+]

借助torch.autograd中的Function

import torch
from torch.autograd import Function
import torch.nn as nn


class ReverseLayer(Function):
    @staticmethod
    def forward(ctx, x):
        return x
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg()


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.parameter1 = nn.Parameter(torch.ones(10, 10))
        self.parameter2 = nn.Parameter(torch.ones(10, 10))
        self.parameter3 = nn.Parameter(torch.ones(10, 10))
    def forward(self, x):
        return x@self.parameter1@self.parameter2@self.parameter3


class ReverseNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.parameter1 = nn.Parameter(torch.ones(10, 10))
        self.parameter2 = nn.Parameter(torch.ones(10, 10))
        self.parameter3 = nn.Parameter(torch.ones(10, 10))
    def forward(self, x):
        x1 = x@self.parameter1
        x2 = ReverseLayer.apply(x1@self.parameter2)
        return x2@self.parameter3


dataInput = torch.randn(2, 10)
dataTarget = torch.randn(2, 10)

net1 = Net()
net2 = ReverseNet()
loss1 = torch.mean(net1(dataInput) - dataTarget)
loss1.backward()
loss2 = torch.mean(net2(dataInput) - dataTarget)
loss2.backward()
print(‘=======================PARAMETER1============================‘)
print(net1.parameter1.grad[0])
print(net2.parameter1.grad[0])
print(‘=======================PARAMETER2============================‘)
print(net1.parameter2.grad[0])
print(net2.parameter2.grad[0])
print(‘=======================PARAMETER3============================‘)
print(net1.parameter3.grad[0])
print(net2.parameter3.grad[0])

‘‘‘
It can be seen that due to the chain rule, 
the derivative of all the layers before the reverse layer is taken to be negative 
‘‘‘

pytorch0-梯度取反

原文:https://www.cnblogs.com/tensorzhang/p/14913886.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!