ser_layoutxlm_xfund_zh.yml 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. Global:
  2. use_gpu: True
  3. epoch_num: &epoch_num 200
  4. log_smooth_window: 10
  5. print_batch_step: 10
  6. save_model_dir: ./output/ser_layoutxlm_xfund_zh
  7. save_epoch_step: 2000
  8. # evaluation is run every 10 iterations after the 0th iteration
  9. eval_batch_step: [ 0, 187 ]
  10. cal_metric_during_train: False
  11. save_inference_dir:
  12. use_visualdl: False
  13. seed: 2022
  14. infer_img: ppstructure/docs/kie/input/zh_val_42.jpg
  15. save_res_path: ./output/ser_layoutxlm_xfund_zh/res
  16. Architecture:
  17. model_type: kie
  18. algorithm: &algorithm "LayoutXLM"
  19. Transform:
  20. Backbone:
  21. name: LayoutXLMForSer
  22. pretrained: True
  23. checkpoints:
  24. num_classes: &num_classes 7
  25. Loss:
  26. name: VQASerTokenLayoutLMLoss
  27. num_classes: *num_classes
  28. key: "backbone_out"
  29. Optimizer:
  30. name: AdamW
  31. beta1: 0.9
  32. beta2: 0.999
  33. lr:
  34. name: Linear
  35. learning_rate: 0.00005
  36. epochs: *epoch_num
  37. warmup_epoch: 2
  38. regularizer:
  39. name: L2
  40. factor: 0.00000
  41. PostProcess:
  42. name: VQASerTokenLayoutLMPostProcess
  43. class_path: &class_path train_data/XFUND/class_list_xfun.txt
  44. Metric:
  45. name: VQASerTokenMetric
  46. main_indicator: hmean
  47. Train:
  48. dataset:
  49. name: SimpleDataSet
  50. data_dir: train_data/XFUND/zh_train/image
  51. label_file_list:
  52. - train_data/XFUND/zh_train/train.json
  53. ratio_list: [ 1.0 ]
  54. transforms:
  55. - DecodeImage: # load image
  56. img_mode: RGB
  57. channel_first: False
  58. - VQATokenLabelEncode: # Class handling label
  59. contains_re: False
  60. algorithm: *algorithm
  61. class_path: *class_path
  62. - VQATokenPad:
  63. max_seq_len: &max_seq_len 512
  64. return_attention_mask: True
  65. - VQASerTokenChunk:
  66. max_seq_len: *max_seq_len
  67. - Resize:
  68. size: [224,224]
  69. - NormalizeImage:
  70. scale: 1
  71. mean: [ 123.675, 116.28, 103.53 ]
  72. std: [ 58.395, 57.12, 57.375 ]
  73. order: 'hwc'
  74. - ToCHWImage:
  75. - KeepKeys:
  76. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
  77. loader:
  78. shuffle: True
  79. drop_last: False
  80. batch_size_per_card: 8
  81. num_workers: 4
  82. Eval:
  83. dataset:
  84. name: SimpleDataSet
  85. data_dir: train_data/XFUND/zh_val/image
  86. label_file_list:
  87. - train_data/XFUND/zh_val/val.json
  88. transforms:
  89. - DecodeImage: # load image
  90. img_mode: RGB
  91. channel_first: False
  92. - VQATokenLabelEncode: # Class handling label
  93. contains_re: False
  94. algorithm: *algorithm
  95. class_path: *class_path
  96. - VQATokenPad:
  97. max_seq_len: *max_seq_len
  98. return_attention_mask: True
  99. - VQASerTokenChunk:
  100. max_seq_len: *max_seq_len
  101. - Resize:
  102. size: [224,224]
  103. - NormalizeImage:
  104. scale: 1
  105. mean: [ 123.675, 116.28, 103.53 ]
  106. std: [ 58.395, 57.12, 57.375 ]
  107. order: 'hwc'
  108. - ToCHWImage:
  109. - KeepKeys:
  110. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
  111. loader:
  112. shuffle: False
  113. drop_last: False
  114. batch_size_per_card: 8
  115. num_workers: 4