• mmdetection训练得到的权重/checkpoints文件分析


    这篇文章对mmdetection(包括mmlab的其他例如mmclassification等)训练得到的模型权重,或者说checkpoints文件进行分析,一般模型保存在work-dir文件夹下,具体路径要参考训练用到的config,即配置文件。保存的模型一般是.pt的文件。

    读取.pth文件具体数值

    修改.pth文件具体数值(比如修改卷积核通道数)

    读取.pth文件具体数值

    .pt模型文件读取方法

    这种模型文件可以用torch.load()函数进行解析

    1. import torch
    2. pth_path = 'work-dir/your_check_point.pt'
    3. model = torch.load(pth_path)

    这里我们就可以看到这个model实际上不是什么复杂的东西,就是一个很大的dict

    这个model一般包括三个key、value。

    meta

    第一个:meta,包含一些基本信息。就是告诉你这个模型是在什么背景下被训练得到的,用的mmdet是什么版本,随机种子seed是多少,config是什么,方便你复现复刻出来这个model 

    state_dict 

    这个是模型关键。一般网上下载的预训练权重只有这个,其是一个大的OrderedDict里面包含了这个模型按顺序得到的各层参数,看下图就明白个大概了。

    一般要利用一个checkpoint(.pt的模型权重文件) ,也就是主要读取这里面的信息,来进行refine或者infer。

    optimizer 

    里面存放的是优化器的状态,方便用这个.pt文件进行resume,即意外中断实验的时候进行继续实验,结合mmdet的train.py里的resume_from命令理解。

    修改.pth文件具体数值(比如修改卷积核通道数)

    有时候我们修改了网络部分,会导致预训练权重的shape跟网络修改后的shape不匹配。最经典的例子就是我们希望输入图片从3channels变成6channels。就会报诸如

    torch: input tensor shape not match

    的错误,那么这个时候为了正常用预训练模型,我们就需要手动去修改模型权重。当然如果相关module的封装特别好比如MMdetection的一些backbone和neck module,是可以手动选择in_channels来自动匹配input的shape修改的,但是这种匹配也是建立在已封装了修改具体权重的基础上,实际上我们更希望手动去控制一些修改权重的方法(比如初始化法)

  • 相关阅读:
    process information unavailable解决方案
    【MySQL架构篇】存储引擎
    计算几何算法模板
    Mac和Linux中的chmod +x命令详解
    展商企业【广东伟创科技开发有限公司】| 2024水科技大会暨技术装备成果展
    HDU 1712:ACboy needs your help ← 分组背包问题
    【Linux】vim_gcc_动静态库
    (免费领源码)JAVA#springboot#MYSQL 社区医院病历管理平台11271-计算机毕业设计项目选题推荐
    RocketMQ保姆级教程
    k8s-dynamic-pvc
  • 原文地址:https://blog.csdn.net/jiangqixing0728/article/details/127940476