• YOLOv5、v7改进之三十三:引入GAMAttention注意力机制


     前 言:作为当前先进的深度学习目标检测算法YOLOv7,已经集合了大量的trick,但是还是有提高和改进的空间,针对具体应用场景下的检测难点,可以不同的改进方法。此后的系列文章,将重点对YOLOv7的如何改进进行详细的介绍,目的是为了给那些搞科研的同学需要创新点或者搞工程项目的朋友需要达到更好的效果提供自己的微薄帮助和参考。由于出到YOLOv7,YOLOv5算法2020年至今已经涌现出大量改进论文,这个不论对于搞科研的同学或者已经工作的朋友来说,研究的价值和新颖度都不太够了,为与时俱进,以后改进算法以YOLOv7为基础,此前YOLOv5改进方法在YOLOv7同样适用,所以继续YOLOv5系列改进的序号。另外改进方法在YOLOv5等其他算法同样可以适用进行改进。希望能够对大家有帮助。

    具体改进办法请关注后私信留言!

    解决问题:之前改进增加了很多注意力机制的方法,包括比较常规的SE、CBAM等,本文加入SKAttention注意力机制,该注意力机制了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,我们提出了一种全局调度机制,通过减少信息缩减和放大全局交互表示来提高深度神经网络的性能,提高检测效果。

    基本原理:

          为了提高各种计算机视觉任务的性能,人们研究了各种注意机制。然而,以往的方法忽略了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,我们提出了一种全局调度机制,通过减少信息缩减和放大全局交互表示来提高深度神经网络的性能。我们沿着卷积空间注意子模块引入了用于通道注意的多层感知器3D置换。对CIFAR-100和ImageNet-1K上提出的图像分类任务机制的评估表明,我们的方法在ResNet和轻量级MobileNet上稳定地优于最近的几种注意机制。。

          对于ImageNet-1K,我们将图像预处理为224×224(He et al.[2016])。我们包括ResNet18和ResNet50(He et al.[2016]),以验证不同网络深度的方法推广。对于ResNet50,我们将其与群卷积进行了比较,以防止参数显著增加。我们将起始学习率设置为0.1,并每隔30个阶段降低一次。我们总共使用90个训练时段。在空间注意子模块中,我们将第一个块的第一步从1切换到2,以匹配特征的大小。为了进行公平比较,CBAM保留了其他设置,包括在空间注意子模块中使用最大池。3 MobileNet V2是用于图像分类的最高效的轻量级模型之一。我们对MobileNet V2使用相同的ResNet设置,只是使用了0.045的初始学习率和4×10的权重衰减−5.对ImageNet-1K的评估如表所示。它表明GAM可以稳定地提高不同神经架构的性能。尤其是对于ResNet18,GAM以更少的参数和更好的效率优于ABN。

      添加方法:

    第一步:确定添加的位置,作为即插即用的注意力模块,可以添加到YOLOv5网络中的任何地方。

     第二步:common.py构建GAMAttention模块。部分代码如下,关注文章末尾,私信后领取。

    1. import numpy as np
    2. import torch
    3. from torch import nn
    4. from torch.nn import init
    5. class GAMAttention(nn.Module):
    6. #https://paperswithcode.com/paper/global-attention-mechanism-retain-information
    7. def __init__(self, c1, c2, group=True,rate=4):
    8. super(GAMAttention, self).__init__()
    9. self.channel_attention = nn.Sequential(
    10. nn.Linear(c1, int(c1 / rate)),
    11. nn.ReLU(inplace=True),
    12. nn.Linear(int(c1 / rate), c1)
    13. )
    14. self.spatial_attention = nn.Sequential(
    15. nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3),
    16. nn.BatchNorm2d(int(c1 /rate)),
    17. nn.ReLU(inplace=True),
    18. nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3),
    19. nn.BatchNorm2d(c2)
    20. )
    21. def forward(self, x):
    22. b, c, h, w = x.shape
    23. x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
    24. x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
    25. x_channel_att = x_att_permute.permute(0, 3, 1, 2)
    26. x = x * x_channel_att
    27. x_spatial_att = self.spatial_attention(x).sigmoid()
    28. x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle
    29. out = x * x_spatial_att
    30. return out
    31. def channel_shuffle(x, groups=2): ##shuffle channel
    32. #RESHAPE----->transpose------->Flatten
    33. B, C, H, W = x.size()
    34. out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
    35. out=out.view(B, C, H, W)
    36. return out

    第三步:yolo.py中注册  GAMAttention模块

    1. elif m is GAMAttention:
    2. c1, c2 = ch[f], args[0]
    3. if c2 != no:
    4. c2 = make_divisible(c2 * gw, 8)

    第四步:修改yaml文件,本文以修改backbone为例,将原C3模块后加入该模块。

    1. # YOLOv5 backbone
    2. backbone:
    3. # [from, number, module, args] # [c=channels,module,kernlsize,strides]
    4. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 [c=3,64*0.5=32,3]
    5. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
    6. [-1, 3, C3, [128]],
    7. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
    8. [-1, 6, C3, [256]],
    9. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
    10. [-1, 9, C3, [512]],
    11. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
    12. [-1, 3, C3, [1024]],
    13. [-1, 1, GAMAttention, [1024,1024]], #9
    14. [-1, 1, SPPF, [1024,5]], #10
    15. ]

    第五步:将train.py中改为本文的yaml文件即可,开始训练。

    结 果:本人在遥感数据集上进行实验,有涨点效果。需要请关注留言。

    预告一下:下一篇内容将继续分享深度学习算法相关改进方法。有兴趣的朋友可以关注一下我,有问题可以留言或者私聊我哦

    PS:该方法不仅仅是适用改进YOLOv5,也可以改进其他的YOLO网络以及目标检测网络,比如YOLOv7、v6、v4、v3,Faster rcnn ,ssd等。

    最后,希望能互粉一下,做个朋友,一起学习交流。

  • 相关阅读:
    淘宝商品采集上架拼多多店铺(无货源数据采集接口,拼多多商品详情数据,淘宝商品详情数据)接口代码对接教程
    GZ038 物联网应用开发赛题第8套
    oracle面试题
    docker jenkins
    MODBUS协议下,能否实现MCGS触摸屏与FX5U之间无线通讯?
    项目成本管理
    目标检测算法 - YOLOv1
    【XGBoost】第 2 章:深度决策树
    软件测试书单/书籍推荐(整理更新中)
    dbLinq最新版linq sqlite
  • 原文地址:https://blog.csdn.net/m0_70388905/article/details/127330819