ser_vi_layoutxlm_xfund_zh.yml 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. Global:
  2. use_gpu: True
  3. epoch_num: &epoch_num 200
  4. log_smooth_window: 10
  5. print_batch_step: 10
  6. save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh
  7. save_epoch_step: 2000
  8. # evaluation is run every 10 iterations after the 0th iteration
  9. eval_batch_step: [ 0, 19 ]
  10. cal_metric_during_train: False
  11. save_inference_dir:
  12. use_visualdl: False
  13. seed: 2022
  14. infer_img: ppstructure/docs/kie/input/zh_val_42.jpg
  15. # if you want to predict using the groundtruth ocr info,
  16. # you can use the following config
  17. # infer_img: train_data/XFUND/zh_val/val.json
  18. # infer_mode: False
  19. save_res_path: ./output/ser/xfund_zh/res
  20. kie_rec_model_dir:
  21. kie_det_model_dir:
  22. Architecture:
  23. model_type: kie
  24. algorithm: &algorithm "LayoutXLM"
  25. Transform:
  26. Backbone:
  27. name: LayoutXLMForSer
  28. pretrained: True
  29. checkpoints:
  30. # one of base or vi
  31. mode: vi
  32. num_classes: &num_classes 7
  33. Loss:
  34. name: VQASerTokenLayoutLMLoss
  35. num_classes: *num_classes
  36. key: "backbone_out"
  37. Optimizer:
  38. name: AdamW
  39. beta1: 0.9
  40. beta2: 0.999
  41. lr:
  42. name: Linear
  43. learning_rate: 0.00005
  44. epochs: *epoch_num
  45. warmup_epoch: 2
  46. regularizer:
  47. name: L2
  48. factor: 0.00000
  49. PostProcess:
  50. name: VQASerTokenLayoutLMPostProcess
  51. class_path: &class_path train_data/XFUND/class_list_xfun.txt
  52. Metric:
  53. name: VQASerTokenMetric
  54. main_indicator: hmean
  55. Train:
  56. dataset:
  57. name: SimpleDataSet
  58. data_dir: train_data/XFUND/zh_train/image
  59. label_file_list:
  60. - train_data/XFUND/zh_train/train.json
  61. ratio_list: [ 1.0 ]
  62. transforms:
  63. - DecodeImage: # load image
  64. img_mode: RGB
  65. channel_first: False
  66. - VQATokenLabelEncode: # Class handling label
  67. contains_re: False
  68. algorithm: *algorithm
  69. class_path: *class_path
  70. use_textline_bbox_info: &use_textline_bbox_info True
  71. # one of [None, "tb-yx"]
  72. order_method: &order_method "tb-yx"
  73. - VQATokenPad:
  74. max_seq_len: &max_seq_len 512
  75. return_attention_mask: True
  76. - VQASerTokenChunk:
  77. max_seq_len: *max_seq_len
  78. - Resize:
  79. size: [224,224]
  80. - NormalizeImage:
  81. scale: 1
  82. mean: [ 123.675, 116.28, 103.53 ]
  83. std: [ 58.395, 57.12, 57.375 ]
  84. order: 'hwc'
  85. - ToCHWImage:
  86. - KeepKeys:
  87. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
  88. loader:
  89. shuffle: True
  90. drop_last: False
  91. batch_size_per_card: 8
  92. num_workers: 4
  93. Eval:
  94. dataset:
  95. name: SimpleDataSet
  96. data_dir: train_data/XFUND/zh_val/image
  97. label_file_list:
  98. - train_data/XFUND/zh_val/val.json
  99. transforms:
  100. - DecodeImage: # load image
  101. img_mode: RGB
  102. channel_first: False
  103. - VQATokenLabelEncode: # Class handling label
  104. contains_re: False
  105. algorithm: *algorithm
  106. class_path: *class_path
  107. use_textline_bbox_info: *use_textline_bbox_info
  108. order_method: *order_method
  109. - VQATokenPad:
  110. max_seq_len: *max_seq_len
  111. return_attention_mask: True
  112. - VQASerTokenChunk:
  113. max_seq_len: *max_seq_len
  114. - Resize:
  115. size: [224,224]
  116. - NormalizeImage:
  117. scale: 1
  118. mean: [ 123.675, 116.28, 103.53 ]
  119. std: [ 58.395, 57.12, 57.375 ]
  120. order: 'hwc'
  121. - ToCHWImage:
  122. - KeepKeys:
  123. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
  124. loader:
  125. shuffle: False
  126. drop_last: False
  127. batch_size_per_card: 8
  128. num_workers: 4