• 搭建深度学习网络时节约GPU显存的技巧


    前言

    在搭建网络的时候,发现relu函数用几种不同的用法,对于不同的用法,并不会改变可训练参数的数量,但是所占用的计算机资源却不同,但是各有利弊,相面详细介绍。

    可以对比这篇文章一起看,实在这篇文章基础上进行修改的。https://mp.csdn.net/mp_blog/creation/editor/126259211

    一、残差块的不同实现

    【版本1】

    1. def BasicBlock(in_ch,out_ch,stride):
    2. return nn.Sequential(
    3. nn.Conv2d(in_ch, out_ch, 3, stride, padding=1, bias=False),
    4. nn.BatchNorm2d(out_ch),
    5. nn.ReLU(inplace=True), # inplace = True原地操作,节省显存
    6. nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False),
    7. nn.BatchNorm2d(out_ch),
    8. nn.ReLU(inplace=True),
    9. )
    10. class ResidualBlock_old(nn.Module):
    11. # 实现子module:Residual Block
    12. def __init__(self, in_ch, out_ch, stride=1, shortcut=None):
    13. super(ResidualBlock_old, self).__init__()
    14. self.BasicBlock = BasicBlock(in_ch,out_ch,stride)
    15. self.downsample = shortcut
    16. def forward(self, x):
    17. out = self.BasicBlock(x)
    18. residual = x if self.downsample is None else self.downsample(x)
    19. out += residual
    20. return out

    【版本2】

    1. class ResidualBlock(nn.Module):
    2. # 实现子module:Residual Block
    3. def __init__(self, in_ch, out_ch, stride=1, shortcut=None):
    4. super(ResidualBlock, self).__init__()
    5. self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, padding=1, bias=False)
    6. self.bn1 = nn.BatchNorm2d(out_ch)
    7. self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
    8. self.bn2 = nn.BatchNorm2d(out_ch)
    9. self.downsample = shortcut
    10. def forward(self, x):
    11. out = self.conv1(x)
    12. out = self.bn1(out)
    13. out = F.relu(out)
    14. out = self.conv2(out)
    15. out = self.bn2(out)
    16. residual = x if self.downsample is None else self.downsample(x)
    17. out += residual
    18. return F.relu(out)
    即nn.ReLU和F.ReLU两种实现方法。
    其中nn.ReLU作为一个层结构,必须添加到nn.Module容器中才能使用,而F.ReLU则作为一个函数调用,
    
    具体使用哪种方式,取决于编程风格。在PyTorch中,nn.X都有对应的函数版本F.X,但是并不是所有的F.X均可以用于forward或其它代码段中,
    
    当网络模型训练完毕时,在存储model时,在forward中的F.X函数中的参数是无法保存的。
    也就是说,在forward中,使用的F.X函数一般均没有状态参数,比如F.ReLU,F.avg_pool2d等,均没有参数,它们可以用在任何代码片段中。

    二、两个版本的可训练参数与占用显存情况

     可训练参数保持不变,但是占用内存的情况变为:源码为589.16MB修改后的为479.16MB 

    总的代码

    1. import torch.nn as nn
    2. import torch
    3. from torch.nn import functional as F
    4. from torchsummary import summary
    5. from torchvision import models
    6. def BasicBlock(in_ch,out_ch,stride):
    7. return nn.Sequential(
    8. nn.Conv2d(in_ch, out_ch, 3, stride, padding=1, bias=False),
    9. nn.BatchNorm2d(out_ch),
    10. nn.ReLU(inplace=True), # inplace = True原地操作,节省显存
    11. nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False),
    12. nn.BatchNorm2d(out_ch),
    13. nn.ReLU(inplace=True),
    14. )
    15. class ResidualBlock_old(nn.Module):
    16. # 实现子module:Residual Block
    17. def __init__(self, in_ch, out_ch, stride=1, shortcut=None):
    18. super(ResidualBlock, self).__init__()
    19. self.BasicBlock = BasicBlock(in_ch,out_ch,stride)
    20. self.downsample = shortcut
    21. def forward(self, x):
    22. out = self.BasicBlock(x)
    23. residual = x if self.downsample is None else self.downsample(x)
    24. out += residual
    25. return out
    26. class ResidualBlock(nn.Module):
    27. # 实现子module:Residual Block
    28. def __init__(self, in_ch, out_ch, stride=1, shortcut=None):
    29. super(ResidualBlock, self).__init__()
    30. self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, padding=1, bias=False)
    31. self.bn1 = nn.BatchNorm2d(out_ch)
    32. self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
    33. self.bn2 = nn.BatchNorm2d(out_ch)
    34. self.downsample = shortcut
    35. def forward(self, x):
    36. out = self.conv1(x)
    37. out = self.bn1(out)
    38. out = F.relu(out)
    39. out = self.conv2(out)
    40. out = self.bn2(out)
    41. residual = x if self.downsample is None else self.downsample(x)
    42. out += residual
    43. return F.relu(out)
    44. class ResNet34(nn.Module):
    45. # 实现主module:ResNet34
    46. def __init__(self, num_classes=1):
    47. super(ResNet34, self).__init__()
    48. self.init_block = nn.Sequential(
    49. nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
    50. nn.BatchNorm2d(64),
    51. nn.ReLU(inplace=True),
    52. nn.MaxPool2d(3, 2, 1)
    53. )
    54. self.layer1 = self.make_layer(64, 64, 3)
    55. self.layer2 = self.make_layer(64, 128, 4, stride=2)
    56. self.layer3 = self.make_layer(128, 256, 6, stride=2)
    57. self.layer4 = self.make_layer(256, 512, 3, stride=2)
    58. # 分类用的全连接
    59. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    60. self.fc = nn.Linear(512, num_classes)
    61. def make_layer(self, in_ch, out_ch, block_num, stride=1):
    62. shortcut = None
    63. # 判断是否使用降采样 维度增加
    64. if not in_ch==out_ch:
    65. shortcut = nn.Sequential(
    66. nn.Conv2d(in_ch, out_ch, 1, stride, bias=False), # 1x1卷积用于增加维度;stride=2用于减半size;为简化不考虑偏差
    67. nn.BatchNorm2d(out_ch))
    68. layers = []
    69. layers.append(ResidualBlock(in_ch, out_ch, stride, shortcut))
    70. for i in range(1, block_num):
    71. layers.append(ResidualBlock(out_ch, out_ch)) # 后面的几个ResidualBlock,shortcut直接相加
    72. return nn.Sequential(*layers)
    73. def forward(self, x):
    74. x = self.init_block(x)
    75. x = self.layer1(x)
    76. x = self.layer2(x)
    77. x = self.layer3(x)
    78. x = self.layer4(x)
    79. x = self.avgpool(x)
    80. x = torch.flatten(x, 1)
    81. x = self.fc(x)
    82. return nn.Sigmoid()(x) # 1x1,将结果化为(0~1)之间
    83. if __name__ == '__main__':
    84. filters = [64, 128, 256, 512]
    85. resnet = models.resnet34(pretrained=False)
    86. summary(resnet.cuda(), (3, 512, 512))
    87. print('***************\n*****************\n')
    88. # MY RESNET
    89. resnet_my = ResNet34(num_classes=1000)
    90. print(self_encoder1)
    91. summary(resnet_my.cuda(), (3, 512, 512))

  • 相关阅读:
    postgresql autovaccum自动清理
    pandas 涉及内容的用法
    1366 - Incorrect string value: ‘\xE5\xB9\xBF\xE5\x85\xB0...‘ for column编码错误
    代码随想录58——单调栈:739每日温度、 496下一个更大元素I
    天宇优配|多家房企发布再融资预案,最牛地产股九连板
    知识分享:如何制作一个电子名片二维码?
    C/C++总结笔记—— volatile、mutable、explicit 关键字
    记录一个hive中因没启yarn导致的spark引擎跑insert语句的报错
    研究生英语复习(一)
    开关电源环路稳定性分析(01)-Buck变换器
  • 原文地址:https://blog.csdn.net/weixin_44503976/article/details/126262686