• Swin transformer v2和Swin transformer v1源码对比


    swin transformer v1源码见我的博客:

    swin_transformer源码详解_樱花的浪漫的博客-CSDN博客_swin transformer代码解析 

    在此只解析v1和v2的区别 

    1.q,k,v的映射 

            在通过x投影得到q,k,v的过程中,swin transformer v2将权重weight和偏置项bias分开进行更新,可能作者觉得普通的线性投影比较受限,而采取分开初始化的方式更能找到合适的参数。

    1. self.qkv = nn.Linear(dim, dim * 3, bias=False)
    2. # 偏置项作为可学习的参数
    3. if qkv_bias:
    4. self.q_bias = nn.Parameter(torch.zeros(dim))
    5. self.v_bias = nn.Parameter(torch.zeros(dim))
    6. else:
    7. self.q_bias = None
    8. self.v_bias = None

    2.余弦注意力

            作者认为原来的标准注意力机制容易使网络陷入极端值,因此提出了一个缩放的余弦注意来取代以前的点积注意。缩放的余弦注意使得计算与块输入的振幅无关,并且注意值不太可能落入极端值。 

            从代码实现上,就是先除以q,k最后一个维度的范数,再做点乘操作,相当于先用最后一个维度的范数做了一个归一化。 对于缩放系数\tau\tau是一个可学习的参数,用10进行初始化。

    1. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
    2. attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
    3. logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01).to(attn.device))).exp()
    4. attn = attn * logit_scale

             

    3.相对位置编码 

            对于相对位置编码,v1的做法是通过创建一个可学习的2*Wh-1 * 2*Ww-1, nH维度的位置参数,按照相对位置索引将相对位置信息加入到注意力机制中。但是这种位置编码方式时在窗口大小发生改变,需要将低分辨率训练的权重转移到高分辨率的图像时,只能通过双三次插值的方法。受到了限制,于是,作者引入了一个对数间隔的连续位置偏差(Log-CPB),它通过对对数间隔的坐标输入应用一个小的元网络来生成任意坐标范围的偏差值。

            代码实现来说,首先创建一个包含相对位置信息的矩阵,维度为1, 2*Wh-1, 2*Ww-1, 2。对于相对位置上的值,初始化为[-w+1,w-1]区间大小的相对位置值,在不转移窗口位置权重的情况下,将相对位置值标准化至[-8,8]区间内,根据作者的思想,作者这样限制是想要限制窗口改变时的外推比。然后,使用两层线性层和一层relu激活的全连接层生成位置参数,维度为 2*Wh-1 *  2*Ww-1,num_heads,然后根据索引,将位置参数加入到注意力机制中。

            最后,为了对应于初始化中将相对位置的值标准化至[-8,8]区间,使用

     16 * torch.sigmoid(relative_position_bias)将对位置参数的值进行标准化。
    
    1. self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
    2. nn.ReLU(inplace=True),
    3. nn.Linear(512, num_heads, bias=False))
    4. # get relative_coords_table
    5. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
    6. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
    7. # 相对位置的矩阵[-w+1,w-1],限制在区间[-8.8]
    8. relative_coords_table = torch.stack(
    9. torch.meshgrid([relative_coords_h,
    10. relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
    11. if pretrained_window_size[0] > 0:
    12. relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
    13. relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
    14. else:
    15. relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
    16. relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
    17. relative_coords_table *= 8 # normalize to -8, 8
    18. # Log-spaced coordinates 使用了8进行规范化
    19. relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
    20. torch.abs(relative_coords_table) + 1.0) / np.log2(8)
    21. self.register_buffer("relative_coords_table", relative_coords_table)
    22. # get pair-wise relative position index for each token inside the window
    23. coords_h = torch.arange(self.window_size[0])
    24. coords_w = torch.arange(self.window_size[1])
    25. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
    26. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
    27. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
    28. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
    29. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
    30. relative_coords[:, :, 1] += self.window_size[1] - 1
    31. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
    32. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
    33. self.register_buffer("relative_position_index", relative_position_index)
    1. relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
    2. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
    3. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
    4. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
    5. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
    6. attn = attn + relative_position_bias.unsqueeze(0)

     4.Post normalization 

            作者在实验中发现,在预归一化配置中,注意力的激活值直接进行残差连接,造成了激活值的震荡,并且主分支的振幅在更深的层次上越来越大。不同层的振幅差异较大,导致训练不稳定。 

            因此,作者如上图所示,作者将预归一化改为后归一化,即去除预归一化,在残差连接前先进行归一化,在进行残差连接。

    1. class WindowAttention(nn.Module):
    2. r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    3. It supports both of shifted and non-shifted window.
    4. Args:
    5. dim (int): Number of input channels.
    6. window_size (tuple[int]): The height and width of the window.
    7. num_heads (int): Number of attention heads.
    8. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
    9. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
    10. proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    11. pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
    12. """
    13. def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
    14. pretrained_window_size=[0, 0]):
    15. super().__init__()
    16. self.dim = dim
    17. self.window_size = window_size # Wh, Ww
    18. self.pretrained_window_size = pretrained_window_size
    19. self.num_heads = num_heads
    20. # 余弦注意力的缩放值
    21. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
    22. # mlp to generate continuous relative position bias 可学习的位置编码
    23. self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
    24. nn.ReLU(inplace=True),
    25. nn.Linear(512, num_heads, bias=False))
    26. # get relative_coords_table
    27. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
    28. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
    29. # 相对位置的矩阵[-w+1,w-1],限制在区间[-8.8]
    30. relative_coords_table = torch.stack(
    31. torch.meshgrid([relative_coords_h,
    32. relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
    33. if pretrained_window_size[0] > 0:
    34. relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
    35. relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
    36. else:
    37. relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
    38. relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
    39. relative_coords_table *= 8 # normalize to -8, 8
    40. # Log-spaced coordinates 使用了8进行规范化
    41. relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
    42. torch.abs(relative_coords_table) + 1.0) / np.log2(8)
    43. self.register_buffer("relative_coords_table", relative_coords_table)
    44. # get pair-wise relative position index for each token inside the window
    45. coords_h = torch.arange(self.window_size[0])
    46. coords_w = torch.arange(self.window_size[1])
    47. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
    48. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
    49. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
    50. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
    51. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
    52. relative_coords[:, :, 1] += self.window_size[1] - 1
    53. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
    54. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
    55. self.register_buffer("relative_position_index", relative_position_index)
    56. self.qkv = nn.Linear(dim, dim * 3, bias=False)
    57. # 偏置项作为可学习的参数
    58. if qkv_bias:
    59. self.q_bias = nn.Parameter(torch.zeros(dim))
    60. self.v_bias = nn.Parameter(torch.zeros(dim))
    61. else:
    62. self.q_bias = None
    63. self.v_bias = None
    64. self.attn_drop = nn.Dropout(attn_drop)
    65. self.proj = nn.Linear(dim, dim)
    66. self.proj_drop = nn.Dropout(proj_drop)
    67. self.softmax = nn.Softmax(dim=-1)
    68. def forward(self, x, mask=None):
    69. """
    70. Args:
    71. x: input features with shape of (num_windows*B, N, C)
    72. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
    73. """
    74. B_, N, C = x.shape
    75. qkv_bias = None
    76. if self.q_bias is not None:
    77. qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
    78. qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
    79. qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
    80. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
    81. # cosine attention
    82. attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
    83. logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01).to(attn.device))).exp()
    84. attn = attn * logit_scale
    85. relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
    86. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
    87. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
    88. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
    89. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
    90. attn = attn + relative_position_bias.unsqueeze(0)
    91. if mask is not None:
    92. nW = mask.shape[0]
    93. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
    94. attn = attn.view(-1, self.num_heads, N, N)
    95. attn = self.softmax(attn)
    96. else:
    97. attn = self.softmax(attn)
    98. attn = self.attn_drop(attn)
    99. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
    100. x = self.proj(x)
    101. x = self.proj_drop(x)
    102. return x
    1. class SwinTransformerBlock(nn.Module):
    2. r""" Swin Transformer Block.
    3. Args:
    4. dim (int): Number of input channels.
    5. input_resolution (tuple[int]): Input resulotion.
    6. num_heads (int): Number of attention heads.
    7. window_size (int): Window size.
    8. shift_size (int): Shift size for SW-MSA.
    9. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
    10. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
    11. drop (float, optional): Dropout rate. Default: 0.0
    12. attn_drop (float, optional): Attention dropout rate. Default: 0.0
    13. drop_path (float, optional): Stochastic depth rate. Default: 0.0
    14. act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
    15. norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
    16. pretrained_window_size (int): Window size in pre-training.
    17. """
    18. def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
    19. mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
    20. act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
    21. super().__init__()
    22. self.dim = dim
    23. self.input_resolution = input_resolution
    24. self.num_heads = num_heads
    25. self.window_size = window_size
    26. self.shift_size = shift_size
    27. self.mlp_ratio = mlp_ratio
    28. if min(self.input_resolution) <= self.window_size:
    29. # if window size is larger than input resolution, we don't partition windows
    30. self.shift_size = 0
    31. self.window_size = min(self.input_resolution)
    32. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
    33. self.norm1 = norm_layer(dim)
    34. self.attn = WindowAttention(
    35. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
    36. qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
    37. pretrained_window_size=to_2tuple(pretrained_window_size))
    38. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    39. self.norm2 = norm_layer(dim)
    40. mlp_hidden_dim = int(dim * mlp_ratio)
    41. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    42. if self.shift_size > 0:
    43. # calculate attention mask for SW-MSA
    44. H, W = self.input_resolution
    45. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
    46. h_slices = (slice(0, -self.window_size),
    47. slice(-self.window_size, -self.shift_size),
    48. slice(-self.shift_size, None))
    49. w_slices = (slice(0, -self.window_size),
    50. slice(-self.window_size, -self.shift_size),
    51. slice(-self.shift_size, None))
    52. cnt = 0
    53. for h in h_slices:
    54. for w in w_slices:
    55. img_mask[:, h, w, :] = cnt
    56. cnt += 1
    57. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
    58. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
    59. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    60. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    61. else:
    62. attn_mask = None
    63. self.register_buffer("attn_mask", attn_mask)
    64. def forward(self, x):
    65. H, W = self.input_resolution
    66. B, L, C = x.shape
    67. assert L == H * W, "input feature has wrong size"
    68. shortcut = x
    69. x = x.view(B, H, W, C)
    70. # cyclic shift
    71. if self.shift_size > 0:
    72. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    73. else:
    74. shifted_x = x
    75. # partition windows
    76. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
    77. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
    78. # W-MSA/SW-MSA
    79. attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
    80. # merge windows
    81. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
    82. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
    83. # reverse cyclic shift
    84. if self.shift_size > 0:
    85. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
    86. else:
    87. x = shifted_x
    88. x = x.view(B, H * W, C)
    89. x = shortcut + self.drop_path(self.norm1(x))
    90. # FFN
    91. x = x + self.drop_path(self.norm2(self.mlp(x)))
    92. return x

     

     

     

  • 相关阅读:
    Can‘t pickle <class ‘__main__.Test‘>: it‘s not the same object as __main__.Test
    【单片机基础】C51语言基础
    前端组件封装:构建模块化、可维护和可重用的前端应用
    AWS】在EC2上创建root用户,并使用root用户登录
    老K,硬核“锅”气
    Java---Java Web---JSP
    液晶显示计算器(显示程序)
    前端网站分享
    SELinux
    PaddleNLP UIE -- 药品说明书信息抽取(名称、规格、用法、用量)
  • 原文地址:https://blog.csdn.net/qq_52053775/article/details/127794977