• 【Python】【Torch】神经网络中各层输出的特征图可视化详解和示例


    本文对神经网络各层特征图可视化的过程进行运行示例,方便大家使用,有助于更好的理解深度学习的过程,尤其是每层的结果。

    神经网络各层特征图可视化的好处和特点如下:

    可视化过程可以了解网络对图像像素的权重分布,可以了解网络对图像特征的提取过程,还可以剔除对特征表达无关紧要的像素,缩短网络训练时间,减少模型复杂度。
    可以将复杂多维数据以图像形式呈现,帮助科研人员更好的理解数据特征,同时可以建立定量化的图像与病理切片的对应关系,为后续病理研究提供可视化依据。

    本示例以一幅图象经过一层卷积输出为例进行。在自己运行时可以多加几层卷积和调整相应的输出通道等操作。

    import torch
    import torch.nn as nn
    import matplotlib.pyplot as plt
    from PIL import Image
    import numpy as np
    import math
    from torchvision import transforms
    # 定义一个卷积层
    conv_layer = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=1)
    
    # 输入图像(随机生成)
    image = Image.open("../11111.jpg")
    #input_image = torch.randn(1, 3, 224, 224)
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    # 对图像应用转换操作
    input_image= transform(image)
    input_image = input_image.unsqueeze(0)
    
    # 通过卷积层获取特征图
    feature_map = conv_layer(input_image)
    
    batch, channels, height, width = feature_map.shape
    blocks = torch.chunk(feature_map[0].cpu(), channels, dim=0)
    n = min(32, channels)  # number of plots
    fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)  # 8 rows x n/8 cols
    ax = ax.ravel()
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    for i in range(n):
        ax[i].imshow(blocks[i].squeeze().detach().numpy())  # cmap='gray'
        ax[i].axis('off')
    plt.savefig('./tezhengtu.jpg', dpi=300, bbox_inches='tight')
    plt.show()
    plt.close()
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36

    代码解释:
    步骤1 定义一个卷积层(Convolutional Layer):conv_layer,该卷积层有3个输入通道,64个输出通道, kernel size为3x3,步长为2,填充为1。
    步骤2输入图像:这里使用了一个真实的图像文件路径"…/11111.jpg"作为输入图像。你可以替换为你自己的图像文件路径。
    步骤3定义一个图像转换操作(transform)序列,用于将输入图像转换为PyTorch需要的张量格式。这里仅包含一个操作:转换为张量(ToTensor)。
    步骤4对输入图像应用转换操作:通过transform(image)将图像转换为PyTorch张量,然后通过unsqueeze(0)增加一个额外的维度(batch维度),使得输入图像的形状变为(1, 3, H, W)。

    步骤5通过卷积层获取特征图:将输入图像传递给卷积层conv_layer,得到特征图feature_map。
    步骤6将特征图转换为numpy数组:为了可视化,需要将特征图从PyTorch张量转换为numpy数组。这里使用了detach().numpy()方法来实现转换。
    步骤7获取特征图的一些属性:使用shape属性获取特征图的batch大小、通道数、高度和宽度。
    步骤8分块显示特征图:为了在图像中显示特征图,需要将特征图分块处理。这里使用torch.chunk方法将特征图按照通道数分割成若干块,每一块代表一个通道的输出。然后使用Matplotlib库中的subplot功能将分块后的图像显示在画布上。具体地,这段代码将分块后的图像显示在一个8x8的画布上,每个小图的尺寸为256x256像素(因为最后一块图像可能不足8个通道,所以使用了最少的小图数量)。最后使用savefig方法保存图像到文件,并关闭Matplotlib的画布。

    输入的图像为:
    在这里插入图片描述
    经过一层卷积之后的特征图为:

    在这里插入图片描述

  • 相关阅读:
    L3-005 垃圾箱分布
    关于 Java Lambda 表达式看这一篇就够了(强烈建议收藏)
    计算机出现msvcr110.dll丢失是什么意思?七种方法解决msvcr110.dll丢失
    Android BottomSheetDialog最大展开高度问题
    YOLOv5 PyQt5 | PyQt5快速入门 | 2/3
    vue3.2 导出pdf文件或表格数据
    Python基于Excel生成矢量图层及属性表信息:ArcPy
    广西建筑模板厂家-能强优品木业
    岛屿问题,矩阵:DFS+标记剪枝+回溯
    股票 SQL
  • 原文地址:https://blog.csdn.net/qq_22734027/article/details/134545730