• pytorch 保存和加载模型


    保存和加载模型

    在本节中,我们将了解如何通过保存、加载和运行模型预测来保持模型状态。

    import torch
    import torchvision.models as models
    
    • 1
    • 2

    1. 保存和加载模型权重

    PyTorch 模型将学习到的参数存储在内部状态字典中,称为state_dict. 这些可以通过以下torch.save 方法持久化:

    model = models.vgg16(pretrained=True)
    torch.save(model.state_dict(), 'model_weights.pth')
    
    • 1
    • 2
    /opt/conda/lib/python3.7/site-packages/torchvision/models/_utils.py:209: UserWarning:
    
    The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
    
    /opt/conda/lib/python3.7/site-packages/torchvision/models/_utils.py:223: UserWarning:
    
    Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
    
    Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/vgg16-397923af.pth
    
      0%|          | 0.00/528M [00:00
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    要加载模型权重,首先需要创建一个相同模型的实例,然后使用load_state_dict()方法加载参数。

    model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
    model.load_state_dict(torch.load('model_weights.pth'))
    model.eval()
    
    • 1
    • 2
    • 3
    VGG(
      (features): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
        (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (3): ReLU(inplace=True)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (6): ReLU(inplace=True)
        (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): ReLU(inplace=True)
        (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): ReLU(inplace=True)
        (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (13): ReLU(inplace=True)
        (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (15): ReLU(inplace=True)
        (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (18): ReLU(inplace=True)
        (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (20): ReLU(inplace=True)
        (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (22): ReLU(inplace=True)
        (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (25): ReLU(inplace=True)
        (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (27): ReLU(inplace=True)
        (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (29): ReLU(inplace=True)
        (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
      (classifier): Sequential(
        (0): Linear(in_features=25088, out_features=4096, bias=True)
        (1): ReLU(inplace=True)
        (2): Dropout(p=0.5, inplace=False)
        (3): Linear(in_features=4096, out_features=4096, bias=True)
        (4): ReLU(inplace=True)
        (5): Dropout(p=0.5, inplace=False)
        (6): Linear(in_features=4096, out_features=1000, bias=True)
      )
    )
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45

    注意事项:

    一定要model.eval()在推理之前调用方法,将 dropout 和 batch normalization 层设置为评估模式。不这样做会产生不一致的推理结果。

    2. 保存和加载模型结构和参数

    加载模型权重时,我们需要先实例化模型类,因为该类定义了网络的结构。我们可能希望将此类的结构与模型一起保存,在这种情况下,我们可以将model(而不是model.state_dict())传递给保存函数:

    torch.save(model, 'model.pth')
    
    • 1

    然后我们可以像这样加载模型:

    model = torch.load('model.pth')
    
    • 1

    注意事项:

    这种方法在序列化模型时使用 Python pickle模块,因此它依赖于在加载模型时可用的实际类定义。

    原文地址

  • 相关阅读:
    idea导入tomcat8源码搭建源码调试环境
    操作系统学习
    从软件测试培训班出来后找工作这段时间的经历,教会了我这五件事...
    Verilog设计参数化的译码器与编码器,以及设计4位格雷码计数器
    linux操作系统期末考试题库
    低功耗无线扫描唤醒技术,重塑物联网蓝牙新体验
    失败率高达80%,数字化转型如何正确完成战略规划?
    家政服务预约小程序,推拿spa上门预约系统
    位图和布隆过滤器
    golang学习笔记系列之复杂数据类型
  • 原文地址:https://blog.csdn.net/xuejianxinokok/article/details/127614240