re_vi_layoutxlm_xfund_zh_udml.yml 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. Global:
  2. use_gpu: True
  3. epoch_num: &epoch_num 130
  4. log_smooth_window: 10
  5. print_batch_step: 10
  6. save_model_dir: ./output/re_vi_layoutxlm_xfund_zh_udml
  7. save_epoch_step: 2000
  8. # evaluation is run every 10 iterations after the 0th iteration
  9. eval_batch_step: [ 0, 19 ]
  10. cal_metric_during_train: False
  11. save_inference_dir:
  12. use_visualdl: False
  13. seed: 2022
  14. infer_img: ppstructure/docs/kie/input/zh_val_21.jpg
  15. save_res_path: ./output/re/xfund_zh/with_gt
  16. Architecture:
  17. model_type: &model_type "kie"
  18. name: DistillationModel
  19. algorithm: Distillation
  20. Models:
  21. Teacher:
  22. pretrained:
  23. freeze_params: false
  24. return_all_feats: true
  25. model_type: *model_type
  26. algorithm: &algorithm "LayoutXLM"
  27. Transform:
  28. Backbone:
  29. name: LayoutXLMForRe
  30. pretrained: True
  31. mode: vi
  32. checkpoints:
  33. Student:
  34. pretrained:
  35. freeze_params: false
  36. return_all_feats: true
  37. model_type: *model_type
  38. algorithm: *algorithm
  39. Transform:
  40. Backbone:
  41. name: LayoutXLMForRe
  42. pretrained: True
  43. mode: vi
  44. checkpoints:
  45. Loss:
  46. name: CombinedLoss
  47. loss_config_list:
  48. - DistillationLossFromOutput:
  49. weight: 1.0
  50. model_name_list: ["Student", "Teacher"]
  51. key: loss
  52. reduction: mean
  53. - DistillationVQADistanceLoss:
  54. weight: 0.5
  55. mode: "l2"
  56. model_name_pairs:
  57. - ["Student", "Teacher"]
  58. key: hidden_states
  59. index: 5
  60. name: "loss_5"
  61. - DistillationVQADistanceLoss:
  62. weight: 0.5
  63. mode: "l2"
  64. model_name_pairs:
  65. - ["Student", "Teacher"]
  66. key: hidden_states
  67. index: 8
  68. name: "loss_8"
  69. Optimizer:
  70. name: AdamW
  71. beta1: 0.9
  72. beta2: 0.999
  73. clip_norm: 10
  74. lr:
  75. learning_rate: 0.00005
  76. warmup_epoch: 10
  77. regularizer:
  78. name: L2
  79. factor: 0.00000
  80. PostProcess:
  81. name: DistillationRePostProcess
  82. model_name: ["Student", "Teacher"]
  83. key: null
  84. Metric:
  85. name: DistillationMetric
  86. base_metric_name: VQAReTokenMetric
  87. main_indicator: hmean
  88. key: "Student"
  89. Train:
  90. dataset:
  91. name: SimpleDataSet
  92. data_dir: train_data/XFUND/zh_train/image
  93. label_file_list:
  94. - train_data/XFUND/zh_train/train.json
  95. ratio_list: [ 1.0 ]
  96. transforms:
  97. - DecodeImage: # load image
  98. img_mode: RGB
  99. channel_first: False
  100. - VQATokenLabelEncode: # Class handling label
  101. contains_re: True
  102. algorithm: *algorithm
  103. class_path: &class_path train_data/XFUND/class_list_xfun.txt
  104. use_textline_bbox_info: &use_textline_bbox_info True
  105. # [None, "tb-yx"]
  106. order_method: &order_method "tb-yx"
  107. - VQATokenPad:
  108. max_seq_len: &max_seq_len 512
  109. return_attention_mask: True
  110. - VQAReTokenRelation:
  111. - VQAReTokenChunk:
  112. max_seq_len: *max_seq_len
  113. - TensorizeEntitiesRelations:
  114. - Resize:
  115. size: [224,224]
  116. - NormalizeImage:
  117. scale: 1
  118. mean: [ 123.675, 116.28, 103.53 ]
  119. std: [ 58.395, 57.12, 57.375 ]
  120. order: 'hwc'
  121. - ToCHWImage:
  122. - KeepKeys:
  123. keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
  124. loader:
  125. shuffle: True
  126. drop_last: False
  127. batch_size_per_card: 2
  128. num_workers: 4
  129. Eval:
  130. dataset:
  131. name: SimpleDataSet
  132. data_dir: train_data/XFUND/zh_val/image
  133. label_file_list:
  134. - train_data/XFUND/zh_val/val.json
  135. transforms:
  136. - DecodeImage: # load image
  137. img_mode: RGB
  138. channel_first: False
  139. - VQATokenLabelEncode: # Class handling label
  140. contains_re: True
  141. algorithm: *algorithm
  142. class_path: *class_path
  143. use_textline_bbox_info: *use_textline_bbox_info
  144. order_method: *order_method
  145. - VQATokenPad:
  146. max_seq_len: *max_seq_len
  147. return_attention_mask: True
  148. - VQAReTokenRelation:
  149. - VQAReTokenChunk:
  150. max_seq_len: *max_seq_len
  151. - TensorizeEntitiesRelations:
  152. - Resize:
  153. size: [224,224]
  154. - NormalizeImage:
  155. scale: 1
  156. mean: [ 123.675, 116.28, 103.53 ]
  157. std: [ 58.395, 57.12, 57.375 ]
  158. order: 'hwc'
  159. - ToCHWImage:
  160. - KeepKeys:
  161. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order
  162. loader:
  163. shuffle: False
  164. drop_last: False
  165. batch_size_per_card: 8
  166. num_workers: 8