• YOLOv8+swin_transfomer


    测试环境:cuda11.3  pytorch1.11 rtx3090  wsl2 ubuntu20.04

    本科在读,中九以上老师或者课题组捞捞我,孩子想读书,求课题组师兄内推qaq

    踩了很多坑,网上很多博主的代码根本跑不通,自己去github仓库复现修改的

    网上博主的代码日常出现cpu,gpu混合,或许是人家分布式训练了,哈哈哈

    下面上干货吧,宝子们点个关注,点个赞,没有废话
    ————————————————
    首先上yaml文件,诚意满满,我是做的分割,做检测就修改最后那个检测头就好了

    1. # Ultralytics YOLO 🚀, AGPL-3.0 license
    2. # YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment
    3. # Parameters
    4. nc: 1 # number of classes
    5. scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n'
    6. # [depth, width, max_channels]
    7. n: [0.33, 0.25, 1024]
    8. s: [0.33, 0.50, 1024]
    9. m: [0.67, 0.75, 768]
    10. l: [1.00, 1.00, 512]
    11. x: [1.00, 1.25, 512]
    12. # YOLOv8.0n backbone
    13. backbone:
    14. # [from, repeats, module, args]
    15. - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
    16. - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
    17. - [-1, 3, C2f, [128, True]]
    18. - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
    19. - [-1, 6, C2f, [256, True]]
    20. - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
    21. - [-1, 9, C3STR, [512]]
    22. - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
    23. - [-1, 3, C3STR, [1024]]
    24. - [-1, 1, SPPF, [1024, 5]] # 9
    25. # YOLOv8.0n head
    26. head:
    27. - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
    28. - [[-1, 6], 1, Concat, [1]] # cat backbone P4
    29. - [-1, 3, C2f, [512]] # 12
    30. - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
    31. - [[-1, 4], 1, Concat, [1]] # cat backbone P3
    32. - [-1, 3, C2f, [256]] # 15 (P3/8-small)
    33. - [-1, 1, Conv, [256, 3, 2]]
    34. - [[-1, 12], 1, Concat, [1]] # cat head P4
    35. - [-1, 3, C2f, [512]] # 18 (P4/16-medium)
    36. - [-1, 1, Conv, [512, 3, 2]]
    37. - [[-1, 9], 1, Concat, [1]] # cat head P5
    38. - [-1, 3, C2f, [1024]] # 21 (P5/32-large)
    39. - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)

    然后在nn/modules/block.py最下面加入

    1. class SwinTransformerBlock(nn.Module):
    2. def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
    3. super().__init__()
    4. self.conv = None
    5. if c1 != c2:
    6. self.conv = Conv(c1, c2)
    7. # remove input_resolution
    8. self.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,
    9. shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])
    10. def forward(self, x):
    11. if self.conv is not None:
    12. x = self.conv(x)
    13. x = self.blocks(x)
    14. return x
    15. class WindowAttention(nn.Module):
    16. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
    17. super().__init__()
    18. self.dim = dim
    19. self.window_size = window_size # Wh, Ww
    20. self.num_heads = num_heads
    21. head_dim = dim // num_heads
    22. self.scale = qk_scale or head_dim ** -0.5
    23. # define a parameter table of relative position bias
    24. self.relative_position_bias_table = nn.Parameter(
    25. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
    26. # get pair-wise relative position index for each token inside the window
    27. coords_h = torch.arange(self.window_size[0])
    28. coords_w = torch.arange(self.window_size[1])
    29. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
    30. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
    31. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
    32. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
    33. relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
    34. relative_coords[:, :, 1] += self.window_size[1] - 1
    35. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
    36. relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
    37. self.register_buffer("relative_position_index", relative_position_index)
    38. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    39. self.attn_drop = nn.Dropout(attn_drop)
    40. self.proj = nn.Linear(dim, dim)
    41. self.proj_drop = nn.Dropout(proj_drop)
    42. nn.init.normal_(self.relative_position_bias_table, std=.02)
    43. self.softmax = nn.Softmax(dim=-1)
    44. def forward(self, x, mask=None):
    45. B_, N, C = x.shape
    46. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    47. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
    48. q = q * self.scale
    49. attn = (q @ k.transpose(-2, -1))
    50. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
    51. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
    52. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
    53. attn = attn + relative_position_bias.unsqueeze(0)
    54. if mask is not None:
    55. nW = mask.shape[0]
    56. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
    57. attn = attn.view(-1, self.num_heads, N, N)
    58. attn = self.softmax(attn)
    59. else:
    60. attn = self.softmax(attn)
    61. attn = self.attn_drop(attn)
    62. # print(attn.dtype, v.dtype)
    63. try:
    64. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
    65. except:
    66. #print(attn.dtype, v.dtype)
    67. x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
    68. x = self.proj(x)
    69. x = self.proj_drop(x)
    70. return x
    71. class Mlp(nn.Module):
    72. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
    73. super().__init__()
    74. out_features = out_features or in_features
    75. hidden_features = hidden_features or in_features
    76. self.fc1 = nn.Linear(in_features, hidden_features)
    77. self.act = act_layer()
    78. self.fc2 = nn.Linear(hidden_features, out_features)
    79. self.drop = nn.Dropout(drop)
    80. def forward(self, x):
    81. x = self.fc1(x)
    82. x = self.act(x)
    83. x = self.drop(x)
    84. x = self.fc2(x)
    85. x = self.drop(x)
    86. return x
    87. class SwinTransformerLayer(nn.Module):
    88. def __init__(self, dim, num_heads, window_size=8, shift_size=0,
    89. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
    90. act_layer=nn.SiLU, norm_layer=nn.LayerNorm):
    91. super().__init__()
    92. self.dim = dim
    93. self.num_heads = num_heads
    94. self.window_size = window_size
    95. self.shift_size = shift_size
    96. self.mlp_ratio = mlp_ratio
    97. # if min(self.input_resolution) <= self.window_size:
    98. # # if window size is larger than input resolution, we don't partition windows
    99. # self.shift_size = 0
    100. # self.window_size = min(self.input_resolution)
    101. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
    102. self.norm1 = norm_layer(dim)
    103. self.attn = WindowAttention(
    104. dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
    105. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
    106. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
    107. self.norm2 = norm_layer(dim)
    108. mlp_hidden_dim = int(dim * mlp_ratio)
    109. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
    110. def create_mask(self, H, W):
    111. # calculate attention mask for SW-MSA
    112. img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
    113. h_slices = (slice(0, -self.window_size),
    114. slice(-self.window_size, -self.shift_size),
    115. slice(-self.shift_size, None))
    116. w_slices = (slice(0, -self.window_size),
    117. slice(-self.window_size, -self.shift_size),
    118. slice(-self.shift_size, None))
    119. cnt = 0
    120. for h in h_slices:
    121. for w in w_slices:
    122. img_mask[:, h, w, :] = cnt
    123. cnt += 1
    124. def window_partition(x, window_size):
    125. """
    126. Args:
    127. x: (B, H, W, C)
    128. window_size (int): window size
    129. Returns:
    130. windows: (num_windows*B, window_size, window_size, C)
    131. """
    132. B, H, W, C = x.shape
    133. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    134. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    135. return windows
    136. def window_reverse(windows, window_size, H, W):
    137. """
    138. Args:
    139. windows: (num_windows*B, window_size, window_size, C)
    140. window_size (int): Window size
    141. H (int): Height of image
    142. W (int): Width of image
    143. Returns:
    144. x: (B, H, W, C)
    145. """
    146. B = int(windows.shape[0] / (H * W / window_size / window_size))
    147. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    148. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    149. return x
    150. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
    151. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
    152. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
    153. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
    154. return attn_mask
    155. def forward(self, x):
    156. # reshape x[b c h w] to x[b l c]
    157. _, _, H_, W_ = x.shape
    158. Padding = False
    159. if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
    160. Padding = True
    161. # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
    162. pad_r = (self.window_size - W_ % self.window_size) % self.window_size
    163. pad_b = (self.window_size - H_ % self.window_size) % self.window_size
    164. x = F.pad(x, (0, pad_r, 0, pad_b))
    165. # print('2', x.shape)
    166. B, C, H, W = x.shape
    167. L = H * W
    168. x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) # b, L, c
    169. # create mask from init to forward
    170. if self.shift_size > 0:
    171. attn_mask = self.create_mask(H, W).to(x.device)
    172. else:
    173. attn_mask = None
    174. shortcut = x
    175. x = self.norm1(x)
    176. x = x.view(B, H, W, C)
    177. # cyclic shift
    178. if self.shift_size > 0:
    179. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
    180. else:
    181. shifted_x = x
    182. def window_partition(x, window_size):
    183. """
    184. Args:
    185. x: (B, H, W, C)
    186. window_size (int): window size
    187. Returns:
    188. windows: (num_windows*B, window_size, window_size, C)
    189. """
    190. B, H, W, C = x.shape
    191. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    192. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    193. return windows
    194. def window_reverse(windows, window_size, H, W):
    195. """
    196. Args:
    197. windows: (num_windows*B, window_size, window_size, C)
    198. window_size (int): Window size
    199. H (int): Height of image
    200. W (int): Width of image
    201. Returns:
    202. x: (B, H, W, C)
    203. """
    204. B = int(windows.shape[0] / (H * W / window_size / window_size))
    205. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    206. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    207. return x
    208. # partition windows
    209. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
    210. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
    211. # W-MSA/SW-MSA
    212. attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
    213. # merge windows
    214. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
    215. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
    216. # reverse cyclic shift
    217. if self.shift_size > 0:
    218. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
    219. else:
    220. x = shifted_x
    221. x = x.view(B, H * W, C)
    222. # FFN
    223. x = shortcut + self.drop_path(x)
    224. x = x + self.drop_path(self.mlp(self.norm2(x)))
    225. x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W) # b c h w
    226. if Padding:
    227. x = x[:, :, :H_, :W_] # reverse padding
    228. return x
    229. class C3STR(C3):
    230. # C3 module with SwinTransformerBlock()
    231. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
    232. super().__init__(c1, c2, n, shortcut, g, e)
    233. c_ = int(c2 * e)
    234. num_heads = c_ // 32
    235. self.m = SwinTransformerBlock(c_, c_, num_heads, n)

    然后是在目录的init.py里面把这个模块注册进去,参考我的另一篇博客的后半部分

    YOLOv8+swin_transfomerv2_不会写代码!!的博客-CSDN博客

    task.py需要修改两处

    1. n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
    2. if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
    3. BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3
    4. ,CBAM , GAM_Attention ,ResBlock_CBAM,GCT,C3STR,SwinV2_CSPB):
    1. args = [c1, c2, *args[1:]]
    2. if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3,C3STR):
    3. args.insert(2, n) # number of repeats
    4. n = 1

    有问题私信

    结构图如下,可以对模块排列组合涨点

  • 相关阅读:
    Java 内部类 面试“变态题”
    解释一下React中的钩子(hooks),例如useState和useEffect。
    【趣味测试】
    python第三方库-requests的使用
    Pytorch API
    使用DIV+CSS技术设计的非遗文化网页与实现制作(web前端网页制作课作业)
    【DDPM论文解读】Denoising Diffusion Probabilistic Models
    Java文件流练习
    Lingolingo
    Blazor VS Vue
  • 原文地址:https://blog.csdn.net/xty123abc/article/details/133429072