• 【Pytorch】Visualization of Feature Maps(1)


    在这里插入图片描述

    学习参考来自


    文章目录


    filter 的激活值

    原理:找一张图片,使得某个 layer 的 filter 的激活值最大,这张图片就是能被这个 filter 所检测的对象。

    来个案例,流程:

    1. 初始化一张图片, 56X56
    2. 使用预训练好的 VGG16 网络,固定网络参数;
    3. 若想可视化第 40 层 layer 的第 k 个 filter 的 conv, 我们设置 loss 函数为 (-1*神经元激活值);
    4. 梯度下降, 对初始图片进行更新;
    5. 对得到的图片X1.2, 得到新的图片,重复上面的步骤;

    其中第五步比较关键,我们可以看到初始化的图片不是很大,只有56X56. 这是因为原文作者在实际做的时候发现,若初始图片较大,得到的特征的频率会较高,即没有现在这么好的显示效果。

    import torch
    from torch.autograd import Variable
    from PIL import Image, ImageOps
    import torchvision.transforms as transforms
    import torchvision.models as models
    
    import numpy as np
    import cv2
    from cv2 import resize
    from matplotlib import pyplot as plt
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    "initialize input image"
    sz = 56
    img = np.uint(np.random.uniform(150, 180, (3, sz, sz))) / 255  # (3, 56, 56)
    img = torch.from_numpy(img[None]).float().to(device)  # (1, 3, 56, 56)
    
    "pretrained model"
    model_vgg16 = models.vgg16_bn(pretrained=True).features.to(device).eval()
    # downloading /home/xxx/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth, 500M+
    # print(model_vgg16)
    # print(len(list(model_vgg16.children())))  # 44
    # print(list(model_vgg16.children()))
    
    "get the filter's output of one layer"
    # 使用hook来得到网络中间层的输出
    class SaveFeatures():
        def __init__(self, module):
            self.hook = module.register_forward_hook(self.hook_fn)
        def hook_fn(self, module, input, output):
            self.features = output.clone()
        def close(self):
            self.hook.remove()
    
    layer = 42
    activations = SaveFeatures(list(model_vgg16.children())[layer])
    
    "backpropagation, setting hyper-parameters"
    lr = 0.1
    opt_steps = 25 # 迭代次数
    filters = 265 # layer 42 的第 265 个 filter,使其激活值最大
    upscaling_steps = 13 # 图像放大次数
    blur = 3
    upscaling_factor = 1.2 # 放大倍率
    
    "preprocessing of datasets"
    cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1).to(device)
    cnn_normalization_std = torch.tensor([0.299, 0.224, 0.225]).view(-1, 1, 1).to(device)
    
    "gradient descent"
    for epoch in range(upscaling_steps):  # scale the image up up_scaling_steps times
        img = (img - cnn_normalization_mean) / cnn_normalization_std
        img[img > 1] = 1
        img[img < 0] = 0
        print("Image Shape1:", img.shape)
        img_var = Variable(img, requires_grad=True)  # convert image to Variable that requires grad
        "optimizer"
        optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6)
        for n in range(opt_steps):
            optimizer.zero_grad()
            model_vgg16(img_var)  # forward
            loss = -activations.features[0, filters].mean()  # max the activations
            loss.backward()
            optimizer.step()
    
        "restore the image"
        print("Loss:", loss.cpu().detach().numpy())
        img = img_var * cnn_normalization_std + cnn_normalization_mean
        img[img>1] = 1
        img[img<0] = 0
        img = img.data.cpu().numpy()[0].transpose(1,2,0)
        sz = int(upscaling_factor * sz)  # calculate new image size
        img = cv2.resize(img, (sz, sz), interpolation=cv2.INTER_CUBIC)  # scale image up
        if blur is not None:
            img = cv2.blur(img, (blur, blur))  # blur image to reduce high frequency patterns
        print("Image Shape2:", img.shape)
    
        img = torch.from_numpy(img.transpose(2, 0, 1)[None]).to(device)
        print("Image Shape3:", img.shape)
        print(str(epoch), ", Finished")
        print("="*10)
    
    activations.close()  # remove the hook
    
    image = img.cpu().clone()
    image = image.squeeze(0)
    unloader = transforms.ToPILImage()
    
    image = unloader(image)
    image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
    cv2.imwrite("res1.jpg", image)
    torch.cuda.empty_cache()
    
    
    """
    Image Shape1: torch.Size([1, 3, 56, 56])
    Loss: -6.0634975
    Image Shape2: (67, 67, 3)
    Image Shape3: torch.Size([1, 3, 67, 67])
    0 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 67, 67])
    Loss: -7.8898916
    Image Shape2: (80, 80, 3)
    Image Shape3: torch.Size([1, 3, 80, 80])
    1 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 80, 80])
    Loss: -8.730318
    Image Shape2: (96, 96, 3)
    Image Shape3: torch.Size([1, 3, 96, 96])
    2 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 96, 96])
    Loss: -9.697872
    Image Shape2: (115, 115, 3)
    Image Shape3: torch.Size([1, 3, 115, 115])
    3 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 115, 115])
    Loss: -10.190881
    Image Shape2: (138, 138, 3)
    Image Shape3: torch.Size([1, 3, 138, 138])
    4 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 138, 138])
    Loss: -10.315895
    Image Shape2: (165, 165, 3)
    Image Shape3: torch.Size([1, 3, 165, 165])
    5 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 165, 165])
    Loss: -9.73861
    Image Shape2: (198, 198, 3)
    Image Shape3: torch.Size([1, 3, 198, 198])
    6 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 198, 198])
    Loss: -9.503629
    Image Shape2: (237, 237, 3)
    Image Shape3: torch.Size([1, 3, 237, 237])
    7 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 237, 237])
    Loss: -9.488493
    Image Shape2: (284, 284, 3)
    Image Shape3: torch.Size([1, 3, 284, 284])
    8 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 284, 284])
    Loss: -9.100454
    Image Shape2: (340, 340, 3)
    Image Shape3: torch.Size([1, 3, 340, 340])
    9 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 340, 340])
    Loss: -8.699549
    Image Shape2: (408, 408, 3)
    Image Shape3: torch.Size([1, 3, 408, 408])
    10 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 408, 408])
    Loss: -8.90135
    Image Shape2: (489, 489, 3)
    Image Shape3: torch.Size([1, 3, 489, 489])
    11 , Finished
    ==========
    Image Shape1: torch.Size([1, 3, 489, 489])
    Loss: -8.838546
    Image Shape2: (586, 586, 3)
    Image Shape3: torch.Size([1, 3, 586, 586])
    12 , Finished
    ==========
    
    Process finished with exit code 0
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90
    • 91
    • 92
    • 93
    • 94
    • 95
    • 96
    • 97
    • 98
    • 99
    • 100
    • 101
    • 102
    • 103
    • 104
    • 105
    • 106
    • 107
    • 108
    • 109
    • 110
    • 111
    • 112
    • 113
    • 114
    • 115
    • 116
    • 117
    • 118
    • 119
    • 120
    • 121
    • 122
    • 123
    • 124
    • 125
    • 126
    • 127
    • 128
    • 129
    • 130
    • 131
    • 132
    • 133
    • 134
    • 135
    • 136
    • 137
    • 138
    • 139
    • 140
    • 141
    • 142
    • 143
    • 144
    • 145
    • 146
    • 147
    • 148
    • 149
    • 150
    • 151
    • 152
    • 153
    • 154
    • 155
    • 156
    • 157
    • 158
    • 159
    • 160
    • 161
    • 162
    • 163
    • 164
    • 165
    • 166
    • 167
    • 168
    • 169
    • 170
    • 171
    • 172
    • 173
    • 174
    • 175
    • 176
    • 177

    得到特征图

    请添加图片描述
    网上找个图片测试下,看响应是不是最大

    测试图片

    请添加图片描述

    import torch
    from torch.autograd import Variable
    from PIL import Image, ImageOps
    import torchvision.transforms as transforms
    import torchvision.models as models
    
    import numpy as np
    import cv2
    from cv2 import resize
    from matplotlib import pyplot as plt
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    class SaveFeatures():
        def __init__(self, module):
            self.hook = module.register_forward_hook(self.hook_fn)
        def hook_fn(self, module, input, output):
            self.features = output.clone()
        def close(self):
            self.hook.remove()
    
    size = (224, 224)
    picture = Image.open("./bird.jpg").convert("RGB")
    picture = ImageOps.fit(picture, size, Image.ANTIALIAS)
    
    loader = transforms.ToTensor()
    picture = loader(picture).to(device)
    print(picture.shape)
    
    cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1).to(device)
    cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1).to(device)
    
    picture = (picture-cnn_normalization_mean) / cnn_normalization_std
    
    model_vgg16 = models.vgg16_bn(pretrained=True).features.to(device).eval()
    print(list(model_vgg16.children())[40])  # Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    print(list(model_vgg16.children())[41])  # BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    print(list(model_vgg16.children())[42])  # ReLU(inplace=True)
    
    layer = 42
    filters = 265
    activations = SaveFeatures(list(model_vgg16.children())[layer])
    
    with torch.no_grad():
        picture_var = Variable(picture[None])
        model_vgg16(picture_var)
    activations.close()
    
    print(activations.features.shape)  # torch.Size([1, 512, 14, 14])
    
    # 画出每个 filter 的平均值
    mean_act = [activations.features[0, i].mean().item() for i in range(activations.features.shape[1])]
    plt.figure(figsize=(7,5))
    act = plt.plot(mean_act, linewidth=2.)
    extraticks = [filters]
    ax = act[0].axes
    ax.set_xlim(0, 500)
    plt.axvline(x=filters, color="gray", linestyle="--")
    ax.set_xlabel("feature map")
    ax.set_ylabel("mane activation")
    ax.set_xticks([0, 200, 400] + extraticks)
    plt.show()
    
    """
    torch.Size([3, 224, 224])
    Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    ReLU(inplace=True)
    torch.Size([1, 512, 14, 14])
    """
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70

    请添加图片描述

    可以看到,265 特征图对该输入的相应最高

    总结:实测了其他 layer 和 filter,画出来的直方图中,对应的 filter 相应未必是最高的,不过也很高,可能找的待测图片并不是最贴合设定 layer 的某个 filter 的特征。

  • 相关阅读:
    flink使用kryo支持自定义的序列化器
    redux的概念介绍和基础使用
    WebDAV之葫芦儿·派盘+人生Life
    软考网络工程师 第五章 第六节 WLAN安全
    单链表的排序操作
    vue2单元测试
    【数据分享】2023年我国科技型中小企业数据(免费获取/Excel格式/Shp格式)
    四、支付宝支付对接 - SDK开发、业务对接、支付回调、支付组件(2)
    天翎知识管理系统:强大的权限管理功能,保障知识安全
    buuctf web [极客大挑战 2019]BabySQL
  • 原文地址:https://blog.csdn.net/bryant_meng/article/details/134526106