• 【AI设计模式】05-检查点模式(CheckPoints):如何定期存储模型?


    作者:王磊
    更多精彩分享,欢迎访问和关注:https://www.zhihu.com/people/wldandan

    目录

    在前一篇文章《数据处理-Eager模式》中分享了数据处理-Eager模式,那么在模型训练时,有哪些设计模式可以使用呢?在数据库领域,为了防止执行时间较长的存储过程失败重新执行,会将中间的过程状态以检查点的形式持续记录下来,每次失败时不需要重头执行,而是加载最近的检查点,继续执行,避免浪费时间。和存储过程类似,模型的训练时间会更长,如果缺乏一定的可靠性机制,过程中一旦失败,就需要重头开始训练,浪费时间较多。因此,需要实现类似的机制来保证可靠性问题,这种机制被称为检查点(CheckPoints)模式

    AI设计模式总览

    模式定义

    检查点模式(CheckPoints)是指通过周期性(迭代/时间)的保存模型的完整状态,在模型训练失败时,可以从保存的检查点模型继续训练,以避免训练失败时每次都需要从头开始带来的训练时间浪费。检查点模式适用于模型训练时间长、训练需要提前结束、fine-tune等场景,也可以拓展到异常时的断点续训场景。

    问题

    1. 训练耗时的网络在训练过程中失败,从头开始训练的代价高:对于层数比较深的神经网络,或者需要大规模训练数据的模型,训练的时间会很长。因为有更多的参数以及更多的数据样本需要处理。比如对于VGG16的网络,cifar-10的数据集,普通的NVIDIA显卡训练需要3-4小时;一旦过程中失败,需要重头开始训练,时间成本高。
    2. 训练时间越长,精度可能不发生变化,或者产生过拟合的现象。这种场景时,提前结束(early stopping)获得中间的模型状态收益会更高。
    3. Fine-Tune时,通常需要最终模型前面的一些模型状态进行基础上进行调优,这样可以更好的针对新数据进行训练,获得更好的泛化性。

    解决方案

    在每轮训练结束时,都保存当前的模型状态作为检查点,如果下轮训练失败时,可以从这个检查点模型继续训练。和训练完成导出的模型(以神经网络为例,最终的模型包含权重、激活函数以及隐藏层信息)相比,这个中间模型状态需要额外的轮、当前的批量计数等信息,以保证基于这个中间模型继续训练。通常这个中间模型被称为检查点(CheckPoints)。检查点的模型状态中通常不包括学习率,因为训练过程中它可能会动态调整。

    如果在每个批量数据训练完,权重更新后都保存检查点,中间模型的数量和占用的空间会非常大。所以实践中通常会在每轮结束后保存检查点,或者保留最近的几个检查点。

    案例

    AI框架通常都提供了模型训练的检查点保存能力。在MindSpore中,通过训练API提供了ModelCheckPointCheckpointConfig模块来帮助开发者保存模型训练过程中的检查点。MindSpore提供了三种检查点保存策略,包括直接保存、周期保存(迭代次数或者训练时长)、和异常保存(在训练失败的异常情况下保存的策略)。

    说明:检查点文件是一个二进制文件,存储了所有训练参数的值;且检查点的实现上采用了Protocol Buffers机制,与开发语言、平台无关,具有良好的可扩展性。

    在此,我们重点介绍下如何在MindSpore中周期性保存模型状态、以及在异常情况下保存故障点的模型状态。

    周期保存

    1)迭代次数方式保存

    下面的MindSpore代码片段展示了使用迭代次数配置检查点保存策略,以及在模型训练时通过回调的方式应用保存策略。训练开始后,会每隔1785个step保存一次检查点模型,并最多保留10个中间模型,模型的名称格式为checkpoint_lenet-1_1875.ckpt

    1. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
    2. # 设置模型保存参数,设置模型保存的策略,如本例中设置最多保存10个checkpoints,每隔1875个step保存一次
    3. config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)
    4. # 应用模型保存参数
    5. ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
    6. #通过回调的方式配置在模型训练的过程中
    7. model.train(epoch_size, ds_train, callbacks=[ckpoint_cb])

    加载CheckPoint可以通过load_checkpointload_param_into_net方法来完成,如下面的代码,通过load_checkpoint方法从保存好的checkpoint中加载网络的参数,再通过load_param_into_net将参数导入到具体的网络实例中,方便后面续训或者评估。

    1. from mindspore import load_checkpoint, load_param_into_net
    2. # 加载已经保存的用于测试的模型
    3. param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
    4. # 加载参数到网络中
    5. load_param_into_net(net, param_dict)

    完整的代码可以参考[1]中的案例。

    2)周期时间方式保存

    时间策略提供了按照秒和分钟配置参数,如下面的代码,每隔30秒保存一个CheckPoint文件,每隔3分钟保留一个CheckPoint文件。

    1. from mindspore import CheckpointConfig
    2. # 每隔30秒保存一个CheckPoint文件,每隔3分钟保留一个CheckPoint文件
    3. config_ck = CheckpointConfig(save_checkpoint_seconds=30, keep_checkpoint_per_n_minutes=3)

    异常保存

    如果模型较大,通常会减少梳理保留的检查点模型,间隔的时间会拉长。如盘古大模型的检查点保存间隔在4-5小时,如果在两个检查点之间失败,那么从上个检查点重新训练的时间损失会比较大。MindSpore在1.7版本扩展了检查点功能,提供断点续训能力,保证在训练异常时触发检查点,保证下次可以从发生故障时的模型状态继续训练,训练时间无损失。引入断点续训功能,只需在策略配置时增加“exception_save=True”的参数即可。

    1. from mindspore import ModelCheckpoint, CheckpointConfig
    2. # 配置断点续训功能开启
    3. config_ck = CheckpointConfig(save_checkpoint_steps=32, keep_checkpoint_max=10, exception_save=True)

    总结

    检查点(CheckPoints)模式最大的作用在于保证了模型训练的可靠性,同时也可以让开发者更容易的做早停。断点续训能力对于大模型的价值较大,异常状态下续训无时间损失,检查点模式也有利于转移学习时做fine-tune,这也是我们下一个要介绍的模式。

    参考资料

    [1] MindSpore完整案例:https://mindspore.cn/tutorials/zh-CN/r1.7/beginner/quick_start.html

    [2] MindSpore模型保存:https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/advanced/train/save.ipynb

    [3] 机器学习设计模式:https://www.oreilly.com/library/vie

    说明:严禁转载本文内容,否则视为侵权。 

  • 相关阅读:
    MES管理系统在电子行业的作用和效益
    redis学习(008 实战:黑马点评:缓存介绍)
    Bio-MOF-100 金属有机骨架材料
    三星大规模生产3nm芯片?预计明年就能流通各大手机厂商手中
    培训机构招生电子传单制作教程:突出核心竞争力的方法
    shell脚本中循环语句(极其粗糙版)
    计算机网络 —— 运输层(UDP和TCP)
    MySQL复习资料(附加)case when
    01 时钟配置初始化,debug
    国产AI绘画海克斯科技——爱作画AIGC开放平台
  • 原文地址:https://blog.csdn.net/Kenji_Shinji/article/details/126704608