ser_vi_layoutxlm_xfund_zh_udml.yml 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. Global:
  2. use_gpu: True
  3. epoch_num: &epoch_num 200
  4. log_smooth_window: 10
  5. print_batch_step: 10
  6. save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh_udml
  7. save_epoch_step: 2000
  8. # evaluation is run every 10 iterations after the 0th iteration
  9. eval_batch_step: [ 0, 19 ]
  10. cal_metric_during_train: False
  11. save_inference_dir:
  12. use_visualdl: False
  13. seed: 2022
  14. infer_img: ppstructure/docs/kie/input/zh_val_42.jpg
  15. save_res_path: ./output/ser_layoutxlm_xfund_zh/res
  16. Architecture:
  17. model_type: &model_type "kie"
  18. name: DistillationModel
  19. algorithm: Distillation
  20. Models:
  21. Teacher:
  22. pretrained:
  23. freeze_params: false
  24. return_all_feats: true
  25. model_type: *model_type
  26. algorithm: &algorithm "LayoutXLM"
  27. Transform:
  28. Backbone:
  29. name: LayoutXLMForSer
  30. pretrained: True
  31. # one of base or vi
  32. mode: vi
  33. checkpoints:
  34. num_classes: &num_classes 7
  35. Student:
  36. pretrained:
  37. freeze_params: false
  38. return_all_feats: true
  39. model_type: *model_type
  40. algorithm: *algorithm
  41. Transform:
  42. Backbone:
  43. name: LayoutXLMForSer
  44. pretrained: True
  45. # one of base or vi
  46. mode: vi
  47. checkpoints:
  48. num_classes: *num_classes
  49. Loss:
  50. name: CombinedLoss
  51. loss_config_list:
  52. - DistillationVQASerTokenLayoutLMLoss:
  53. weight: 1.0
  54. model_name_list: ["Student", "Teacher"]
  55. key: backbone_out
  56. num_classes: *num_classes
  57. - DistillationSERDMLLoss:
  58. weight: 1.0
  59. act: "softmax"
  60. use_log: true
  61. model_name_pairs:
  62. - ["Student", "Teacher"]
  63. key: backbone_out
  64. - DistillationVQADistanceLoss:
  65. weight: 0.5
  66. mode: "l2"
  67. model_name_pairs:
  68. - ["Student", "Teacher"]
  69. key: hidden_states_5
  70. name: "loss_5"
  71. - DistillationVQADistanceLoss:
  72. weight: 0.5
  73. mode: "l2"
  74. model_name_pairs:
  75. - ["Student", "Teacher"]
  76. key: hidden_states_8
  77. name: "loss_8"
  78. Optimizer:
  79. name: AdamW
  80. beta1: 0.9
  81. beta2: 0.999
  82. lr:
  83. name: Linear
  84. learning_rate: 0.00005
  85. epochs: *epoch_num
  86. warmup_epoch: 10
  87. regularizer:
  88. name: L2
  89. factor: 0.00000
  90. PostProcess:
  91. name: DistillationSerPostProcess
  92. model_name: ["Student", "Teacher"]
  93. key: backbone_out
  94. class_path: &class_path train_data/XFUND/class_list_xfun.txt
  95. Metric:
  96. name: DistillationMetric
  97. base_metric_name: VQASerTokenMetric
  98. main_indicator: hmean
  99. key: "Student"
  100. Train:
  101. dataset:
  102. name: SimpleDataSet
  103. data_dir: train_data/XFUND/zh_train/image
  104. label_file_list:
  105. - train_data/XFUND/zh_train/train.json
  106. ratio_list: [ 1.0 ]
  107. transforms:
  108. - DecodeImage: # load image
  109. img_mode: RGB
  110. channel_first: False
  111. - VQATokenLabelEncode: # Class handling label
  112. contains_re: False
  113. algorithm: *algorithm
  114. class_path: *class_path
  115. # one of [None, "tb-yx"]
  116. order_method: &order_method "tb-yx"
  117. - VQATokenPad:
  118. max_seq_len: &max_seq_len 512
  119. return_attention_mask: True
  120. - VQASerTokenChunk:
  121. max_seq_len: *max_seq_len
  122. - Resize:
  123. size: [224,224]
  124. - NormalizeImage:
  125. scale: 1
  126. mean: [ 123.675, 116.28, 103.53 ]
  127. std: [ 58.395, 57.12, 57.375 ]
  128. order: 'hwc'
  129. - ToCHWImage:
  130. - KeepKeys:
  131. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
  132. loader:
  133. shuffle: True
  134. drop_last: False
  135. batch_size_per_card: 4
  136. num_workers: 4
  137. Eval:
  138. dataset:
  139. name: SimpleDataSet
  140. data_dir: train_data/XFUND/zh_val/image
  141. label_file_list:
  142. - train_data/XFUND/zh_val/val.json
  143. transforms:
  144. - DecodeImage: # load image
  145. img_mode: RGB
  146. channel_first: False
  147. - VQATokenLabelEncode: # Class handling label
  148. contains_re: False
  149. algorithm: *algorithm
  150. class_path: *class_path
  151. order_method: *order_method
  152. - VQATokenPad:
  153. max_seq_len: *max_seq_len
  154. return_attention_mask: True
  155. - VQASerTokenChunk:
  156. max_seq_len: *max_seq_len
  157. - Resize:
  158. size: [224,224]
  159. - NormalizeImage:
  160. scale: 1
  161. mean: [ 123.675, 116.28, 103.53 ]
  162. std: [ 58.395, 57.12, 57.375 ]
  163. order: 'hwc'
  164. - ToCHWImage:
  165. - KeepKeys:
  166. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
  167. loader:
  168. shuffle: False
  169. drop_last: False
  170. batch_size_per_card: 8
  171. num_workers: 4