ch_PP-OCRv3_det_cml.yml 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. Global:
  2. debug: false
  3. use_gpu: true
  4. epoch_num: 500
  5. log_smooth_window: 20
  6. print_batch_step: 10
  7. save_model_dir: ./output/ch_PP-OCR_v3_det/
  8. save_epoch_step: 100
  9. eval_batch_step:
  10. - 0
  11. - 400
  12. cal_metric_during_train: false
  13. pretrained_model: null
  14. checkpoints: null
  15. save_inference_dir: null
  16. use_visualdl: false
  17. infer_img: doc/imgs_en/img_10.jpg
  18. save_res_path: ./checkpoints/det_db/predicts_db.txt
  19. distributed: true
  20. Architecture:
  21. name: DistillationModel
  22. algorithm: Distillation
  23. model_type: det
  24. Models:
  25. Student:
  26. pretrained:
  27. model_type: det
  28. algorithm: DB
  29. Transform: null
  30. Backbone:
  31. name: MobileNetV3
  32. scale: 0.5
  33. model_name: large
  34. disable_se: true
  35. Neck:
  36. name: RSEFPN
  37. out_channels: 96
  38. shortcut: True
  39. Head:
  40. name: DBHead
  41. k: 50
  42. Student2:
  43. pretrained:
  44. model_type: det
  45. algorithm: DB
  46. Transform: null
  47. Backbone:
  48. name: MobileNetV3
  49. scale: 0.5
  50. model_name: large
  51. disable_se: true
  52. Neck:
  53. name: RSEFPN
  54. out_channels: 96
  55. shortcut: True
  56. Head:
  57. name: DBHead
  58. k: 50
  59. Teacher:
  60. freeze_params: true
  61. return_all_feats: false
  62. model_type: det
  63. algorithm: DB
  64. Backbone:
  65. name: ResNet_vd
  66. in_channels: 3
  67. layers: 50
  68. Neck:
  69. name: LKPAN
  70. out_channels: 256
  71. Head:
  72. name: DBHead
  73. kernel_list: [7,2,2]
  74. k: 50
  75. Loss:
  76. name: CombinedLoss
  77. loss_config_list:
  78. - DistillationDilaDBLoss:
  79. weight: 1.0
  80. model_name_pairs:
  81. - ["Student", "Teacher"]
  82. - ["Student2", "Teacher"]
  83. key: maps
  84. balance_loss: true
  85. main_loss_type: DiceLoss
  86. alpha: 5
  87. beta: 10
  88. ohem_ratio: 3
  89. - DistillationDMLLoss:
  90. model_name_pairs:
  91. - ["Student", "Student2"]
  92. maps_name: "thrink_maps"
  93. weight: 1.0
  94. model_name_pairs: ["Student", "Student2"]
  95. key: maps
  96. - DistillationDBLoss:
  97. weight: 1.0
  98. model_name_list: ["Student", "Student2"]
  99. balance_loss: true
  100. main_loss_type: DiceLoss
  101. alpha: 5
  102. beta: 10
  103. ohem_ratio: 3
  104. Optimizer:
  105. name: Adam
  106. beta1: 0.9
  107. beta2: 0.999
  108. lr:
  109. name: Cosine
  110. learning_rate: 0.001
  111. warmup_epoch: 2
  112. regularizer:
  113. name: L2
  114. factor: 5.0e-05
  115. PostProcess:
  116. name: DistillationDBPostProcess
  117. model_name: ["Student"]
  118. key: head_out
  119. thresh: 0.3
  120. box_thresh: 0.6
  121. max_candidates: 1000
  122. unclip_ratio: 1.5
  123. Metric:
  124. name: DistillationMetric
  125. base_metric_name: DetMetric
  126. main_indicator: hmean
  127. key: "Student"
  128. Train:
  129. dataset:
  130. name: SimpleDataSet
  131. data_dir: ./train_data/icdar2015/text_localization/
  132. label_file_list:
  133. - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
  134. ratio_list: [1.0]
  135. transforms:
  136. - DecodeImage:
  137. img_mode: BGR
  138. channel_first: false
  139. - DetLabelEncode: null
  140. - CopyPaste:
  141. - IaaAugment:
  142. augmenter_args:
  143. - type: Fliplr
  144. args:
  145. p: 0.5
  146. - type: Affine
  147. args:
  148. rotate:
  149. - -10
  150. - 10
  151. - type: Resize
  152. args:
  153. size:
  154. - 0.5
  155. - 3
  156. - EastRandomCropData:
  157. size:
  158. - 960
  159. - 960
  160. max_tries: 50
  161. keep_ratio: true
  162. - MakeBorderMap:
  163. shrink_ratio: 0.4
  164. thresh_min: 0.3
  165. thresh_max: 0.7
  166. - MakeShrinkMap:
  167. shrink_ratio: 0.4
  168. min_text_size: 8
  169. - NormalizeImage:
  170. scale: 1./255.
  171. mean:
  172. - 0.485
  173. - 0.456
  174. - 0.406
  175. std:
  176. - 0.229
  177. - 0.224
  178. - 0.225
  179. order: hwc
  180. - ToCHWImage: null
  181. - KeepKeys:
  182. keep_keys:
  183. - image
  184. - threshold_map
  185. - threshold_mask
  186. - shrink_map
  187. - shrink_mask
  188. loader:
  189. shuffle: true
  190. drop_last: false
  191. batch_size_per_card: 8
  192. num_workers: 4
  193. Eval:
  194. dataset:
  195. name: SimpleDataSet
  196. data_dir: ./train_data/icdar2015/text_localization/
  197. label_file_list:
  198. - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
  199. transforms:
  200. - DecodeImage: # load image
  201. img_mode: BGR
  202. channel_first: False
  203. - DetLabelEncode: # Class handling label
  204. - DetResizeForTest:
  205. - NormalizeImage:
  206. scale: 1./255.
  207. mean: [0.485, 0.456, 0.406]
  208. std: [0.229, 0.224, 0.225]
  209. order: 'hwc'
  210. - ToCHWImage:
  211. - KeepKeys:
  212. keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
  213. loader:
  214. shuffle: False
  215. drop_last: False
  216. batch_size_per_card: 1 # must be 1
  217. num_workers: 2