ch_PP-OCRv2_det_dml.yml 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. Global:
  2. use_gpu: true
  3. epoch_num: 1200
  4. log_smooth_window: 20
  5. print_batch_step: 2
  6. save_model_dir: ./output/ch_db_mv3/
  7. save_epoch_step: 1200
  8. # evaluation is run every 5000 iterations after the 4000th iteration
  9. eval_batch_step: [3000, 2000]
  10. cal_metric_during_train: False
  11. pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
  12. checkpoints:
  13. save_inference_dir:
  14. use_visualdl: False
  15. infer_img: doc/imgs_en/img_10.jpg
  16. save_res_path: ./output/det_db/predicts_db.txt
  17. Architecture:
  18. name: DistillationModel
  19. algorithm: Distillation
  20. model_type: det
  21. Models:
  22. Student:
  23. pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
  24. freeze_params: false
  25. return_all_feats: false
  26. model_type: det
  27. algorithm: DB
  28. Backbone:
  29. name: MobileNetV3
  30. scale: 0.5
  31. model_name: large
  32. disable_se: True
  33. Neck:
  34. name: DBFPN
  35. out_channels: 96
  36. Head:
  37. name: DBHead
  38. k: 50
  39. Teacher:
  40. pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
  41. freeze_params: false
  42. return_all_feats: false
  43. model_type: det
  44. algorithm: DB
  45. Transform:
  46. Backbone:
  47. name: MobileNetV3
  48. scale: 0.5
  49. model_name: large
  50. disable_se: True
  51. Neck:
  52. name: DBFPN
  53. out_channels: 96
  54. Head:
  55. name: DBHead
  56. k: 50
  57. Loss:
  58. name: CombinedLoss
  59. loss_config_list:
  60. - DistillationDMLLoss:
  61. model_name_pairs:
  62. - ["Student", "Teacher"]
  63. maps_name: "thrink_maps"
  64. weight: 1.0
  65. # act: None
  66. model_name_pairs: ["Student", "Teacher"]
  67. key: maps
  68. - DistillationDBLoss:
  69. weight: 1.0
  70. model_name_list: ["Student", "Teacher"]
  71. # key: maps
  72. name: DBLoss
  73. balance_loss: true
  74. main_loss_type: DiceLoss
  75. alpha: 5
  76. beta: 10
  77. ohem_ratio: 3
  78. Optimizer:
  79. name: Adam
  80. beta1: 0.9
  81. beta2: 0.999
  82. lr:
  83. name: Cosine
  84. learning_rate: 0.001
  85. warmup_epoch: 2
  86. regularizer:
  87. name: 'L2'
  88. factor: 0
  89. PostProcess:
  90. name: DistillationDBPostProcess
  91. model_name: ["Student", "Teacher"]
  92. key: head_out
  93. thresh: 0.3
  94. box_thresh: 0.6
  95. max_candidates: 1000
  96. unclip_ratio: 1.5
  97. Metric:
  98. name: DistillationMetric
  99. base_metric_name: DetMetric
  100. main_indicator: hmean
  101. key: "Student"
  102. Train:
  103. dataset:
  104. name: SimpleDataSet
  105. data_dir: ./train_data/icdar2015/text_localization/
  106. label_file_list:
  107. - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
  108. ratio_list: [1.0]
  109. transforms:
  110. - DecodeImage: # load image
  111. img_mode: BGR
  112. channel_first: False
  113. - DetLabelEncode: # Class handling label
  114. - CopyPaste:
  115. - IaaAugment:
  116. augmenter_args:
  117. - { 'type': Fliplr, 'args': { 'p': 0.5 } }
  118. - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
  119. - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
  120. - EastRandomCropData:
  121. size: [960, 960]
  122. max_tries: 50
  123. keep_ratio: true
  124. - MakeBorderMap:
  125. shrink_ratio: 0.4
  126. thresh_min: 0.3
  127. thresh_max: 0.7
  128. - MakeShrinkMap:
  129. shrink_ratio: 0.4
  130. min_text_size: 8
  131. - NormalizeImage:
  132. scale: 1./255.
  133. mean: [0.485, 0.456, 0.406]
  134. std: [0.229, 0.224, 0.225]
  135. order: 'hwc'
  136. - ToCHWImage:
  137. - KeepKeys:
  138. keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
  139. loader:
  140. shuffle: True
  141. drop_last: False
  142. batch_size_per_card: 8
  143. num_workers: 4
  144. Eval:
  145. dataset:
  146. name: SimpleDataSet
  147. data_dir: ./train_data/icdar2015/text_localization/
  148. label_file_list:
  149. - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
  150. transforms:
  151. - DecodeImage: # load image
  152. img_mode: BGR
  153. channel_first: False
  154. - DetLabelEncode: # Class handling label
  155. - DetResizeForTest:
  156. # image_shape: [736, 1280]
  157. - NormalizeImage:
  158. scale: 1./255.
  159. mean: [0.485, 0.456, 0.406]
  160. std: [0.229, 0.224, 0.225]
  161. order: 'hwc'
  162. - ToCHWImage:
  163. - KeepKeys:
  164. keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
  165. loader:
  166. shuffle: False
  167. drop_last: False
  168. batch_size_per_card: 1 # must be 1
  169. num_workers: 2