码农知识堂 - 1000bd
  •   Python
  •   PHP
  •   JS/TS
  •   JAVA
  •   C/C++
  •   C#
  •   GO
  •   Kotlin
  •   Swift
  • pytorch_YOLOX剪枝【附代码】


    目录

    环境

    安装包

    feature视化

    网络剪枝

    剪枝后的微调训练

    训练自己的数据集

    预测

    Conv与BN层的融合推理加速

    日志文件的保存

    权重


    环境

    pytorch 1.7

    loguru 0.5.3

    NVIDIA 1650 4G

    intel i5-9th

    torch-pruning 0.2.7


    安装包

    pip install torch_pruning

    Note:本项目是在b站up主Bubbliiiing和原YOLOX官方代码进行了整合。

    1.添加了feature可视化功能

    2.训练中可开启EMA功能

    3.网络剪枝(支持s,m,l,x)

    3.1支持单个卷积剪枝

    3.2支持网络层剪枝

    4.剪枝后微调训练

    5.Conv与BN层的融合推理加速

    6.保存log信息

    数据集格式:采用voc数据集格式


    feature视化

    在tools/Net_Vision.py为可视化代码实现。可以通过在网络层导入NV函数,实现通道可视化。

    eg:

    1. features = [out_features[f] for f in self.in_features]
    2. [x2, x1, x0] = features  # shape is (batch_size,channels,w,h)
    3. NV(x2)

     


    网络剪枝

    参考论文:Pruning Filters for Efficient ConvNets

    导入剪枝工具

    import torch_pruning as tp

    如果需要看yolov4的,可以看:YOLOv4剪枝【附代码】_爱吃肉的鹏的博客-CSDN博客_yolov4剪枝

    采用通道剪枝,而不是权重剪枝。

    在剪枝之前需要通过tools/prunmodel.py save_whole_model(weights_path, num_classes) 函数将模型的权重和结构都保存下来。

    weights_path:权重路径

    num_classes:自己类别数量

    model = YOLOX(num_classes, 's')  # 这里需要根据自己的类数量修改  s指yolox-s

    支持对某个卷积的剪枝:调用Conv_pruning(whole_model_weights):

    pruning_idxs = strategy(v, amount=0.4)  # 0.4是剪枝率 根据需要自己修改,数越大剪的越多

    对于单独一个卷积的剪枝,需要修改两个地方值,这里的卷积层需要打印模型获得,不要自己盲目瞎猜:

    if k == 'backbone.backbone.dark2.0.conv.weight'
    pruning_plan = DG.get_pruning_plan((model.backbone.backbone.dark2)[0].conv, tp.prune_conv, idxs=pruning_idxs)

    支持网络层的剪枝:调用layer_pruning(whole_model_weights):

    included_layers = list((model.backbone.backbone.dark2.modules()))  # 针对某层剪枝

    Note:剪枝成功以后,会打印模型的参数变化量!如果没有打印,说明你剪的不对,好好检查一下!

    剪枝以后的log日志文件会保存在logs文件下

    剪枝后的微调训练

    将train.py中的pruned_train设置为True.

    False为正常训练,然后自己修改batch_size。

    注意修改model_path和classes_path,不然会报错!

    剪枝前的网络输入大小和微调训练以及推理时的大小必须一致!

    训练自己的数据集

    如果你有用过Bubbliiiing up主的代码,你将很快就能上手。数据集采用的是VOC的形式

    VOCdevkit/
    `-- VOC2007
        |-- Annotations  (存放xml标签文件)
        |-- ImageSets
        |   `-- Main
        `-- JPEGImages (存放图像)
    ​

    在model_data中新建一个new_classes.txt,里面写入自己的类。运行voc_annotation.py,会在当前目录生成2007_train.txt文件和2007_val.txt文件。(可以检查一下里面有没有生成成功)

    在train.py中,将classes_path修改为model_data/new_classes.txt【等预测的时候,也是需要在yolo.py中修改这里】

    然后根据需要修改其他超参即可训练,训练权重会保存在logs文件中(默认保存权值,不含网络结构)

    预测

     

    参数说明:下面终端的输入都是可选的

    --predict:预测模式

    --pruned:开启剪枝预测或训练

    --image:图像检测

    --video:开始视频检测

    --video_path:视频路径

    --camid:摄像头id 默认0

    --fps:FPS测试

    --dir_predict:对一个文件夹下图像进行预测

    --phi:可以选择s,m,l,x等

    --input_shape:网络输入大小,默认640

    --confidence:置信度阈值

    --nms_iou:iou阈值

    --num_classes:类别数量,默认80

    --fuse:是否开启卷积层和BN层融合加速,默认False

    终端输入:

    # 图像预测
    python demo.py --predict --image
    # 视频预测
    python demo.py --predict --video --camdi 0
    # fps测试
    python demo.py --predict --fps

    默认预测都为yolox_s,如果要指定其他网络,输入:(需要注意的是在yolo.py修改权重路径,如果是自己数据集,还需要修改classes_path)

    # 使用yolox_l预测
    python demo.py --predict --image --phi l

    Conv与BN层的融合推理加速

    其他命令可以搭配使用,比如采用conv和bn融合的方式进行推理

    python demo.py --predict --image --fuse

    通过测试发现FPS提升了3帧/s左右(我的GPU是1650)

    日志文件的保存

    本项目采用loguru工具捕获日志,检测和训练中的一些日志记录会自动记录,保存在logs文件下,一个log文件的大小我设置的上限大小为1 MB,如果超过该范围,会自动生成一个新的.log文件,可以自己修改这个值,或者修改日志保存时间(以免保存了太多的日志)。如果不想要这个功能,可以找到相应的位置注释掉即可。

    这里只是帮助大家造轮子,用尽可能简单的代码实现一些功能,不用大家再去看复杂的工程代码,最终的效果需要自己耐心调试,慢慢“炼丹”!

    权重

    链接:百度网盘 请输入提取码icon-default.png?t=M5H6https://pan.baidu.com/s/1Jbq8dCv893rZ7RkaANUZgQ 提取码:yypn

    代码(如果有帮助,麻烦点个star呗~):

    GitHub - YINYIPENG-EN/Pruning_for_YOLOX: 实现对YOLOX的剪枝操作,添加了卷积层和BN层融合推理,添加中间层可视化功能,可实现预测和训练日志保存实现对YOLOX的剪枝操作,添加了卷积层和BN层融合推理,添加中间层可视化功能,可实现预测和训练日志保存 - GitHub - YINYIPENG-EN/Pruning_for_YOLOX: 实现对YOLOX的剪枝操作,添加了卷积层和BN层融合推理,添加中间层可视化功能,可实现预测和训练日志保存https://github.com/YINYIPENG-EN/Pruning_for_YOLOX

  • 相关阅读:
    DevOps(十二)Jenkins实战之Web发布到远程服务器
    同一项目如何连接多个mongo服务器地址
    Element-ui源码解析(二):最简单的组件Button
    ‘face_alignment‘ has no attribute ‘FaceAlignment‘
    LeetCode+ 71 - 75
    ERC20
    直播电商企业“快反”模式与数字化营销转型:兼论开源 AI 智能名片 S2B2C 商城小程序的应用
    (WSI分类)WSI分类文献小综述
    C++ 泛型编程-模板
    TL072ACDR 丝印072AC SOP-8 双路JFET输入运算放大器芯片
  • 原文地址:https://blog.csdn.net/z240626191s/article/details/125608886
  • 最新文章
  • 攻防演习之三天拿下官网站群
    数据安全治理学习——前期安全规划和安全管理体系建设
    企业安全 | 企业内一次钓鱼演练准备过程
    内网渗透测试 | Kerberos协议及其部分攻击手法
    0day的产生 | 不懂代码的"代码审计"
    安装scrcpy-client模块av模块异常,环境问题解决方案
    leetcode hot100【LeetCode 279. 完全平方数】java实现
    OpenWrt下安装Mosquitto
    AnatoMask论文汇总
    【AI日记】24.11.01 LangChain、openai api和github copilot
  • 热门文章
  • 十款代码表白小特效 一个比一个浪漫 赶紧收藏起来吧!!!
    奉劝各位学弟学妹们,该打造你的技术影响力了!
    五年了,我在 CSDN 的两个一百万。
    Java俄罗斯方块,老程序员花了一个周末,连接中学年代!
    面试官都震惊,你这网络基础可以啊!
    你真的会用百度吗?我不信 — 那些不为人知的搜索引擎语法
    心情不好的时候,用 Python 画棵樱花树送给自己吧
    通宵一晚做出来的一款类似CS的第一人称射击游戏Demo!原来做游戏也不是很难,连憨憨学妹都学会了!
    13 万字 C 语言从入门到精通保姆级教程2021 年版
    10行代码集2000张美女图,Python爬虫120例,再上征途
Copyright © 2022 侵权请联系2656653265@qq.com    京ICP备2022015340号-1
正则表达式工具 cron表达式工具 密码生成工具

京公网安备 11010502049817号