SER(Semantic Entity Recognition):语义实体识别。语义实体识别指的是给定一段文本行,确定其类别(如姓名、住址等类别)。本文采用基于VI-LayoutXLM的多模态语义实体识别方法。
1、增值税发票数据集
https://download.csdn.net/download/ronshi/88467149
2、训练模型
python tools/train.py -c my/ser_vi_layoutxlm_xfund_zh_udml.yml
配置文件 ser_vi_layoutxlm_xfund_zh_udml.yml
- Global:
- use_gpu: true
- epoch_num: &epoch_num 200
- log_smooth_window: 10
- print_batch_step: 10
- save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh_udml
- save_epoch_step: 2000
- # evaluation is run every 10 iterations after the 0th iteration
- eval_batch_step: [ 0, 19 ]
- cal_metric_during_train: False
- save_inference_dir:
- use_visualdl: False
- seed: 2022
- infer_img: my/b201.jpg
- save_res_path: ./output/ser_layoutxlm_xfund_zh/res
-
-
- Architecture:
- model_type: &model_type "kie"
- name: DistillationModel
- algorithm: Distillation
- Models:
- Teacher:
- pretrained:
- freeze_params: false
- return_all_feats: true
- model_type: *model_type
- algorithm: &algorithm "LayoutXLM"
- Transform:
- Backbone:
- name: LayoutXLMForSer
- pretrained: True
- # one of base or vi
- mode: vi
- checkpoints:
- num_classes: &num_classes 5
- Student:
- pretrained:
- freeze_params: false
- return_all_feats: true
- model_type: *model_type
- algorithm: *algorithm
- Transform:
- Backbone:
- name: LayoutXLMForSer
- pretrained: True
- # one of base or vi
- mode: vi
- checkpoints:
- num_classes: *num_classes
-
-
- Loss:
- name: CombinedLoss
- loss_config_list:
- - DistillationVQASerTokenLayoutLMLoss:
- weight: 1.0
- model_name_list: ["Student", "Teacher"]
- key: backbone_out
- num_classes: *num_classes
- - DistillationSERDMLLoss:
- weight: 1.0
- act: "softmax"
- use_log: true
- model_name_pairs:
- - ["Student", "Teacher"]
- key: backbone_out
- - DistillationVQADistanceLoss:
- weight: 0.5
- mode: "l2"
- model_name_pairs:
- - ["Student", "Teacher"]
- key: hidden_states_5
- name: "loss_5"
- - DistillationVQADistanceLoss:
- weight: 0.5
- mode: "l2"
- model_name_pairs:
- - ["Student", "Teacher"]
- key: hidden_states_8
- name: "loss_8"
-
-
-
- Optimizer:
- name: AdamW
- beta1: 0.9
- beta2: 0.999
- lr:
- name: Linear
- learning_rate: 0.00005
- epochs: *epoch_num
- warmup_epoch: 10
- regularizer:
- name: L2
- factor: 0.00000
-
- PostProcess:
- name: DistillationSerPostProcess
- model_name: ["Student", "Teacher"]
- key: backbone_out
- class_path: &class_path my/zzsfp/class_list.txt
-
- Metric:
- name: DistillationMetric
- base_metric_name: VQASerTokenMetric
- main_indicator: hmean
- key: "Student"
-
- Train:
- dataset:
- name: SimpleDataSet
- data_dir: my/zzsfp/imgs
- label_file_list:
- - my/zzsfp/train.json
- ratio_list: [ 1.0 ]
- transforms:
- - DecodeImage: # load image
- img_mode: RGB
- channel_first: False
- - VQATokenLabelEncode: # Class handling label
- contains_re: False
- algorithm: *algorithm
- class_path: *class_path
- # one of [None, "tb-yx"]
- order_method: &order_method "tb-yx"
- - VQATokenPad:
- max_seq_len: &max_seq_len 512
- return_attention_mask: True
- - VQASerTokenChunk:
- max_seq_len: *max_seq_len
- - Resize:
- size: [224,224]
- - NormalizeImage:
- scale: 1
- mean: [ 123.675, 116.28, 103.53 ]
- std: [ 58.395, 57.12, 57.375 ]
- order: 'hwc'
- - ToCHWImage:
- - KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
- loader:
- shuffle: True
- drop_last: False
- batch_size_per_card: 4
- num_workers: 4
-
- Eval:
- dataset:
- name: SimpleDataSet
- data_dir: my/zzsfp/imgs
- label_file_list:
- - my/zzsfp/val.json
- transforms:
- - DecodeImage: # load image
- img_mode: RGB
- channel_first: False
- - VQATokenLabelEncode: # Class handling label
- contains_re: False
- algorithm: *algorithm
- class_path: *class_path
- order_method: *order_method
- - VQATokenPad:
- max_seq_len: *max_seq_len
- return_attention_mask: True
- - VQASerTokenChunk:
- max_seq_len: *max_seq_len
- - Resize:
- size: [224,224]
- - NormalizeImage:
- scale: 1
- mean: [ 123.675, 116.28, 103.53 ]
- std: [ 58.395, 57.12, 57.375 ]
- order: 'hwc'
- - ToCHWImage:
- - KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
- loader:
- shuffle: False
- drop_last: False
- batch_size_per_card: 4
- num_workers: 4
3、模型评估
python tools/eval.py -c my/ser_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./output/ser_vi_layoutxlm_xfund_zh_udml/best_accuracy
配置文件 ser_vi_layoutxlm_xfund_zh.yml
- Global:
- use_gpu: True
- epoch_num: &epoch_num 200
- log_smooth_window: 10
- print_batch_step: 10
- save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh
- save_epoch_step: 2000
- # evaluation is run every 10 iterations after the 0th iteration
- eval_batch_step: [ 0, 19 ]
- cal_metric_during_train: False
- save_inference_dir:
- use_visualdl: False
- seed: 2022
- infer_img: my/b201.jpg
- d2s_train_image_shape: [3, 224, 224]
- # if you want to predict using the groundtruth ocr info,
- # you can use the following config
- # infer_img: train_data/XFUND/zh_val/val.json
- # infer_mode: False
-
- save_res_path: ./output/ser/xfund_zh/res
- kie_rec_model_dir:
- kie_det_model_dir:
- amp_custom_white_list: ['scale', 'concat', 'elementwise_add']
-
- Architecture:
- model_type: kie
- algorithm: &algorithm "LayoutXLM"
- Transform:
- Backbone:
- name: LayoutXLMForSer
- pretrained: True
- checkpoints:
- # one of base or vi
- mode: vi
- num_classes: &num_classes 5
-
- Loss:
- name: VQASerTokenLayoutLMLoss
- num_classes: *num_classes
- key: "backbone_out"
-
- Optimizer:
- name: AdamW
- beta1: 0.9
- beta2: 0.999
- lr:
- name: Linear
- learning_rate: 0.00005
- epochs: *epoch_num
- warmup_epoch: 2
- regularizer:
- name: L2
- factor: 0.00000
-
- PostProcess:
- name: VQASerTokenLayoutLMPostProcess
- class_path: &class_path my/zzsfp/class_list.txt
-
- Metric:
- name: VQASerTokenMetric
- main_indicator: hmean
-
- Train:
- dataset:
- name: SimpleDataSet
- data_dir: my/zzsfp/imgs
- label_file_list:
- - my/zzsfp/train.json
- ratio_list: [ 1.0 ]
- transforms:
- - DecodeImage: # load image
- img_mode: RGB
- channel_first: False
- - VQATokenLabelEncode: # Class handling label
- contains_re: False
- algorithm: *algorithm
- class_path: *class_path
- use_textline_bbox_info: &use_textline_bbox_info True
- # one of [None, "tb-yx"]
- order_method: &order_method "tb-yx"
- - VQATokenPad:
- max_seq_len: &max_seq_len 512
- return_attention_mask: True
- - VQASerTokenChunk:
- max_seq_len: *max_seq_len
- - Resize:
- size: [224,224]
- - NormalizeImage:
- scale: 1
- mean: [ 123.675, 116.28, 103.53 ]
- std: [ 58.395, 57.12, 57.375 ]
- order: 'hwc'
- - ToCHWImage:
- - KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
- loader:
- shuffle: True
- drop_last: False
- batch_size_per_card: 1
- num_workers: 1
-
- Eval:
- dataset:
- name: SimpleDataSet
- data_dir: my/zzsfp/imgs
- label_file_list:
- - my/zzsfp/val.json
- transforms:
- - DecodeImage: # load image
- img_mode: RGB
- channel_first: False
- - VQATokenLabelEncode: # Class handling label
- contains_re: False
- algorithm: *algorithm
- class_path: *class_path
- use_textline_bbox_info: *use_textline_bbox_info
- order_method: *order_method
- - VQATokenPad:
- max_seq_len: *max_seq_len
- return_attention_mask: True
- - VQASerTokenChunk:
- max_seq_len: *max_seq_len
- - Resize:
- size: [224,224]
- - NormalizeImage:
- scale: 1
- mean: [ 123.675, 116.28, 103.53 ]
- std: [ 58.395, 57.12, 57.375 ]
- order: 'hwc'
- - ToCHWImage:
- - KeepKeys:
- keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
- loader:
- shuffle: False
- drop_last: False
- batch_size_per_card: 1
- num_workers: 1
4、模型预测
python tools/infer_kie_token_ser.py -c my/ser_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./output/ser_vi_layoutxlm_xfund_zh_udml/best_accuracy Global.infer_img=./my/zzsfp/val.json Global.infer_mode=False
python tools/infer_kie_token_ser.py -c my/ser_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./output/ser_vi_layoutxlm_xfund_zh_udml/best_accuracy Global.infer_img=./my/b201.jpg Global.infer_mode=True
5、模型导出
python tools/export_model.py -c my/ser_vi_layoutxlm_xfund_zh.yml -o Global.pretrained_model=./output/ser_vi_layoutxlm_xfund_zh_udml/best_accuracy Global.save_inference_dir=./inference/ser_vi_layoutxlm
6、模型推理
python ppstructure/kie/predict_kie_token_ser.py \
--kie_algorithm=LayoutXLM \
--ser_model_dir=./inference/ser_vi_layoutxlm \
--image_dir=./my/b201.jpg \
--ser_dict_path=./my/zzsfp/class_list.txt \
--vis_font_path=./doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx"