• 九、池化层


    一、Pooling layers

    Pooling layers官网文档
    在这里插入图片描述
    MaxPool最大池化下采样
    MaxUnpool最大池化层上采样
    AvgPool最大池化层平均采样

    例如:池化核为(3,3),输入图像为(5,5),步长为1,不加边
    最大池化就是选出在池化核为单位图像中的最大的值
    在这里插入图片描述

    二、MaxPool2d

    torch.nn.MaxPool2d官网API
    torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

    参数解释
    kernel_size池化核的大小
    stride池化核移动时的步长,默认值为kernel_size
    padding是否对原图像进行加边操作
    dilation池化核映射到图像上每个池化核单元是否有距离,又称空洞池化或空洞卷积
    return_indices
    ceil_modeTrue时,使用ceil模式(2.31取2,向下取整,舍去小数,保留边缘剩余数据);False时使用floor模式,默认为False (2.31取3,向上取整,整数加一,舍去边缘剩余数据)

    三、代码实操

    对一个(5,5)的数据进行(kernel_size=3)池化核为(3,3),步长(stride)没传 默认为池化核大小(kernel_size=3)为3,padding也没传参默认为0表示不加边,ceil_mode=True保留边缘剩余数据
    MaxPool2d(kernel_size=3,ceil_mode=True)
    通过池化层处理之后,输出一个结果为二维tensor[[5,5],[5,5,]]
    在这里插入图片描述

    import torch
    from torch import nn
    from torch.nn import MaxPool2d
    
    input = torch.tensor([[1,2,3,4,5],
                          [5,4,3,2,1],
                          [1,2,3,4,5],
                          [5,4,3,2,1],
                          [1,2,3,4,5]],dtype=torch.float32)
    #RuntimeError: "max_pool2d" not implemented for 'Long'  需要将数据进行转换dtype=torch.float32
    
    input = torch.reshape(input,(-1,1,5,5))
    
    print(input.shape)#torch.Size([1, 1, 5, 5])
    
    class Beyond(nn.Module):
        def __init__(self):
            super(Beyond,self).__init__()
            self.maxpool_1 = MaxPool2d(kernel_size=3,ceil_mode=True)
    
        def forward(self,input):
            output = self.maxpool_1(input)
            return output
    
    beyond = Beyond()
    output = beyond(input)
    print(output)
    """
    tensor([[[[5., 5.],
              [5., 5.]]]])
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31

    四、池化层意义

    保留输入图像的特征,同时把数据量给减小
    (5,5)的图像,经过池化层就会变成(3,3)甚至(1,1)
    数据量的减小对于神经网络而言,参数少了,训练的速度会更快

    五、通过池化层处理CIFAR-10数据集

    对CIFAR10的1w张测试集数据进行(kernel_size=3)池化核为(3,3),步长(stride)没传 默认为池化核大小(kernel_size=3)为3,padding也没传参默认为0表示不加边,ceil_mode=True保留边缘剩余数据
    MaxPool2d(kernel_size=3,ceil_mode=True)

    import torch
    import torchvision
    from torch import nn
    from torch.nn import MaxPool2d
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    
    dataset_test = torchvision.datasets.CIFAR10("CIFAR_10",train=False,transform=torchvision.transforms.ToTensor(),download=True)
    
    dataloader = DataLoader(dataset=dataset_test,batch_size=64)#64张为一组
    
    class Beyond(nn.Module):
        def __init__(self):
            super(Beyond,self).__init__()
            self.maxpool_1 = MaxPool2d(kernel_size=3,ceil_mode=True)
    
        def forward(self,input):
            output = self.maxpool_1(input)
            return output
    
    beyond = Beyond()
    writer = SummaryWriter("y_log")
    
    i=0
    for data in dataloader:
        imgs,targets = data
        writer.add_images("input_maxpool",imgs,i)
        output = beyond(imgs)
        writer.add_images("output_maxpool",output,i)
        i = i + 1
    
    writer.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

    在Terminal下运行tensorboard --logdir=y_log --port=7870,logdir为打开事件文件的路径,port为指定端口打开;
    通过指定端口7870进行打开tensorboard,若不设置port参数,默认通过6006端口进行打开。
    在这里插入图片描述
    点击该链接或者复制链接到浏览器打开即可
    在这里插入图片描述

  • 相关阅读:
    c语言练习58:⾃定义类型:结构体
    client-go学习(6)Informer
    策略模式与模板模式的区别
    spark算子简单案例 - Python
    NL2SQL技术方案系列(6):金融领域知识检索,NL2SQL技术方案以及行业案例实战讲解4
    [附源码]java毕业设计游戏装备交易网站论文2022
    基于单片机的汽车智能仪表的设计
    【vue】主分支外的一些知识点
    人人都是艺术家!AI工具Doodly让潦草手绘变精美画作
    亮相2022南京软博会,创邻科技携Galaxybase图平台展现信创硬核实力
  • 原文地址:https://blog.csdn.net/qq_41264055/article/details/126442687