ser_layoutlmv2_xfund_zh.yml 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. Global:
  2. use_gpu: True
  3. epoch_num: &epoch_num 200
  4. log_smooth_window: 10
  5. print_batch_step: 10
  6. save_model_dir: ./output/ser_layoutlmv2_xfund_zh/
  7. save_epoch_step: 2000
  8. # evaluation is run every 10 iterations after the 0th iteration
  9. eval_batch_step: [ 0, 19 ]
  10. cal_metric_during_train: False
  11. save_inference_dir:
  12. use_visualdl: False
  13. seed: 2022
  14. infer_img: ppstructure/docs/kie/input/zh_val_42.jpg
  15. save_res_path: ./output/ser_layoutlmv2_xfund_zh/res/
  16. Architecture:
  17. model_type: kie
  18. algorithm: &algorithm "LayoutLMv2"
  19. Transform:
  20. Backbone:
  21. name: LayoutLMv2ForSer
  22. pretrained: True
  23. checkpoints:
  24. num_classes: &num_classes 7
  25. Loss:
  26. name: VQASerTokenLayoutLMLoss
  27. num_classes: *num_classes
  28. key: "backbone_out"
  29. Optimizer:
  30. name: AdamW
  31. beta1: 0.9
  32. beta2: 0.999
  33. lr:
  34. name: Linear
  35. learning_rate: 0.00005
  36. epochs: *epoch_num
  37. warmup_epoch: 2
  38. regularizer:
  39. name: L2
  40. factor: 0.00000
  41. PostProcess:
  42. name: VQASerTokenLayoutLMPostProcess
  43. class_path: &class_path train_data/XFUND/class_list_xfun.txt
  44. Metric:
  45. name: VQASerTokenMetric
  46. main_indicator: hmean
  47. Train:
  48. dataset:
  49. name: SimpleDataSet
  50. data_dir: train_data/XFUND/zh_train/image
  51. label_file_list:
  52. - train_data/XFUND/zh_train/train.json
  53. transforms:
  54. - DecodeImage: # load image
  55. img_mode: RGB
  56. channel_first: False
  57. - VQATokenLabelEncode: # Class handling label
  58. contains_re: False
  59. algorithm: *algorithm
  60. class_path: *class_path
  61. - VQATokenPad:
  62. max_seq_len: &max_seq_len 512
  63. return_attention_mask: True
  64. - VQASerTokenChunk:
  65. max_seq_len: *max_seq_len
  66. - Resize:
  67. size: [224,224]
  68. - NormalizeImage:
  69. scale: 1
  70. mean: [ 123.675, 116.28, 103.53 ]
  71. std: [ 58.395, 57.12, 57.375 ]
  72. order: 'hwc'
  73. - ToCHWImage:
  74. - KeepKeys:
  75. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
  76. loader:
  77. shuffle: True
  78. drop_last: False
  79. batch_size_per_card: 8
  80. num_workers: 4
  81. Eval:
  82. dataset:
  83. name: SimpleDataSet
  84. data_dir: train_data/XFUND/zh_val/image
  85. label_file_list:
  86. - train_data/XFUND/zh_val/val.json
  87. transforms:
  88. - DecodeImage: # load image
  89. img_mode: RGB
  90. channel_first: False
  91. - VQATokenLabelEncode: # Class handling label
  92. contains_re: False
  93. algorithm: *algorithm
  94. class_path: *class_path
  95. - VQATokenPad:
  96. max_seq_len: *max_seq_len
  97. return_attention_mask: True
  98. - VQASerTokenChunk:
  99. max_seq_len: *max_seq_len
  100. - Resize:
  101. size: [224,224]
  102. - NormalizeImage:
  103. scale: 1
  104. mean: [ 123.675, 116.28, 103.53 ]
  105. std: [ 58.395, 57.12, 57.375 ]
  106. order: 'hwc'
  107. - ToCHWImage:
  108. - KeepKeys:
  109. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
  110. loader:
  111. shuffle: False
  112. drop_last: False
  113. batch_size_per_card: 8
  114. num_workers: 4