ch_PP-OCRv3_det_dml.yml 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. Global:
  2. use_gpu: true
  3. epoch_num: 1200
  4. log_smooth_window: 20
  5. print_batch_step: 2
  6. save_model_dir: ./output/ch_db_mv3/
  7. save_epoch_step: 1200
  8. # evaluation is run every 5000 iterations after the 4000th iteration
  9. eval_batch_step: [3000, 2000]
  10. cal_metric_during_train: False
  11. pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
  12. checkpoints:
  13. save_inference_dir:
  14. use_visualdl: False
  15. infer_img: doc/imgs_en/img_10.jpg
  16. save_res_path: ./output/det_db/predicts_db.txt
  17. Architecture:
  18. name: DistillationModel
  19. algorithm: Distillation
  20. model_type: det
  21. Models:
  22. Student:
  23. return_all_feats: false
  24. model_type: det
  25. algorithm: DB
  26. Backbone:
  27. name: ResNet_vd
  28. in_channels: 3
  29. layers: 50
  30. Neck:
  31. name: LKPAN
  32. out_channels: 256
  33. Head:
  34. name: DBHead
  35. kernel_list: [7,2,2]
  36. k: 50
  37. Student2:
  38. return_all_feats: false
  39. model_type: det
  40. algorithm: DB
  41. Backbone:
  42. name: ResNet_vd
  43. in_channels: 3
  44. layers: 50
  45. Neck:
  46. name: LKPAN
  47. out_channels: 256
  48. Head:
  49. name: DBHead
  50. kernel_list: [7,2,2]
  51. k: 50
  52. Loss:
  53. name: CombinedLoss
  54. loss_config_list:
  55. - DistillationDMLLoss:
  56. model_name_pairs:
  57. - ["Student", "Student2"]
  58. maps_name: "thrink_maps"
  59. weight: 1.0
  60. # act: None
  61. model_name_pairs: ["Student", "Student2"]
  62. key: maps
  63. - DistillationDBLoss:
  64. weight: 1.0
  65. model_name_list: ["Student", "Student2"]
  66. # key: maps
  67. name: DBLoss
  68. balance_loss: true
  69. main_loss_type: DiceLoss
  70. alpha: 5
  71. beta: 10
  72. ohem_ratio: 3
  73. Optimizer:
  74. name: Adam
  75. beta1: 0.9
  76. beta2: 0.999
  77. lr:
  78. name: Cosine
  79. learning_rate: 0.001
  80. warmup_epoch: 2
  81. regularizer:
  82. name: 'L2'
  83. factor: 0
  84. PostProcess:
  85. name: DistillationDBPostProcess
  86. model_name: ["Student", "Student2"]
  87. key: head_out
  88. thresh: 0.3
  89. box_thresh: 0.6
  90. max_candidates: 1000
  91. unclip_ratio: 1.5
  92. Metric:
  93. name: DistillationMetric
  94. base_metric_name: DetMetric
  95. main_indicator: hmean
  96. key: "Student"
  97. Train:
  98. dataset:
  99. name: SimpleDataSet
  100. data_dir: ./train_data/icdar2015/text_localization/
  101. label_file_list:
  102. - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
  103. ratio_list: [1.0]
  104. transforms:
  105. - DecodeImage: # load image
  106. img_mode: BGR
  107. channel_first: False
  108. - DetLabelEncode: # Class handling label
  109. - CopyPaste:
  110. - IaaAugment:
  111. augmenter_args:
  112. - { 'type': Fliplr, 'args': { 'p': 0.5 } }
  113. - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
  114. - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
  115. - EastRandomCropData:
  116. size: [960, 960]
  117. max_tries: 50
  118. keep_ratio: true
  119. - MakeBorderMap:
  120. shrink_ratio: 0.4
  121. thresh_min: 0.3
  122. thresh_max: 0.7
  123. - MakeShrinkMap:
  124. shrink_ratio: 0.4
  125. min_text_size: 8
  126. - NormalizeImage:
  127. scale: 1./255.
  128. mean: [0.485, 0.456, 0.406]
  129. std: [0.229, 0.224, 0.225]
  130. order: 'hwc'
  131. - ToCHWImage:
  132. - KeepKeys:
  133. keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
  134. loader:
  135. shuffle: True
  136. drop_last: False
  137. batch_size_per_card: 8
  138. num_workers: 4
  139. Eval:
  140. dataset:
  141. name: SimpleDataSet
  142. data_dir: ./train_data/icdar2015/text_localization/
  143. label_file_list:
  144. - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
  145. transforms:
  146. - DecodeImage: # load image
  147. img_mode: BGR
  148. channel_first: False
  149. - DetLabelEncode: # Class handling label
  150. - DetResizeForTest:
  151. # image_shape: [736, 1280]
  152. - NormalizeImage:
  153. scale: 1./255.
  154. mean: [0.485, 0.456, 0.406]
  155. std: [0.229, 0.224, 0.225]
  156. order: 'hwc'
  157. - ToCHWImage:
  158. - KeepKeys:
  159. keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
  160. loader:
  161. shuffle: False
  162. drop_last: False
  163. batch_size_per_card: 1 # must be 1
  164. num_workers: 2