• 发票关键信息抽取SER


    SER(Semantic Entity Recognition):语义实体识别。语义实体识别指的是给定一段文本行,确定其类别(如姓名住址等类别)。本文采用基于VI-LayoutXLM的多模态语义实体识别方法。

    1、增值税发票数据集 

        https://download.csdn.net/download/ronshi/88467149

    2、训练模型

    python tools/train.py -c my/ser_vi_layoutxlm_xfund_zh_udml.yml 

    配置文件 ser_vi_layoutxlm_xfund_zh_udml.yml

    1. Global:
    2. use_gpu: true
    3. epoch_num: &epoch_num 200
    4. log_smooth_window: 10
    5. print_batch_step: 10
    6. save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh_udml
    7. save_epoch_step: 2000
    8. # evaluation is run every 10 iterations after the 0th iteration
    9. eval_batch_step: [ 0, 19 ]
    10. cal_metric_during_train: False
    11. save_inference_dir:
    12. use_visualdl: False
    13. seed: 2022
    14. infer_img: my/b201.jpg
    15. save_res_path: ./output/ser_layoutxlm_xfund_zh/res
    16. Architecture:
    17. model_type: &model_type "kie"
    18. name: DistillationModel
    19. algorithm: Distillation
    20. Models:
    21. Teacher:
    22. pretrained:
    23. freeze_params: false
    24. return_all_feats: true
    25. model_type: *model_type
    26. algorithm: &algorithm "LayoutXLM"
    27. Transform:
    28. Backbone:
    29. name: LayoutXLMForSer
    30. pretrained: True
    31. # one of base or vi
    32. mode: vi
    33. checkpoints:
    34. num_classes: &num_classes 5
    35. Student:
    36. pretrained:
    37. freeze_params: false
    38. return_all_feats: true
    39. model_type: *model_type
    40. algorithm: *algorithm
    41. Transform:
    42. Backbone:
    43. name: LayoutXLMForSer
    44. pretrained: True
    45. # one of base or vi
    46. mode: vi
    47. checkpoints:
    48. num_classes: *num_classes
    49. Loss:
    50. name: CombinedLoss
    51. loss_config_list:
    52. - DistillationVQASerTokenLayoutLMLoss:
    53. weight: 1.0
    54. model_name_list: ["Student", "Teacher"]
    55. key: backbone_out
    56. num_classes: *num_classes
    57. - DistillationSERDMLLoss:
    58. weight: 1.0
    59. act: "softmax"
    60. use_log: true
    61. model_name_pairs:
    62. - ["Student", "Teacher"]
    63. key: backbone_out
    64. - DistillationVQADistanceLoss:
    65. weight: 0.5
    66. mode: "l2"
    67. model_name_pairs:
    68. - ["Student", "Teacher"]
    69. key: hidden_states_5
    70. name: "loss_5"
    71. - DistillationVQADistanceLoss:
    72. weight: 0.5
    73. mode: "l2"
    74. model_name_pairs:
    75. - ["Student", "Teacher"]
    76. key: hidden_states_8
    77. name: "loss_8"
    78. Optimizer:
    79. name: AdamW
    80. beta1: 0.9
    81. beta2: 0.999
    82. lr:
    83. name: Linear
    84. learning_rate: 0.00005
    85. epochs: *epoch_num
    86. warmup_epoch: 10
    87. regularizer:
    88. name: L2
    89. factor: 0.00000
    90. PostProcess:
    91. name: DistillationSerPostProcess
    92. model_name: ["Student", "Teacher"]
    93. key: backbone_out
    94. class_path: &class_path my/zzsfp/class_list.txt
    95. Metric:
    96. name: DistillationMetric
    97. base_metric_name: VQASerTokenMetric
    98. main_indicator: hmean
    99. key: "Student"
    100. Train:
    101. dataset:
    102. name: SimpleDataSet
    103. data_dir: my/zzsfp/imgs
    104. label_file_list:
    105. - my/zzsfp/train.json
    106. ratio_list: [ 1.0 ]
    107. transforms:
    108. - DecodeImage: # load image
    109. img_mode: RGB
    110. channel_first: False
    111. - VQATokenLabelEncode: # Class handling label
    112. contains_re: False
    113. algorithm: *algorithm
    114. class_path: *class_path
    115. # one of [None, "tb-yx"]
    116. order_method: &order_method "tb-yx"
    117. - VQATokenPad:
    118. max_seq_len: &max_seq_len 512
    119. return_attention_mask: True
    120. - VQASerTokenChunk:
    121. max_seq_len: *max_seq_len
    122. - Resize:
    123. size: [224,224]
    124. - NormalizeImage:
    125. scale: 1
    126. mean: [ 123.675, 116.28, 103.53 ]
    127. std: [ 58.395, 57.12, 57.375 ]
    128. order: 'hwc'
    129. - ToCHWImage:
    130. - KeepKeys:
    131. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
    132. loader:
    133. shuffle: True
    134. drop_last: False
    135. batch_size_per_card: 4
    136. num_workers: 4
    137. Eval:
    138. dataset:
    139. name: SimpleDataSet
    140. data_dir: my/zzsfp/imgs
    141. label_file_list:
    142. - my/zzsfp/val.json
    143. transforms:
    144. - DecodeImage: # load image
    145. img_mode: RGB
    146. channel_first: False
    147. - VQATokenLabelEncode: # Class handling label
    148. contains_re: False
    149. algorithm: *algorithm
    150. class_path: *class_path
    151. order_method: *order_method
    152. - VQATokenPad:
    153. max_seq_len: *max_seq_len
    154. return_attention_mask: True
    155. - VQASerTokenChunk:
    156. max_seq_len: *max_seq_len
    157. - Resize:
    158. size: [224,224]
    159. - NormalizeImage:
    160. scale: 1
    161. mean: [ 123.675, 116.28, 103.53 ]
    162. std: [ 58.395, 57.12, 57.375 ]
    163. order: 'hwc'
    164. - ToCHWImage:
    165. - KeepKeys:
    166. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
    167. loader:
    168. shuffle: False
    169. drop_last: False
    170. batch_size_per_card: 4
    171. num_workers: 4

    3、模型评估

    python tools/eval.py -c my/ser_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./output/ser_vi_layoutxlm_xfund_zh_udml/best_accuracy

     配置文件 ser_vi_layoutxlm_xfund_zh.yml

    1. Global:
    2. use_gpu: True
    3. epoch_num: &epoch_num 200
    4. log_smooth_window: 10
    5. print_batch_step: 10
    6. save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh
    7. save_epoch_step: 2000
    8. # evaluation is run every 10 iterations after the 0th iteration
    9. eval_batch_step: [ 0, 19 ]
    10. cal_metric_during_train: False
    11. save_inference_dir:
    12. use_visualdl: False
    13. seed: 2022
    14. infer_img: my/b201.jpg
    15. d2s_train_image_shape: [3, 224, 224]
    16. # if you want to predict using the groundtruth ocr info,
    17. # you can use the following config
    18. # infer_img: train_data/XFUND/zh_val/val.json
    19. # infer_mode: False
    20. save_res_path: ./output/ser/xfund_zh/res
    21. kie_rec_model_dir:
    22. kie_det_model_dir:
    23. amp_custom_white_list: ['scale', 'concat', 'elementwise_add']
    24. Architecture:
    25. model_type: kie
    26. algorithm: &algorithm "LayoutXLM"
    27. Transform:
    28. Backbone:
    29. name: LayoutXLMForSer
    30. pretrained: True
    31. checkpoints:
    32. # one of base or vi
    33. mode: vi
    34. num_classes: &num_classes 5
    35. Loss:
    36. name: VQASerTokenLayoutLMLoss
    37. num_classes: *num_classes
    38. key: "backbone_out"
    39. Optimizer:
    40. name: AdamW
    41. beta1: 0.9
    42. beta2: 0.999
    43. lr:
    44. name: Linear
    45. learning_rate: 0.00005
    46. epochs: *epoch_num
    47. warmup_epoch: 2
    48. regularizer:
    49. name: L2
    50. factor: 0.00000
    51. PostProcess:
    52. name: VQASerTokenLayoutLMPostProcess
    53. class_path: &class_path my/zzsfp/class_list.txt
    54. Metric:
    55. name: VQASerTokenMetric
    56. main_indicator: hmean
    57. Train:
    58. dataset:
    59. name: SimpleDataSet
    60. data_dir: my/zzsfp/imgs
    61. label_file_list:
    62. - my/zzsfp/train.json
    63. ratio_list: [ 1.0 ]
    64. transforms:
    65. - DecodeImage: # load image
    66. img_mode: RGB
    67. channel_first: False
    68. - VQATokenLabelEncode: # Class handling label
    69. contains_re: False
    70. algorithm: *algorithm
    71. class_path: *class_path
    72. use_textline_bbox_info: &use_textline_bbox_info True
    73. # one of [None, "tb-yx"]
    74. order_method: &order_method "tb-yx"
    75. - VQATokenPad:
    76. max_seq_len: &max_seq_len 512
    77. return_attention_mask: True
    78. - VQASerTokenChunk:
    79. max_seq_len: *max_seq_len
    80. - Resize:
    81. size: [224,224]
    82. - NormalizeImage:
    83. scale: 1
    84. mean: [ 123.675, 116.28, 103.53 ]
    85. std: [ 58.395, 57.12, 57.375 ]
    86. order: 'hwc'
    87. - ToCHWImage:
    88. - KeepKeys:
    89. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
    90. loader:
    91. shuffle: True
    92. drop_last: False
    93. batch_size_per_card: 1
    94. num_workers: 1
    95. Eval:
    96. dataset:
    97. name: SimpleDataSet
    98. data_dir: my/zzsfp/imgs
    99. label_file_list:
    100. - my/zzsfp/val.json
    101. transforms:
    102. - DecodeImage: # load image
    103. img_mode: RGB
    104. channel_first: False
    105. - VQATokenLabelEncode: # Class handling label
    106. contains_re: False
    107. algorithm: *algorithm
    108. class_path: *class_path
    109. use_textline_bbox_info: *use_textline_bbox_info
    110. order_method: *order_method
    111. - VQATokenPad:
    112. max_seq_len: *max_seq_len
    113. return_attention_mask: True
    114. - VQASerTokenChunk:
    115. max_seq_len: *max_seq_len
    116. - Resize:
    117. size: [224,224]
    118. - NormalizeImage:
    119. scale: 1
    120. mean: [ 123.675, 116.28, 103.53 ]
    121. std: [ 58.395, 57.12, 57.375 ]
    122. order: 'hwc'
    123. - ToCHWImage:
    124. - KeepKeys:
    125. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
    126. loader:
    127. shuffle: False
    128. drop_last: False
    129. batch_size_per_card: 1
    130. num_workers: 1

    4、模型预测 

    python tools/infer_kie_token_ser.py -c my/ser_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./output/ser_vi_layoutxlm_xfund_zh_udml/best_accuracy Global.infer_img=./my/zzsfp/val.json Global.infer_mode=False

    python tools/infer_kie_token_ser.py -c my/ser_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./output/ser_vi_layoutxlm_xfund_zh_udml/best_accuracy Global.infer_img=./my/b201.jpg Global.infer_mode=True 

    5、模型导出 

    python tools/export_model.py -c my/ser_vi_layoutxlm_xfund_zh.yml -o Global.pretrained_model=./output/ser_vi_layoutxlm_xfund_zh_udml/best_accuracy Global.save_inference_dir=./inference/ser_vi_layoutxlm

     6、模型推理

    python ppstructure/kie/predict_kie_token_ser.py \
      --kie_algorithm=LayoutXLM \
      --ser_model_dir=./inference/ser_vi_layoutxlm \
      --image_dir=./my/b201.jpg \
      --ser_dict_path=./my/zzsfp/class_list.txt \
      --vis_font_path=./doc/fonts/simfang.ttf \
      --ocr_order_method="tb-yx"

  • 相关阅读:
    python与pycharm配置http服务
    线性同余方程( 数学知识 + 同余 + 扩展欧几里得算法 )
    计算机毕业设计Java自行车在线租赁管理系统2021(源码+系统+mysql数据库+Lw文档)
    AAC算法
    tcpdump工具使用
    Java自定义注解解析
    python版超市信息管理系统源代码,基于tkinter带界面
    关于#搜索引擎#的问题:虽然静下心学英语,更适合中国宝宝认知的学习英语,应试英语的方法,针对于听力、阅读、单词量(最头疼)、作文,有什么方向或经验能推荐、分享一下吗
    Java版分布式微服务云开发架构 Spring Cloud+Spring Boot+Mybatis 电子招标采购系统功能清单
    JuiceFS 在多云存储架构中的应用 | 深势科技分享
  • 原文地址:https://blog.csdn.net/ronshi/article/details/134017751