• pytorch保存onnx模型


    因为一些原因,需要用pytorch去创建、训练和保存模型。pytorch保存的模型通常为pth、pt、pkl的格式,但这种类型的模型不能在其他框架(tensorflow)下直接加载,因此需要将模型保存为其他格式的。在网上进行相应的学习后,总结出一下两点:

    • pytorch可以直接将模型保存为onnx的,并且可以通过onnx转换为其他格式的模型(pb);
    • pytorch也可以直接将模型保存为caffemodel,但是需要一定的代码量去实现。

    前提条件:需要安装onnx 和 onnxruntime,可以通过 pip install onnxpip install onnxruntime 进行安装。

    实现代码

    1. import torch
    2. import torch.onnx
    3. from torch.autograd import Variable
    4. x = Variable(torch.randn(1, 3, 32, 32)).cuda()
    5. torch_out = torch.onnx.export(model, x,
    6. "test.onnx",
    7. export_params=True,
    8. verbose=True)

    API export说明

    1. export(model, args, f, export_params=True, verbose=False, training=False,
    2. input_names=None, output_names=None, aten=False, export_raw_ir=False,
    3. operator_export_type=None, opset_version=None, _retain_param_name=True,
    4. do_constant_folding=False, example_outputs=None, strip_doc_string=True, dynamic_axes=None)

    参数说明

    • model——需要导出的pytorch模型
    • args——模型的输入参数,满足输入层的shape正确即可。
    • path——输出的onnx模型的位置。例如‘yolov5.onnx’。
    • export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。
    • verbose——是否打印模型转换信息。default=False。
    • input_names——输入节点名称。default=None。
    • output_names——输出节点名称。default=None。
    • do_constant_folding——是否使用常量折叠(不了解),默认即可。default=True。
    • dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道。格式如下 :
      • 1)仅list(int) dynamic_axes={‘input’:[0,2,3],‘output’:[0,1]}
      • 2)仅dict dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:{0:‘batch’,1:‘c’}}
      • 3)mixed dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:[0,1]}
    • opset_version——opset的版本,低版本不支持upsample等操作。
    1. Export a model into ONNX format. This exporter runs your model
    2. once in order to get a trace of its execution to be exported;
    3. at the moment, it supports a limited set of dynamic models (e.g., RNNs.)
    4. See also: :ref:`onnx-export`
    5. Arguments:
    6. model (torch.nn.Module): the model to be exported.
    7. args (tuple of arguments): the inputs to
    8. the model, e.g., such that ``model(*args)`` is a valid
    9. invocation of the model. Any non-Tensor arguments will
    10. be hard-coded into the exported model; any Tensor arguments
    11. will become inputs of the exported model, in the order they
    12. occur in args. If args is a Tensor, this is equivalent
    13. to having called it with a 1-ary tuple of that Tensor.
    14. (Note: passing keyword arguments to the model is not currently
    15. supported. Give us a shout if you need it.)
    16. f: a file-like object (has to implement fileno that returns a file descriptor)
    17. or a string containing a file name. A binary Protobuf will be written
    18. to this file.
    19. export_params (bool, default True): if specified, all parameters will
    20. be exported. Set this to False if you want to export an untrained model.
    21. In this case, the exported model will first take all of its parameters
    22. as arguments, the ordering as specified by ``model.state_dict().values()``
    23. verbose (bool, default False): if specified, we will print out a debug
    24. description of the trace being exported.
    25. training (bool, default False): export the model in training mode. At
    26. the moment, ONNX is oriented towards exporting models for inference
    27. only, so you will generally not need to set this to True.
    28. input_names(list of strings, default empty list): names to assign to the
    29. input nodes of the graph, in order
    30. output_names(list of strings, default empty list): names to assign to the
    31. output nodes of the graph, in order
    32. aten (bool, default False): [DEPRECATED. use operator_export_type] export the
    33. model in aten mode. If using aten mode, all the ops original exported
    34. by the functions in symbolic_opset.py are exported as ATen ops.
    35. export_raw_ir (bool, default False): [DEPRECATED. use operator_export_type]
    36. export the internal IR directly instead of converting it to ONNX ops.
    37. operator_export_type (enum, default OperatorExportTypes.ONNX):
    38. OperatorExportTypes.ONNX: all ops are exported as regular ONNX ops.
    39. OperatorExportTypes.ONNX_ATEN: all ops are exported as ATen ops.
    40. OperatorExportTypes.ONNX_ATEN_FALLBACK: if symbolic is missing,
    41. fall back on ATen op.
    42. OperatorExportTypes.RAW: export raw ir.
    43. opset_version (int, default is 9): by default we export the model to the
    44. opset version of the onnx submodule. Since ONNX's latest opset may
    45. evolve before next stable release, by default we export to one stable
    46. opset version. Right now, supported stable opset version is 9.
    47. The opset_version must be _onnx_master_opset or in _onnx_stable_opsets
    48. which are defined in torch/onnx/symbolic_helper.py
    49. do_constant_folding (bool, default False): If True, the constant-folding
    50. optimization is applied to the model during export. Constant-folding
    51. optimization will replace some of the ops that have all constant
    52. inputs, with pre-computed constant nodes.
    53. example_outputs (tuple of Tensors, default None): example_outputs must be provided
    54. when exporting a ScriptModule or TorchScript Function.
    55. strip_doc_string (bool, default True): if True, strips the field
    56. "doc_string" from the exported model, which information about the stack
    57. trace.
    58. example_outputs: example outputs of the model that is being exported.
    59. dynamic_axes (dict> or dict, default empty dict):
    60. a dictionary to specify dynamic axes of input/output, such that:
    61. - KEY: input and/or output names
    62. - VALUE: index of dynamic axes for given key and potentially the name to be used for
    63. exported dynamic axes. In general the value is defined according to one of the following
    64. ways or a combination of both:
    65. (1). A list of integers specifiying the dynamic axes of provided input. In this scenario
    66. automated names will be generated and applied to dynamic axes of provided input/output
    67. during export.
    68. OR (2). An inner dictionary that specifies a mapping FROM the index of dynamic axis in
    69. corresponding input/output TO the name that is desired to be applied on such axis of
    70. such input/output during export.
    71. Example. if we have the following shape for inputs and outputs:
    72. shape(input_1) = ('b', 3, 'w', 'h')
    73. and shape(input_2) = ('b', 4)
    74. and shape(output) = ('b', 'd', 5)
    75. Then dynamic axes can be defined either as:
    76. (a). ONLY INDICES:
    77. dynamic_axes = {'input_1':[0, 2, 3], 'input_2':[0], 'output':[0, 1]}
    78. where automatic names will be generated for exported dynamic axes
    79. (b). INDICES WITH CORRESPONDING NAMES:
    80. dynamic_axes = {'input_1':{0:'batch', 1:'width', 2:'height'},
    81. 'input_2':{0:'batch'},
    82. 'output':{0:'batch', 1:'detections'}
    83. where provided names will be applied to exported dynamic axes
    84. (c). MIXED MODE OF (a) and (b)
    85. dynamic_axes = {'input_1':[0, 2, 3], 'input_2':{0:'batch'}, 'output':[0,1]}

    运行onnx模型

    检查onnx模型,并使用onnxruntime运行。 

    1. import onnx
    2. import onnxruntime as ort
    3. model = onnx.load('best.onnx')# 加载onnx
    4. onnx.checker.check_model(model) # 检查生成模型是否错误
    5. session = ort.InferenceSession('best.onnx')
    6. x=np.random.randn(1,3,32,32).astype(np.float32) # 注意输入type一定要np.float32!!!!!
    7. # x= torch.randn(batch_size,chancel,h,w)
    8. outputs = session.run(None,input = { 'input' : x })

    run()参数说明:

    • output_names: default=None。用来指定输出哪些,以及顺序。
      • 若为None,则按序输出所有的output,即返回[output_0,output_1];
      • 若为[‘output_1’,‘output_0’],则返回[output_1,output_0];
      • 若为[‘output_0’],则仅返回[output_0:tensor];
    • input:dict。可以通过session.get_inputs().name获得名称。其中key值要求与torch.onnx.export中设定的一致。

    onnx模型输出与pytorch模型比对

    1. import numpy as np
    2. np.testing.assert_allclose(torch_result[0].detach().numpu(),onnx_result,rtol=0.0001)

    参考链接:

    pytorch保存onnx模型_一杯盐水的博客-CSDN博客_pytorch 保存onnx

    从pytorch转换到onnx - 知乎

    pytorch模型转onnx模型_挣扎的笨鸟的博客-CSDN博客_pytorch转onnx

    unity+pytorch手写数字识别_ys999666的博客-CSDN博客_pytorch unity

  • 相关阅读:
    手把手教你运行Java开源框架若依RuoYi(视频教程)
    《Java》图书管理系统
    进程(2)——进程状态(僵尸,睡眠……)【linux】
    Pygame中监控鼠标动作的方法
    技术架构之术
    YOLOv8血细胞检测(4):Dual-ViT:一种多尺度双视觉Transformer ,Dualattention助力小目标检测| 顶刊TPAMI 2023
    数据分析------知识点(六)
    【面试高频题】二叉树“神级遍历“入门
    Ceph文件系统
    正则表达式
  • 原文地址:https://blog.csdn.net/weixin_43570470/article/details/126358511