• pytorch的模型保存加载和继续训练


    随着现在模型越来越大,一次性训练完模型在低算力平台也越来越难以实现,因此很有必要在训练过程中保存模型,以便下次之前训练的基础上进行继续训练,节约时间。代码如下:

    导包

    import torch
    from torch import nn
    import numpy as np
    
    • 1
    • 2
    • 3

    定义模型

    定义一个三层的MLP分类模型

    class MyModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(64, 32)
            self.linear1 = nn.Linear(32, 10)
            self.relu = nn.ReLU()
    
        def forward(self, x):
            x = self.linear(x)
            x = self.relu(x)
            x = self.linear1(x)
            return x
    
    ## 随机生成2组带标签的数据
    rand1 = torch.rand((100, 64)).to(torch.float)
    label1 = np.random.randint(0, 10, size=100)
    label1 = torch.from_numpy(label1).to(torch.long)
    rand2 = torch.rand((100, 64)).to(torch.float)
    label2 = np.random.randint(0, 10, size=100)
    label2 = torch.from_numpy(label2).to(torch.long)
    
    model = MyModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss = nn.CrossEntropyLoss()
    
    ## 训练10个epoch
    epoch = 10
    for i in range(epoch):
        output = model(rand1)
        my_loss = loss(output, label1)
        optimizer.zero_grad()
        my_loss.backward()
        optimizer.step()
        print("epoch:{} loss:{}".format(i, my_loss))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34

    结果如下:记下这些loss值,观察下次继续训练的初始loss

    epoch:0 loss:2.3494179248809814
    epoch:1 loss:2.287858009338379
    epoch:2 loss:2.2486231327056885
    epoch:3 loss:2.2189149856567383
    epoch:4 loss:2.193182945251465
    epoch:5 loss:2.167125940322876
    epoch:6 loss:2.140075206756592
    epoch:7 loss:2.1100614070892334
    epoch:8 loss:2.0764594078063965
    epoch:9 loss:2.0402779579162598
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    模型保存

    采用torch.save函数保存模型,一般分为两种模式,分别是简单的保存所有参数,第二种是保存各部分参数,到一个字典结构里面。

    # 保存模型的整体参数
    save_path = r'model_para/'
    torch.save(model, save_path+'model_full.pth')
    
    • 1
    • 2
    • 3

    保存模型参数,优化器参数和epoch情况。

    def save_model(save_path, epoch, optimizer, model):
        torch.save({'epoch': epoch+1,
                    'optimizer_dict': optimizer.state_dict(),
                    'model_dict': model.state_dict()},
                    save_path)
        print("model save success")
    save_model(save_path+'model_dict.pth',epoch, optimizer, model)
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    加载模型

    对于保存的pth参数文件,使用torch.load进行加载,代码如下:

    def load_model(save_name, optimizer, model):
        model_data = torch.load(save_name)
        model.load_state_dict(model_data['model_dict'])
        optimizer.load_state_dict(model_data['optimizer_dict'])
        print("model load success")
    
    • 1
    • 2
    • 3
    • 4
    • 5

    观察当前训练模型的权重参数

    print(model.state_dict()['linear.weight'])
    
    • 1
    tensor([[-0.0215,  0.0299, -0.0255,  ..., -0.0997, -0.0899,  0.0499],
            [-0.0113, -0.0974,  0.1020,  ...,  0.0874, -0.0744,  0.0801],
            [ 0.0471,  0.1373,  0.0069,  ..., -0.0573, -0.0199, -0.0654],
            ...,
            [ 0.0693,  0.1900,  0.0013,  ..., -0.0348,  0.1541,  0.1372],
            [ 0.1672, -0.0086,  0.0189,  ...,  0.0926,  0.1545,  0.0934],
            [-0.0773,  0.0645, -0.1544,  ..., -0.1130,  0.0213, -0.0613]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7

    命名一个新模型,加载之前保存的参数文件,并打印出层参数

    new_model = MyModel()
    new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.01)
    load_model(save_path+'model_dict.pth', new_optimizer, new_model)
    print(new_model.state_dict()['linear.weight'])
    
    • 1
    • 2
    • 3
    • 4

    可以看出新模型和当前模型的参数一致,说明参数加载成功。

    model load success
    tensor([[-0.0215,  0.0299, -0.0255,  ..., -0.0997, -0.0899,  0.0499],
            [-0.0113, -0.0974,  0.1020,  ...,  0.0874, -0.0744,  0.0801],
            [ 0.0471,  0.1373,  0.0069,  ..., -0.0573, -0.0199, -0.0654],
            ...,
            [ 0.0693,  0.1900,  0.0013,  ..., -0.0348,  0.1541,  0.1372],
            [ 0.1672, -0.0086,  0.0189,  ...,  0.0926,  0.1545,  0.0934],
            [-0.0773,  0.0645, -0.1544,  ..., -0.1130,  0.0213, -0.0613]])
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8

    继续训练

    在新模型加载原来模型参数的基础上,继续训练,观察loss值,是在之前训练的最终loss,继续下降,说明模型继续训练成功。

    epoch = 10
    for i in range(epoch):
        output = new_model(rand1)
        my_loss = loss(output, label1)
        new_optimizer.zero_grad()
        my_loss.backward()
        new_optimizer.step()
        print("epoch:{} loss:{}".format(i, my_loss))
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    epoch:0 loss:2.0036799907684326
    epoch:1 loss:1.965193271636963
    epoch:2 loss:1.924098253250122
    epoch:3 loss:1.881495714187622
    epoch:4 loss:1.835693359375
    epoch:5 loss:1.7865667343139648
    epoch:6 loss:1.7352293729782104
    epoch:7 loss:1.6832704544067383
    epoch:8 loss:1.6308385133743286
    epoch:9 loss:1.5763107538223267
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    数据分布不一致带来的问题

    同样,在这里我发现一个问题,因为之前随机产生了2组数据,之前模型训练使用的rand1,这里只有继续训练rand1,之前模型的参数才有效,如果使用rand2,模型相当于从0训练(如下loss),这是因为,两组数据都是随机生成的,数据分布几乎不一样,所以上一组数据训练的模型在第二组数据几乎无效。

    epoch:0 loss:2.523787498474121
    epoch:1 loss:2.469816207885742
    epoch:2 loss:2.4141526222229004
    epoch:3 loss:2.379054069519043
    epoch:4 loss:2.3563807010650635
    epoch:5 loss:2.319946765899658
    epoch:6 loss:2.271805763244629
    epoch:7 loss:2.2274367809295654
    epoch:8 loss:2.186885118484497
    epoch:9 loss:2.144239902496338
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

    但是在真实情况中,由于batch数据都是假设同一分布,所以不用考虑这个问题,

    那么以上,就完成了pytorch的模型保存,加载和继续训练的三种重要过程,希望能够帮到您!!!

    祝您训练愉快。

  • 相关阅读:
    异构数据库
    linux oracle 2022年10月份补丁集:11.2.0.4.221018 PSU补丁包已发布,包含 database,ojvm和GI
    球幕投影有哪些常见的物理表现形式?
    20220728使用电脑上的蓝牙和汇承科技的蓝牙模块HC-05配对蓝牙串口传输
    长安链大规模数据存储及数据膨胀分析
    Python入门自学进阶-Web框架——20、Django其他相关知识2
    基于Matlab实现多个图像融合案例(附上源码+数据集)
    Linux开发讲课18--- “>file 2>&1“ 和 “2>&1 >file“ 的区别
    从 dpdk-20.11 移植 intel E810 百 G 网卡 pmd 驱动到 dpdk-16.04 中
    微信小程序(五)--- Vant组件库,API Promise化,MboX全局数据共享,分包相关
  • 原文地址:https://blog.csdn.net/weixin_42327752/article/details/125405980