re_vi_layoutxlm_xfund_zh.yml 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. Global:
  2. use_gpu: True
  3. epoch_num: &epoch_num 130
  4. log_smooth_window: 10
  5. print_batch_step: 10
  6. save_model_dir: ./output/re_vi_layoutxlm_xfund_zh
  7. save_epoch_step: 2000
  8. # evaluation is run every 10 iterations after the 0th iteration
  9. eval_batch_step: [ 0, 19 ]
  10. cal_metric_during_train: False
  11. save_inference_dir:
  12. use_visualdl: False
  13. seed: 2022
  14. infer_img: ppstructure/docs/kie/input/zh_val_21.jpg
  15. save_res_path: ./output/re/xfund_zh/with_gt
  16. kie_rec_model_dir:
  17. kie_det_model_dir:
  18. Architecture:
  19. model_type: kie
  20. algorithm: &algorithm "LayoutXLM"
  21. Transform:
  22. Backbone:
  23. name: LayoutXLMForRe
  24. pretrained: True
  25. mode: vi
  26. checkpoints:
  27. Loss:
  28. name: LossFromOutput
  29. key: loss
  30. reduction: mean
  31. Optimizer:
  32. name: AdamW
  33. beta1: 0.9
  34. beta2: 0.999
  35. clip_norm: 10
  36. lr:
  37. learning_rate: 0.00005
  38. warmup_epoch: 10
  39. regularizer:
  40. name: L2
  41. factor: 0.00000
  42. PostProcess:
  43. name: VQAReTokenLayoutLMPostProcess
  44. Metric:
  45. name: VQAReTokenMetric
  46. main_indicator: hmean
  47. Train:
  48. dataset:
  49. name: SimpleDataSet
  50. data_dir: train_data/XFUND/zh_train/image
  51. label_file_list:
  52. - train_data/XFUND/zh_train/train.json
  53. ratio_list: [ 1.0 ]
  54. transforms:
  55. - DecodeImage: # load image
  56. img_mode: RGB
  57. channel_first: False
  58. - VQATokenLabelEncode: # Class handling label
  59. contains_re: True
  60. algorithm: *algorithm
  61. class_path: &class_path train_data/XFUND/class_list_xfun.txt
  62. use_textline_bbox_info: &use_textline_bbox_info True
  63. order_method: &order_method "tb-yx"
  64. - VQATokenPad:
  65. max_seq_len: &max_seq_len 512
  66. return_attention_mask: True
  67. - VQAReTokenRelation:
  68. - VQAReTokenChunk:
  69. max_seq_len: *max_seq_len
  70. - TensorizeEntitiesRelations:
  71. - Resize:
  72. size: [224,224]
  73. - NormalizeImage:
  74. scale: 1
  75. mean: [ 123.675, 116.28, 103.53 ]
  76. std: [ 58.395, 57.12, 57.375 ]
  77. order: 'hwc'
  78. - ToCHWImage:
  79. - KeepKeys:
  80. keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
  81. loader:
  82. shuffle: True
  83. drop_last: False
  84. batch_size_per_card: 2
  85. num_workers: 4
  86. Eval:
  87. dataset:
  88. name: SimpleDataSet
  89. data_dir: train_data/XFUND/zh_val/image
  90. label_file_list:
  91. - train_data/XFUND/zh_val/val.json
  92. transforms:
  93. - DecodeImage: # load image
  94. img_mode: RGB
  95. channel_first: False
  96. - VQATokenLabelEncode: # Class handling label
  97. contains_re: True
  98. algorithm: *algorithm
  99. class_path: *class_path
  100. use_textline_bbox_info: *use_textline_bbox_info
  101. order_method: *order_method
  102. - VQATokenPad:
  103. max_seq_len: *max_seq_len
  104. return_attention_mask: True
  105. - VQAReTokenRelation:
  106. - VQAReTokenChunk:
  107. max_seq_len: *max_seq_len
  108. - TensorizeEntitiesRelations:
  109. - Resize:
  110. size: [224,224]
  111. - NormalizeImage:
  112. scale: 1
  113. mean: [ 123.675, 116.28, 103.53 ]
  114. std: [ 58.395, 57.12, 57.375 ]
  115. order: 'hwc'
  116. - ToCHWImage:
  117. - KeepKeys:
  118. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
  119. loader:
  120. shuffle: False
  121. drop_last: False
  122. batch_size_per_card: 8
  123. num_workers: 8