re_layoutlmv2_xfund_zh.yml 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. Global:
  2. use_gpu: True
  3. epoch_num: &epoch_num 200
  4. log_smooth_window: 10
  5. print_batch_step: 10
  6. save_model_dir: ./output/re_layoutlmv2_xfund_zh
  7. save_epoch_step: 2000
  8. # evaluation is run every 10 iterations after the 0th iteration
  9. eval_batch_step: [ 0, 19 ]
  10. cal_metric_during_train: False
  11. save_inference_dir:
  12. use_visualdl: False
  13. seed: 2022
  14. infer_img: ppstructure/docs/kie/input/zh_val_21.jpg
  15. save_res_path: ./output/re_layoutlmv2_xfund_zh/res/
  16. Architecture:
  17. model_type: kie
  18. algorithm: &algorithm "LayoutLMv2"
  19. Transform:
  20. Backbone:
  21. name: LayoutLMv2ForRe
  22. pretrained: True
  23. checkpoints:
  24. Loss:
  25. name: LossFromOutput
  26. key: loss
  27. reduction: mean
  28. Optimizer:
  29. name: AdamW
  30. beta1: 0.9
  31. beta2: 0.999
  32. clip_norm: 10
  33. lr:
  34. learning_rate: 0.00005
  35. warmup_epoch: 10
  36. regularizer:
  37. name: L2
  38. factor: 0.00000
  39. PostProcess:
  40. name: VQAReTokenLayoutLMPostProcess
  41. Metric:
  42. name: VQAReTokenMetric
  43. main_indicator: hmean
  44. Train:
  45. dataset:
  46. name: SimpleDataSet
  47. data_dir: train_data/XFUND/zh_train/image
  48. label_file_list:
  49. - train_data/XFUND/zh_train/train.json
  50. ratio_list: [ 1.0 ]
  51. transforms:
  52. - DecodeImage: # load image
  53. img_mode: RGB
  54. channel_first: False
  55. - VQATokenLabelEncode: # Class handling label
  56. contains_re: True
  57. algorithm: *algorithm
  58. class_path: &class_path train_data/XFUND/class_list_xfun.txt
  59. - VQATokenPad:
  60. max_seq_len: &max_seq_len 512
  61. return_attention_mask: True
  62. - VQAReTokenRelation:
  63. - VQAReTokenChunk:
  64. max_seq_len: *max_seq_len
  65. - Resize:
  66. size: [224,224]
  67. - NormalizeImage:
  68. scale: 1./255.
  69. mean: [0.485, 0.456, 0.406]
  70. std: [0.229, 0.224, 0.225]
  71. order: 'hwc'
  72. - ToCHWImage:
  73. - KeepKeys:
  74. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids','image', 'entities', 'relations'] # dataloader will return list in this order
  75. loader:
  76. shuffle: True
  77. drop_last: False
  78. batch_size_per_card: 8
  79. num_workers: 8
  80. collate_fn: ListCollator
  81. Eval:
  82. dataset:
  83. name: SimpleDataSet
  84. data_dir: train_data/XFUND/zh_val/image
  85. label_file_list:
  86. - train_data/XFUND/zh_val/val.json
  87. transforms:
  88. - DecodeImage: # load image
  89. img_mode: RGB
  90. channel_first: False
  91. - VQATokenLabelEncode: # Class handling label
  92. contains_re: True
  93. algorithm: *algorithm
  94. class_path: *class_path
  95. - VQATokenPad:
  96. max_seq_len: *max_seq_len
  97. return_attention_mask: True
  98. - VQAReTokenRelation:
  99. - VQAReTokenChunk:
  100. max_seq_len: *max_seq_len
  101. - Resize:
  102. size: [224,224]
  103. - NormalizeImage:
  104. scale: 1./255.
  105. mean: [0.485, 0.456, 0.406]
  106. std: [0.229, 0.224, 0.225]
  107. order: 'hwc'
  108. - ToCHWImage:
  109. - KeepKeys:
  110. keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image','entities', 'relations'] # dataloader will return list in this order
  111. loader:
  112. shuffle: False
  113. drop_last: False
  114. batch_size_per_card: 8
  115. num_workers: 8
  116. collate_fn: ListCollator