首页 > 其他 > 详细

[torch] pytorch hook学习

时间:2019-08-04 10:27:24      阅读:69      评论:0      收藏:0      [点我收藏+]

pytorch hook学习

register_hook

import torch
x = torch.Tensor([0,1,2,3]).requires_grad_()
y = torch.Tensor([4,5,6,7]).requires_grad_()
w = torch.Tensor([1,2,3,4]).requires_grad_()
z = x+y;
o = w.matmul(z) # o = w(x+y) 中间变量z
o.backward()
print(x.grad,y.grad,z.grad,w.grad,o.grad)

这里的o和z都是中间变量,不是通过指定值来定义的变量,所以是中间变量,所以pytorch并不存储这些变量的梯度。

对于中间变量z,hook的使用方式为: z.register_hook(hook_fn),其中 hook_fn为一个用户自定义的函数,其签名为:hook_fn(grad) -> Tensor or None。

它的输入为变量 z 的梯度,输出为一个 Tensor 或者是 None (None 一般用于直接打印梯度)。反向传播时,梯度传播到变量 z,再继续向前传播之前,将会传入 hook_fn。如果 hook_fn的返回值是 None,那么梯度将不改变,继续向前传播,如果 hook_fn的返回值是 Tensor 类型,则该 Tensor 将取代 z 原有的梯度,向前传播。

import torch
x = torch.Tensor([0,1,2,3]).requires_grad_()
y = torch.Tensor([4,5,6,7]).requires_grad_()
w = torch.Tensor([1,2,3,4]).requires_grad_()
z = x+y;
def hook_fn(grad):
    print(grad)
    return None

z.register_hook(hook_fn)
o = w.matmul(z) # o = w(x+y) 中间变量z
o.backward()
print(x.grad,y.grad,w.grad,z.grad,o.grad)

register_forward_hook

register_forward_hook的作用是获取前向传播过程中,各个网络模块的输入和输出。对于模块 module,其使用方式为:module.register_forward_hook(hook_fn) 。其中 hook_fn的签名为:

hook_fn(module, input, output) -> None

eg

import torch
from torch import nn
class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.fc1 = nn.Linear(3,4) # WT * X + bias
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(4,1)
        self.init()
    def init(self):
        with torch.no_grad():
            # WT * X + bias,所以W为4*3的矩阵,bias为1*4
            self.fc1.weight = torch.nn.Parameter(
                torch.Tensor([[1., 2., 3.],
                              [-4., -5., -6.],
                              [7., 8., 9.],
                              [-10., -11., -12.]]))
            self.fc1.bias = torch.nn.Parameter(torch.Tensor([1.0, 2.0, 3.0, 4.0]))
            self.fc2.weight = torch.nn.Parameter(torch.Tensor([[1.0, 2.0, 3.0, 4.0]]))
            self.fc2.bias = torch.nn.Parameter(torch.Tensor([1.0]))

    def forward(self,x):
        o = self.fc1(x)
        o = self.relu1(o)
        o = self.fc2(o)
        return o
def hook_fn_forward(module,input,output):
    print(module)
    print(input)
    print(output)


model = Model()
modules = model.named_children()
'''
named_children()
Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
'''
for name,module in modules:
    # 这里的name就是自己定义的self.xx的xx。如上面的fc1,fc2.
    # module代指的就是fc1代表的module等等
    module.register_forward_hook(hook_fn_forward)
x = torch.Tensor([[1.0,1.0,1.0]]).requires_grad_()
o = model(x)
o.backward()
 '''
 Linear(in_features=3, out_features=4, bias=True)
(tensor([[1., 1., 1.]], requires_grad=True),)
tensor([[  7., -13.,  27., -29.]], grad_fn=<AddmmBackward>)
ReLU()
(tensor([[  7., -13.,  27., -29.]], grad_fn=<AddmmBackward>),)
tensor([[ 7.,  0., 27.,  0.]], grad_fn=<ReluBackward0>)
Linear(in_features=4, out_features=1, bias=True)
(tensor([[ 7.,  0., 27.,  0.]], grad_fn=<ReluBackward0>),)
tensor([[89.]], grad_fn=<AddmmBackward>)
 
 '''

register_backward_hook

理同前者。得到梯度值。

hook_fn(module, grad_input, grad_output) -> Tensor or None

上面的代码forward全部替换为backward,结果为:

'''
Linear(in_features=4, out_features=1, bias=True)
(tensor([1.]), tensor([[1., 2., 3., 4.]]), tensor([[ 7.],
        [ 0.],
        [27.],
        [ 0.]]))
(tensor([[1.]]),)
ReLU()
(tensor([[1., 0., 3., 0.]]),)
(tensor([[1., 2., 3., 4.]]),)
Linear(in_features=3, out_features=4, bias=True)
(tensor([1., 0., 3., 0.]), tensor([[22., 26., 30.]]), tensor([[1., 0., 3., 0.],
        [1., 0., 3., 0.],
        [1., 0., 3., 0.]]))
(tensor([[1., 0., 3., 0.]]),)
'''

register_backward_hook只能操作简单模块,而不能操作包含多个子模块的复杂模块。 如果对复杂模块用了 backward hook,那么我们只能得到该模块最后一次简单操作的梯度信息。

可以这么用,可以得到一个模块的梯度。

class Mymodel(nn.Module):
    ......
    
model = Mymodel()
model.register_backward_hook(hook_fn_backward)

[torch] pytorch hook学习

原文:https://www.cnblogs.com/aoru45/p/11297066.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!