• 深度学习基础知识 register_buffer 与 register_parameter用法分析


    1、问题引入

    思考问题:定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习

    验证方案:定义网络模型,设置weigut与bias,遍历网络结构参数net.named_parameters(),如果定义的weight与bias在里面,则说明是可学习参数;否则,是不可学习参数

    import torch
    import torch.nn as nn
    
    # 思考两个问题,定义的weight与bias是否会被保存到网络的参数中,可否在优化器的作用下进行学习
    
    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule,self).__init__()
            self.conv1=nn.Conv2d(in_channels= 3,
                                out_channels= 6,
                                kernel_size=3,
                                stride = 1,
                                padding=1,
                                bias=False)
            
            self.conv2=nn.Conv2d(in_channels= 6,
                                out_channels= 9,
                                kernel_size=3,
                                stride = 1,
                                padding=1,
                                bias=False)
            
    
            self.waight=torch.ones(10,10)
            self.bias=torch.zeros(10)
    
        def forward(self,x):
            x=self.conv1(x)
            x=self.conv2(x)
            x = x * self.weight + self.bias
            return x
        
    net=MyModule()
    
    for name,param in net.named_parameters():  # 如果weight与bias在里面,说明其是可学习参数;否则,是不可学习参数
        print(name,param.shape)
    
    print("\n","-"*40,"\n")
    
    for key,val in net.state_dict().items():  # 说明weight与bias是不会被state_dict转化为字典中的元素的
        print(key,val.shape)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42

    打印分析结果:
    在这里插入图片描述
    可以看到,weight与bias不在其中,所以此种定义方式不会是的weight与bias成为可训练参数

    2、register_parameter()

    register_parameter()是 torch.nn.Module 类中的一个方法

    2.1 作用

    1、可将 self.weight 和 self.bias 定义为可学习的参数,保存到网络对象的参数中,被优化器作用进行学习
    2、self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

    2.2 用法

    register_parameter(name,param)

    • name:参数名称
    • param:参数张量, 须是 torch.nn.Parameter() 对象 或 None ,

    否则报错如下
    在这里插入图片描述

    import torch
    import torch.nn as nn
    
    
    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
            self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)
    
            self.register_parameter('weight', torch.nn.Parameter(torch.ones(10, 10)))
            self.register_parameter('bias', torch.nn.Parameter(torch.zeros(10)))
    
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = x * self.weight + self.bias
            return x
    
    
    net = MyModule()
    
    for name, param in net.named_parameters():
        print(name, param.shape)
    
    print('\n', '*'*40, '\n')
    
    for key, val in net.state_dict().items():
        print(key, val.shape)
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    结果显示:
    在这里插入图片描述

    3、register_buffer()

    register_buffer()是 torch.nn.Module() 类中的一个方法

    3.1 作用

    • 将 self.weight 和 self.bias 定义为不可学习的参数,不会被保存到网络对象的参数中,不会被优化器作用进行学习

    • self.weight 和 self.bias 可被保存到 state_dict 中,进而可以 保存到网络文件 / 网络参数文件中

    它用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数)

    • 参数:可以被优化器更新 (requires_grad=False / True)
    • buffer 中的数据 : 不会被优化器更新

    3.2 用法

    register_buffer(name,tensor)

    • name:参数名称
    • tensor:张量

    代码:

    import torch
    import torch.nn as nn
    
    
    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)
            self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)
    
            self.register_buffer('weight', torch.ones(10, 10))   # 注意:定义的方式
            self.register_buffer('bias', torch.zeros(10))
    
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            x = x * self.weight + self.bias
            return x
    
    
    net = MyModule()
    
    for name, param in net.named_parameters():
        print(name, param.shape)
    
    print('\n', '*'*40, '\n')
    
    for key, val in net.state_dict().items():
        print(key, val.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30

    效果如下所示:
    在这里插入图片描述

  • 相关阅读:
    Win11共享文件打不开怎么办?Win11共享文件打不开的解决方法
    那些年用Python踩过的坑
    【典型案例】验证号码
    新能源充电桩物联网应用之工业4G路由器
    八、互联网技术——物联网
    如何启用启用WordPress调试模式
    老陈打码老陈打码
    实战SRC漏洞挖掘全过程,流程详细【网络安全】
    【RTOS训练营】上节回顾、轻量级队列、轻量级事件组和晚课提问
    会声会影2022智能、快速、简单的视频剪辑软件
  • 原文地址:https://blog.csdn.net/guoqingru0311/article/details/133708652