• pytorch代码实现注意力机制之GC


    GC注意力机制

    GC注意力机制来源于《GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond》一文当中,作者从Non-local Network的角度出发,发现对于不同位置点的attention map是几乎一致的,说明non-local中每个点计算attention map存在很大的计算浪费,从而提出了简化的NL,也就是SNL。
    更进一步地,作者还研究了和SENet的关联,基于SENet和SNL,提出了统一的框架,并结合两者优点提出了GCNet,计算量相对较小,又能很好地融合全局信息。

    论文地址:GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond

    SNL结构图
    GC注意力结构图

    代码实现:

    import torch
    from torch import nn as nn
    import torch.nn.functional as F
    from timm.models.layers.create_act import create_act_layer, get_act_layer
    from timm.models.layers import make_divisible
    from timm.models.layers.mlp import ConvMlp
    from timm.models.layers.norm import LayerNorm2d
    
    
    class GlobalContext(nn.Module):
    
        def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False,
                     rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
            super(GlobalContext, self).__init__()
            act_layer = get_act_layer(act_layer)
    
            self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
    
            if rd_channels is None:
                rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
            if fuse_add:
                self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
            else:
                self.mlp_add = None
            if fuse_scale:
                self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
            else:
                self.mlp_scale = None
    
            self.gate = create_act_layer(gate_layer)
            self.init_last_zero = init_last_zero
            self.reset_parameters()
    
        def reset_parameters(self):
            if self.conv_attn is not None:
                nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
            if self.mlp_add is not None:
                nn.init.zeros_(self.mlp_add.fc2.weight)
    
        def forward(self, x):
            B, C, H, W = x.shape
    
            if self.conv_attn is not None:
                attn = self.conv_attn(x).reshape(B, 1, H * W)  # (B, 1, H * W)
                attn = F.softmax(attn, dim=-1).unsqueeze(3)  # (B, 1, H * W, 1)
                context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
                context = context.view(B, C, 1, 1)
            else:
                context = x.mean(dim=(2, 3), keepdim=True)
    
            if self.mlp_scale is not None:
                mlp_x = self.mlp_scale(context)
                x = x * self.gate(mlp_x)
            if self.mlp_add is not None:
                mlp_x = self.mlp_add(context)
                x = x + mlp_x
    
            return x
    
    if __name__ == '__main__':
        input=torch.randn(50,512,7,7)
        gc = GlobalContext(512)
        output=gc(input)
        print(output.shape)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
  • 相关阅读:
    autpoi 1.4.3版本发布—Excel傻瓜式API,快速实现Excel导入导出、Word模板导出
    6.搭建Eureka
    Django 加载静态资源及<!DOCTYPE html>标红解决办法
    各种注意力机制,Attention、MLP、ReP等系列的PyTorch实现,含核心代码
    (四)旋转物体检测数据roLabelImg转DOTA格式
    C++语言的广泛应用领域
    C++智能指针
    flowable-ui部署
    go-zero环境搭建
    分页功能实现
  • 原文地址:https://blog.csdn.net/DM_zx/article/details/132731901