• 第七章 解析PyTorch中Hook函数(工具)


    为了更深入地理解神经网络模型,有时候我们需要观察它训练得到的卷积核、特征图或者梯度等信息,这在CNN可视化研究中经常用到。其中,卷积核最易获取,将模型参数保存即可得到;特征图是中间变量,所对应的图像处理完即会被系统清除,否则将严重占用内存;梯度跟特征图类似,除了叶子结点外,其它中间变量的梯度都被会内存释放,因而不能直接获取。
    最容易想到的获取方法就是改变模型结构,在forward的最后不但返回模型的预测输出,还返回所需要的特征图等信息。

    如何在不改变模型结构的基础上获取特征图、梯度等信息呢?

    Pytorch的hook编程可以在不改变网络结构的基础上有效获取、改变模型中间变量以及梯度等信息。
    hook可以提取或改变Tensor的梯度,也可以获取nn.Module的输出和梯度(这里不能改变)。因此有4个hook函数用于实现以上功能:

    Tensor.register_hook(hook_fn)
    nn.Module.register_forward_hook(hook_fn)
    nn.Module.forward_register_forward_pre_hook(hook_fn)
    nn.Module.register_backward_hook(hook_fn)
    
    • 1
    • 2
    • 3
    • 4
    函数用途应用场景钩子函数签名
    register_hook在梯度计算时调用调试或修改梯度值hook_fn(grad) -> Tensor or None
    register_forward_hook在前向传播结束后调用检查、修改或记录模块的输入和输出hook_fn(module, input, output)
    register_forward_pre_hook在前向传播开始前调用修改进入模块的输入数据,或在模块执行任何操作前执行一些预处理步骤hook_fn(module, input)
    register_backward_hook在后向传播过程中调用检查、修改反向传播过程中的梯度hook_fn(module, grad_input, grad_output)

    下面对其用法进行一一介绍。

    register_hook_一

    功能:注册一个反向传播hook函数,用于自动记录Tensor的梯度。
    PyTorch对中间变量和非叶子节点的梯度运行完后会自动释放,以减缓内存占用。什么是中间变量?什么是非叶子节点?

    img

    上图中,a,b,d就是叶子节点,c,e,o是非叶子节点,也是中间变量。

    In [18]: a = torch.Tensor([1,2]).requires_grad_() 
        ...: b = torch.Tensor([3,4]).requires_grad_() 
        ...: d = torch.Tensor([2]).requires_grad_() 
        ...: c = a + b 
        ...: e = c * d 
        ...: o = e.sum()     
    In [19]: o.backward()
    
    In [20]: print(a.grad)
    tensor([2., 2.])
    
    In [21]: print(b.grad)
    tensor([2., 2.])
    
    In [22]: print(c.grad)
    None
    
    In [23]: print(d.grad)
    tensor([10.])
    
    In [24]: print(e.grad)
    None
    
    In [25]: print(o.grad)
    None
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    可以从程序的输出中看到,a,b,d作为叶子节点,经过反向传播后梯度值仍然保留,而其它非叶子节点的梯度已经被自动释放了,要想得到它们的梯度值,就需要使用hook了。

    我们首先自定义一个hook_fn函数,用于记录对Tensor梯度的操作,然后用Tensor.register_hook(hook_fn)对要获取梯度的非叶子结点的Tensor进行注册,然后重新反向传播一次:

    In [44]: def hook_fn(grad):
        ...:     print(grad)
        ...:
    
    In [45]: e.register_hook(hook_fn)
    Out[45]: 
    
    In [46]: o.backward()
    tensor([1., 1.])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    这时就自动输出了e的梯度。

    自定义的hook_fn函数的函数名可以是任取的,它的参数是grad,表示Tensor的梯度。这个自定义函数主要是用于描述对Tensor梯度值的操作,上例中我们是对梯度直接进行输出,所以是print(grad)。我们也可以把梯度装在一个列表或字典里,甚至可以修改梯度,这样如果梯度很小的时候将其变大一点就可以防止梯度消失的问题了:

    In [28]: a = torch.Tensor([1,2]).requires_grad_() 
        ...: b = torch.Tensor([3,4]).requires_grad_() 
        ...: d = torch.Tensor([2]).requires_grad_() 
        ...: c = a + b 
        ...: e = c * d 
        ...: o = e.sum()                                                            
    
    In [29]: grad_list = []                                                         
    
    In [30]: def hook(grad): 
        ...:     grad_list.append(grad)    # 将梯度装在列表里
        ...:     return 2 * grad    # 将梯度放大两倍
        ...:                                                                        
    
    In [31]: c.register_hook(hook)                                                  
    Out[31]: 
    
    In [32]: o.backward()                                                           
    
    In [33]: grad_list                                                              
    Out[33]: [tensor([2., 2.])]
    
    In [34]: a.grad                                                                 
    Out[34]: tensor([4., 4.])
    
    In [35]: b.grad                                                                 
    Out[35]: tensor([4., 4.])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27

    上例中,我们定义的hook函数执行了两个操作:一是将梯度装进列表grad_list中,二是把梯度放大两倍。从输出中我们可以看到,执行反向传播后,我们注册的非叶子节点c的梯度保存在了列表grad_list中,并且a和b的梯度都变为原来的两倍。这里需要注意的是,如果要将梯度值装在一个列表或字典里,那么首先要定义一个同名的全局变量的列表或字典,即使是局部变量,也要在自定义的hook函数外面。另一个需要注意的点就是如果要改变梯度值,hook函数要有返回值,返回改变后的梯度。

    这里总结一下,如果要获取非叶子节点Tensor的梯度值,我们需要在反向传播前
    1)自定义一个hook函数,描述对梯度的操作,函数名自拟,参数只有grad,表示Tensor的梯度;
    2)对要获取梯度的Tensor用方法Tensor.register_hook(hook)进行注册。
    3)执行反向传播。


    register_hook_二

    本节介绍张量的hook。在PyTorch的**计算图(computation graph)中,只有叶节点(leaf node)**的变量会保留梯度,而所有中间变量的梯度只在反向传播中使用,一旦反向传播完成,中间变量的梯度将自动释放,从而节约内存。

    下图是一个简单的计算图,其中 x , y , w x,y,w x,y,w是叶节点(直接给定数值的变量), z , o z,o z,o是中间变量(由其他变量计算得到的变量)。

    img

    import torch
    
    x = torch.Tensor([0, 1, 2, 3]).requires_grad_()
    y = torch.Tensor([4, 5, 6, 7]).requires_grad_()
    w = torch.Tensor([1, 2, 3, 4]).requires_grad_()
    z = x + y
    o = w.matmul(z)
    o.backward()
    
    print('x.requires_grad:', x.requires_grad)  # True
    print('y.requires_grad:', y.requires_grad)  # True
    print('z.requires_grad:', z.requires_grad)  # True
    print('w.requires_grad:', w.requires_grad)  # True
    print('o.requires_grad:', o.requires_grad)  # True
    
    print('x.grad:', x.grad)  # tensor([1., 2., 3., 4.])
    print('y.grad:', y.grad)  # tensor([1., 2., 3., 4.])
    print('w.grad:', w.grad)  # tensor([4., 6., 8., 10.])
    print('z.grad:', z.grad)  # None
    print('o.grad:', o.grad)  # None
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20

    从上面的例子中可以看出,由于 z , o z,o z,o是中间变量,它们虽然requires_grad的参数都是True,但反向传播后其梯度并没有保存下来,而是直接删除了,因此为None。如果想在反向传播后保留他们的梯度,则需要特殊指定:

    z.retain_grad()
    o.retain_grad()
    
    print('z.requires_grad:', z.requires_grad) # True
    print('o.requires_grad:', o.requires_grad) # True
    print('z.grad:', z.grad)  # tensor([1., 2., 3., 4.])
    print('o.grad:', o.grad)  # tensor(1.)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    但这种使用retain_grad()的方案会增加内存的占用,并不是一个好的方法。可以使用hook保存中间变量的梯度。

    对于中间变量 z z zhook的使用方法为:z.register_hook(hook_fn),其中hook_fn为一个用户自定义的函数:

    def hook_fn(grad): -> Tensor or None
    
    • 1

    该函数输入为变量 z z z的梯度,输出为一个TensorNoneNone一般用于直接打印梯度)。反向传播时,梯度传播到变量 z z z后,再继续往前传播之前,将会传入hook_fn函数。如果hook_fn的返回值是None,则梯度不改变,继续向前传播;如果hook_fn的返回值是Tensor类型,则该Tensor将取代变量 z z z原有的梯度,继续向前传播。

    下面的例子中hook_fn打印梯度值并修改为原来的两倍:

    def hook_fn(grad):
        print(g)
        g = 2 * grad
        return g
    
    z.register_hook(hook_fn)
    
    o.backward()  # tensor([1., 2., 3., 4.])
    print('z.grad:', z.grad)  # None
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9

    在实际代码中,为简化表示,也可以用lambda表达式代替函数,简写如下:

    z.register_hook(lambda x: print(x))
    z.register_hook(lambda x: 2*x)
    
    • 1
    • 2

    注意到一个变量可以绑定多个hook_fn函数,反向传播时,按绑定顺序依次执行。

    register_forward_hookregister_backward_hook

    这两个的操作对象都是nn.Module类,如神经网络中的卷积层(nn.Conv2d),全连接层(nn.Linear),池化层(nn.MaxPool2d, nn.AvgPool2d),激活层(nn.ReLU)或者nn.Sequential定义的小模块等,所以放在一起讲。

    对于模型的中间模块,也可以视作中间节点(非叶子节点),它的输出为特征图或激活值,反向传播的梯度值都会被系统自动释放,如果想要获取它们,就要用到hook功能。

    有名字即可看出,register_forward_hook是获取前向传播的输出的,即特征图或激活值;register_backward_hook是获取反向传播的输出的,即梯度值。它们的用法和上面介绍的register_hook类似。我们先看一下hook_fn的定义:

    对于register_forward_hook(hook_fn),其hook_fn函数定义如下:

    def forward_hook(module, input, output):
        operations
    
    • 1
    • 2

    这里有3个参数,分别表示:模块,模块的输入,模块的输出。函数用于描述对这些参数的操作,一般我们都是为了获取特征图,即只描述对output的操作即可。

    对于register_backward_hook(hook_fn),其hook_fn函数定义如下:

    def backward_hook(module, grad_in, grad_out):
        operations
    
    • 1
    • 2

    这里也有3个参数,分别表示:模块,模块输入端的梯度,模块输出端的梯度。这里需要特别注意的是,此处的输入端和输出端,是前向传播时的输入端和输出端,也就是说,上面的output的梯度对应这里的grad_out。例如线性模块:o=W*x+b,其输入端为 W,x 和 b,输出端为 o。

    如果模块有多个输入或者输出的话,grad_in和grad_out可以是 tuple 类型。对于线性模块:o=W*x+b ,它的输入端包括了W、x 和 b 三部分,因此 grad_input 就是一个包含三个元素的 tuple。

    这里注意和 forward hook 的不同:

    1. 在 forward hook 中,input 是 x,而不包括 W 和 b。
    2. 返回 Tensor 或者 None,backward hook 函数不能直接改变它的输入变量,但是可以返回新的 grad_in,反向传播到它上一个模块。

    此处的自定义的函数hook_fn也可以自拟名称,但如果两个hook函数同时使用的时候注意名称的区别,一般在函数名里添加对应的forward和backward就不易搞混了。

    下面看一个具体用例:

    #-*- utf-8 -*-
    
    '''本程序用于验证hook编程获取卷积层的输出特征图和特征图的梯度'''
    
    __author__ = 'puxitong from UESTC'
    
    import torch
    import torch.nn as nn
    import numpy as np 
    import torchvision.transforms as transforms
    
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3,6,3,1,1)
            self.relu1 = nn.ReLU()
            self.pool1 = nn.MaxPool2d(2,2)
            self.conv2 = nn.Conv2d(6,9,3,1,1)
            self.relu2 = nn.ReLU()
            self.pool2 = nn.MaxPool2d(2,2)
            self.fc1 = nn.Linear(8*8*9, 120)
            self.relu3 = nn.ReLU()
            self.fc2 = nn.Linear(120,10)
    
        def forward(self, x):
            out = self.pool1(self.relu1(self.conv1(x)))
            out = self.pool2(self.relu2(self.conv2(out)))
            out = out.view(out.shape[0], -1)
            out = self.relu3(self.fc1(out))
            out = self.fc2(out)
    
            return out
    
    
    def backward_hook(module, grad_in, grad_out):
        grad_block['grad_in'] = grad_in
        grad_block['grad_out'] = grad_out
    
    
    def farward_hook(module, inp, outp):
        fmap_block['input'] = inp
        fmap_block['output'] = outp
    
    
    loss_func = nn.CrossEntropyLoss()
    
    # 生成一个假标签以便演示
    label = torch.empty(1, dtype=torch.long).random_(3)
    
    # 生成一副假图像以便演示
    input_img = torch.randn(1,3,32,32).requires_grad_()  
    
    fmap_block = dict()  # 装feature map
    grad_block = dict()  # 装梯度
    
    net = Net()
    
    # 注册hook
    net.conv2.register_forward_hook(farward_hook)
    net.conv2.register_backward_hook(backward_hook)
    
    outs = net(input_img)
    loss = loss_func(outs, label)
    loss.backward()
    
    print('End.')
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67

    上面的程序中,我们先定义了一个简单的卷积神经网络模型,我们对第二层卷积模块进行hook注册,既获取它的输入输出,又获取输入输出的梯度,并将它们分别装在字典里。为了达到验证效果,我们随机生成一个假图像,它的尺寸和cifar-10数据集的图像尺寸一致,并且给这个假图像定义一个类别标签,用损失函数进行反向传播,模拟神经网络的训练过程。

    在IPython中运行程序后,相应的特征图和梯度就会出现在两个列表fmap_block和grad_block中了。我们看一下它们的输入和输出的维度:

    In [17]: len(fmap_block['input'])                                               
    Out[17]: 1
    
    In [18]: len(fmap_block['output'])                                              
    Out[18]: 1
    
    In [19]: len(grad_block['grad_in'])                                             
    Out[19]: 3
    
    In [20]: len(grad_block['grad_out'])                                            
    Out[20]: 1
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11

    可以看出,第二层卷积模块的输入和输出都只有一个,即相应的特征图。而输入端的梯度值有3个,分别为权重的梯度,偏差的梯度,以及输入特征图的梯度。输出端的梯度值只有一个,即输出特征图的梯度。正如上面强调的,输入端即使有W, X和b三个,对于前项传播来说只有X是其输入,而对于反向传播来说,3个都是输入。输出端3项的梯度值排列的顺序是什么呢,我们来看一下3项梯度的具体维度:

    In [21]: grad_block['grad_in'][0].shape                                         
    Out[21]: torch.Size([1, 6, 16, 16])
    
    In [22]: grad_block['grad_in'][1].shape                                         
    Out[22]: torch.Size([9, 6, 3, 3])
    
    In [23]: grad_block['grad_in'][2].shape                                         
    Out[23]: torch.Size([9])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    从输出端梯度的维度可以判断,第一个显然是特征图的梯度,第二个则是权重(卷积核/滤波器)的梯度,第三个是偏置的梯度。为了验证梯度和这些参数具有同样的维度,我们再来看看这三个值前向传播时的维度:

    In [24]: fmap_block['input'][0].shape                                           
    Out[24]: torch.Size([1, 6, 16, 16])
    
    In [25]: net.conv2.weight.shape         
    Out[25]: torch.Size([9, 6, 3, 3])
    
    In [26]: net.conv2.bias.shape                                                   
    Out[26]: torch.Size([9])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    可以看到,我们的判断是正确的。

    最后需要注意的一点是,如果需要获取输入图像的梯度,一定要将输入Tensor的requires_grad属性设为True。

    本节介绍模块的hook。模块不像上一节介绍的Tensor一样拥有显式的变量名可以访问,而是被封装在神经网络中。通常只能获得网络整体的输入和输出,而对于网络中间的模块,不仅很难得到它输入和输出的梯度,甚至连输入输出的数值都无法获得。比较麻烦的做法是,在forward函数的返回值中包含中间模块的输出;或者把网络按照模块的名称拆分再组合,提取中间层的特征。

    Pytorch设计了两种hookregister_forward_hookregister_backward_hook,分别用来获取前向传播和反向传播时中间层模块的输入和输出特征及梯度,从而大大降低了获取模型内部信息流的难度。

    register_forward_hook_二

    register_forward_hook的作用是获取前向传播过程中,网络各模块的输入和输出。对于模块module,其使用方法为:module.register_forward_hook(hook_fn),其中hook_fn为一个用户自定义的函数:

    def hook_fn(module, input, output): -> Tensor or None
    
    • 1

    hook_fn函数的输入变量分别为模块、模块的输入和模块的输出。输出为NonePytorch1.2.0之后的版本也可以返回张量,用于修改模块的输出。借助这个hook,可以方便的使用预训练的神经网络提取特征,而不用改变预训练网络的结构。下面是一个简单的例子:

    import torch
    from torch import nn
    
    #  全局变量,用于存储中间层的特征
    total_feat_out = []
    total_feat_in = []
    
    #  定义 forward hook function
    def hook_fn_forward(module, input, output):
        print(module)  # 打印模块名,用于区分模块
        print('input', input)   # 打印该模块的输入
        print('output', output) # 打印该模块的输出
        total_feat_out.append(output) # 保存该模块的输出
        total_feat_in.append(input)   # 保存该模块的输入
    
    model = Model()
    
    modules = model.named_children()
    for name, module in modules:
        module.register_forward_hook(hook_fn_forward)
    
    #  注意下面代码中 x 的维度,第一维是 batch size
    #  forward hook 中看不出来,但是 backward hook 中是必要的。
    x = torch.Tensor([[1.0, 1.0, 1.0]]).requires_grad_() 
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24

    register_backward_hook_二

    register_backward_hook的作用是获取反向传播过程中,网络各模块输入端和输出端的梯度值。对于模块module,其使用方法为:module.register_backward_hook(hook_fn),其中hook_fn为一个用户自定义的函数:

    def hook_fn(module, grad_input, grad_output): -> Tensor or None
    
    • 1

    hook_fn函数的输入变量分别为模块、模块输入端的梯度和模块输出端的梯度(这里的输入端和输出端是站在前向传播的角度来说的)。如果模块有多个输入端或输出端,则对应的梯度是tuple类型(例如对于线性模块,其grad_input是一个三元组,排列顺序分别为:对bias的导数、对输入x的导数、对权重W的导数)。下面是一个简单的例子:

    import torch
    from torch import nn
    
    #  全局变量,用于存储中间层的梯度
    total_grad_out = []
    total_grad_in = []
    
    # 定义 backward hook function
    def hook_fn_backward(module, grad_input, grad_output):
        print(module)  # 打印模块名,用于区分模块
        print('grad_output', grad_output)  # 打印该模块输出端的梯度
        print('grad_input', grad_input)    # 打印该模块输入端的梯度
        total_grad_in.append(grad_input)   # 保存该模块输入端的梯度
        total_grad_out.append(grad_output) # 保存该模块输出端的梯度
    
    model = Model()
    
    modules = model.named_children()
    for name, module in modules:
        module.register_backward_hook(hook_fn_backward)
    
    #  这里的 requires_grad 很重要,如果不加,backward hook
    #  执行到第一层,对 x 的导数将为 None 。
    #  此外再强调一遍 x 的维度,第一维一定是 batch size
    x = torch.Tensor([[1.0, 1.0, 1.0]]).requires_grad_()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25

    注意事项

    register_backward_hook在全连接层和卷积层中的表现是不一致的,具体如下:

    • 形状不一致

      1. 在卷积层中,weight的梯度和weight的形状相同;
      2. 在全连接层中,weight的梯度的形状是weight形状的转置。
    • grad_input
      
      • 1

      元组中梯度的顺序不一致

      1. 在卷积层中,梯度的顺序为:(对feature的梯度,对weight的梯度,对bias的梯度);
      2. 在全连接层中,梯度的顺序为:(对bias的梯度,对feature的梯度,对weight的梯度)。
    • 当batch size大于 1 1 1时,对bias的梯度处理不一致

      1. 在卷积层中,对bias的梯度为整个batch的数据在bias上的梯度之和:(对feature的梯度,对weight的梯度,对bias的梯度);
      2. 在全连接层中,对bias的梯度是分开的,batch中的每个数据对应一个bias的梯度:((data1bias的梯度,data2bias的梯度…),对feature的梯度,对weight的梯度)。

    特别地,如果已知某个模块的类型,也可以用下面的方式对其加hook

    for name, module in modules:
        if isinstance(module, nn.ReLU):
            module.register_forward_hook(forward_hook_fn)
            module.register_backward_hook(backward_hook_fn)
    
    • 1
    • 2
    • 3
    • 4

    forward_register_forward_pre_hook

    在PyTorch中,torch.nn.Module.register_forward_pre_hook方法用于在网络模块(Module)的前向传播过程开始前注册一个钩子(hook)。这个钩子函数在每次前向传播调用之前自动执行,可以用于检查、修改或记录模块的输入。

    • 参数:

      • module: 要注册钩子的网络层或模块。
      • hook: 一个函数,其签名为 hook(module, input),其中module是当前层,input是传入该层的输入数据(一个元组)。
    • 返回值:

      • 返回一个handle,可用于随后移除该钩子。
    • 用途:

      • 可用于调试、可视化、修改输入等目的。这在复杂网络的分析或修改中特别有用。

    示例

    下面的示例展示了如何使用register_forward_pre_hook来打印每个层的输入形状:

    import torch
    import torch.nn as nn
    
    # 定义一个简单的网络
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            return x
    
    # 定义钩子函数
    def print_input_shape(module, input):
        print(f"Entering {module.__class__.__name__} with input shape: {input[0].shape}")
    
    # 创建网络实例
    net = Net()
    
    # 注册钩子
    hook_handle1 = net.conv1.register_forward_pre_hook(print_input_shape)
    hook_handle2 = net.conv2.register_forward_pre_hook(print_input_shape)
    
    # 输入数据
    input_data = torch.randn(1, 1, 28, 28)
    
    # 前向传播
    output = net(input_data)
    
    # 如果需要,可以移除钩子
    hook_handle1.remove()
    hook_handle2.remove()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35

    在这个示例中,我们定义了一个包含两个卷积层的简单网络。对每个卷积层,我们注册了一个前向传播前钩子,该钩子会打印每个层的输入形状。这对于理解网络内部的数据流非常有帮助,特别是当你在调试复杂的网络结构时。

    当不再需要这些钩子时,可以通过调用handle.remove()来移除它们,以避免不必要的性能开销。

    可视化特征图

    要利用register_forward_pre_hook可视化中间特征图,您可以注册一个钩子函数,该函数在模型的前向传播过程中的特定层被调用时,捕获并处理该层的输入数据。以下是一个示例程序,展示了如何实现这一功能:

    首先,您需要一个基本的卷积神经网络模型。在这个示例中,我将创建一个简单的网络用于演示:

    import torch
    import torch.nn as nn
    import torchvision
    import matplotlib.pyplot as plt
    
    # 定义网络结构
    class SimpleCNN(nn.Module):
        def __init__(self):
            super(SimpleCNN, self).__init__()
            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            return x
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16

    然后,定义一个函数来注册为钩子,该函数将捕获并可视化特征图:

    def visualize_feature_map(module, input):
        # 获取输入特征图的第一个元素(批次中的第一个样本)
        x = input[0][0]
        x = x.detach()  # 确保不计算梯度
    
        # 转换为numpy数组并可视化
        num_feature_maps = x.shape[0]
        fig, axes = plt.subplots(1, num_feature_maps, figsize=(num_feature_maps * 2, 2))
        for i in range(num_feature_maps):
            axes[i].imshow(x[i].cpu().numpy(), cmap='gray')
            axes[i].axis('off')
        plt.show()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12

    接下来,创建网络实例并注册钩子:

    # 创建网络实例
    model = SimpleCNN()
    
    # 注册钩子以可视化第一个卷积层的输入特征图
    hook_handle = model.conv1.register_forward_pre_hook(visualize_feature_map)
    
    • 1
    • 2
    • 3
    • 4
    • 5

    最后,进行一次前向传播以触发钩子并可视化特征图:

    # 创建一个随机输入
    input_tensor = torch.randn(1, 1, 28, 28)
    
    # 执行前向传播
    model(input_tensor)
    
    # 移除钩子
    hook_handle.remove()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    在这个示例中,visualize_feature_map函数会在每次模型的第一个卷积层接收输入时被调用,并可视化该层的输入特征图。您可以根据需要调整网络结构、钩子函数和输入数据。

    请注意,这个示例假设您已经安装了必要的库,如torch, torchvision, matplotlib。此外,钩子函数中的可视化部分使用了matplotlib库进行特征图的展示,您可能需要根据自己的环境(例如,是否在Jupyter笔记本中运行)进行调整。


    另外一版可视化特征图:

    import torch
    import torch.nn as nn
    import torchvision
    from torch.utils.tensorboard import SummaryWriter
    
    # 定义一个简单的钩子函数,用于可视化中间特征图
    def visualize_feature_map(module, input):
        x = input[0]  # 获取输入
        image_grid = torchvision.utils.make_grid(x, normalize=True, scale_each=True)  # 创建图像网格
        writer.add_image("Feature Map", image_grid, global_step=0)  # 添加到 TensorBoard 中
    
    # 创建模型
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
            self.pool = nn.MaxPool2d(2, 2)
    
        def forward(self, x):
            x = self.pool(self.conv1(x))
            return x
    
    # 初始化模型和数据
    model = MyModel()
    writer = SummaryWriter()  # 创建一个 TensorBoard SummaryWriter
    
    # 注册前向传播前的钩子
    hook_handle = model.conv1.register_forward_pre_hook(visualize_feature_map)
    
    # 创建一个随机输入
    input_data = torch.rand(1, 3, 64, 64)
    
    # 前向传播
    output = model(input_data)
    
    # 移除钩子
    hook_handle.remove()
    
    # 关闭 SummaryWriter
    writer.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
  • 相关阅读:
    yolo训练时遇到GBK编码问题
    Lidar & IMU & GNSS in ENU
    修完这个 Bug 后,MySQL 性能提升了 300%
    干货 | BitSail Connector 开发详解系列一:Source
    在前端开发领域,如何将AI技术应用于产品开发中?
    【代码随想录】算法训练计划23
    SQL中LIKE和REGEXP简单对比
    Git_GitHub——基本操作、创建远程库、远程库操作、团队协作、SSH免密登录
    网络常见的小知识点
    Qt绘图机制
  • 原文地址:https://blog.csdn.net/weixin_44302770/article/details/134485676