• Python:torch.nn.Conv1d(), torch.nn.Conv2d()和torch.nn.Conv3d()函数理解


    Python:torch.nn.Conv1d(), torch.nn.Conv2d()和torch.nn.Conv3d()函数理解

    1. 函数参数

    在torch中的卷积操作有三个,torch.nn.Conv1d(),torch.nn.Conv2d()还有torch.nn.Conv3d(),这是搭建网络过程中常用的网络层,为了用好卷积层,需要知道这些参数代表的含义。

    这三种不同的卷积的输入参数是相同的,所以只看一个就可以。

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: _size_2_t,
            stride: _size_2_t = 1,
            padding: Union[str, _size_2_t] = 0,
            dilation: _size_2_t = 1,
            groups: int = 1,
            bias: bool = True,
            padding_mode: str = 'zeros',  # TODO: refine this type
            device=None,
            dtype=None
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    这里面的参数网上有很多说明,重点是怎么理解和使用。

    2. 参数理解

    这里面重点是in_channels参数,这个是代表数据输入的通道,很多说明这个通道是利用torch.nn.Conv2d处理图片数据来进行说明的,代表的是图片的通道数,然后面的两个参数对应着图片的长度和宽度。

    下面是本人对这参数的理解过程:

    • 首先对于torch.nn.Conv函数,所接受的数据是可以带有batch维度的,也可以不带有batch维度,这就表示对于torch.nn.Conv2d可以接受的数据包括3维数据或者4维数据,

    如:

    conv2 = torch.nn.Conv2d(16, 120, 3, stride=2)
    input2_3 = torch.randn(16, 5, 5)
    output2_3 = conv2(input2_3)
    print(output2_3.shape)
    
    input2_4 = torch.randn(20, 16, 5, 5)
    output2_4 = conv2(input2_4)
    print(output2_4.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    该段得到的输出为:

    torch.Size([120, 2, 2])
    torch.Size([20, 120, 2, 2])
    
    • 1
    • 2

    这是因为input2_4只是多了一个维度batch在第一个维度上,如果输入的数据是2维的或者5维的,就会提示如下的错误:指明只能接受3维的数据或者4维的数据.

    RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [20, 20, 16, 5, 5]
    
    • 1

    这其实就说明了根据自己数据维度选择合适的torch.nn.Conv, 例如,如果数据是2维的,那么就选择torch.nn.Conv1d,这个可以接收传入的数据维度可以是2维,或者是带有batch维度的3维数据。

    之后需要注意的是in_channels参数其实对应的就是传入数据的第一个维度(不带有batch)或者带有batch的第二个维度,这个要和in_channels参数相同。

    可以理解成这个in_channels就是表示了有多个卷积核在参与计算,那么剩下的维度正好就是卷积核的维度,

    如对于torch.nn.Conv3d,传入的数据最少是4维数据,(不带有batch),那么第一维的数据应该等于in_channels,然后剩下三维正好的是卷积核的维度。
    如:

    conv3 = torch.nn.Conv3d(16, 120, 3, stride=2)
    input3 = torch.randn(16, 5, 5, 5)
    output3 = conv3(input3)
    print(output3.shape)
    
    • 1
    • 2
    • 3
    • 4

    会得到

    torch.Size([120, 2, 2, 2])
    
    • 1

    这个卷积核是333,相当于有16个卷积核,每个卷积核在16维的数据上依次计算。

    其他的作为输出影响的是数据的维度大小,但是out_channels又决定了输出数据的第一个维度,(不带有batch),就可以依然用这个方式思考。

    针对后面几维数据的大小,由其他的参数决定,这个有公式可以计算,懒得算也可以直接打印输出看一下维度。

  • 相关阅读:
    EasyRecovery数据恢复软件 恢复了我两年前的照片视频数据
    《机器学习》第6章 支持向量机
    【冒泡排序设计】
    全球与中国板上芯片LED行业发展规模及投资前景预测报告2022-2028年
    业主方怎么管理固定资产
    ​中南建设2022年半年报“韧”字当头,经营性现金流持续为正​
    Word文档Aspose.Words使用教程:构建适用于Android的Word转PDF应用程序
    [附源码]计算机毕业设计基于SpringBoot的黄河文化科普网站
    什么是软件EV代码签名证书
    【微机原理笔记】第 3 章 - 8086/8088的指令系统
  • 原文地址:https://blog.csdn.net/qudunan6468/article/details/133591805