eval.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(os.path.abspath(__file__))
  20. sys.path.insert(0, __dir__)
  21. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  22. import paddle
  23. from ppocr.data import build_dataloader
  24. from ppocr.modeling.architectures import build_model
  25. from ppocr.postprocess import build_post_process
  26. from ppocr.metrics import build_metric
  27. from ppocr.utils.save_load import load_model
  28. import tools.program as program
  29. def main():
  30. global_config = config['Global']
  31. # build dataloader
  32. valid_dataloader = build_dataloader(config, 'Eval', device, logger)
  33. # build post process
  34. post_process_class = build_post_process(config['PostProcess'],
  35. global_config)
  36. # build model
  37. # for rec algorithm
  38. if hasattr(post_process_class, 'character'):
  39. char_num = len(getattr(post_process_class, 'character'))
  40. if config['Architecture']["algorithm"] in ["Distillation",
  41. ]: # distillation model
  42. for key in config['Architecture']["Models"]:
  43. if config['Architecture']['Models'][key]['Head'][
  44. 'name'] == 'MultiHead': # for multi head
  45. out_channels_list = {}
  46. if config['PostProcess'][
  47. 'name'] == 'DistillationSARLabelDecode':
  48. char_num = char_num - 2
  49. out_channels_list['CTCLabelDecode'] = char_num
  50. out_channels_list['SARLabelDecode'] = char_num + 2
  51. config['Architecture']['Models'][key]['Head'][
  52. 'out_channels_list'] = out_channels_list
  53. else:
  54. config['Architecture']["Models"][key]["Head"][
  55. 'out_channels'] = char_num
  56. elif config['Architecture']['Head'][
  57. 'name'] == 'MultiHead': # for multi head
  58. out_channels_list = {}
  59. if config['PostProcess']['name'] == 'SARLabelDecode':
  60. char_num = char_num - 2
  61. out_channels_list['CTCLabelDecode'] = char_num
  62. out_channels_list['SARLabelDecode'] = char_num + 2
  63. config['Architecture']['Head'][
  64. 'out_channels_list'] = out_channels_list
  65. else: # base rec model
  66. config['Architecture']["Head"]['out_channels'] = char_num
  67. if "num_classes" in global_config:
  68. config['Architecture']["Head"]['num_classes'] = global_config["num_classes"]
  69. model = build_model(config['Architecture'])
  70. extra_input_models = [
  71. "SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"
  72. ]
  73. extra_input = False
  74. if config['Architecture']['algorithm'] == 'Distillation':
  75. for key in config['Architecture']["Models"]:
  76. extra_input = extra_input or config['Architecture']['Models'][key][
  77. 'algorithm'] in extra_input_models
  78. else:
  79. extra_input = config['Architecture']['algorithm'] in extra_input_models
  80. if "model_type" in config['Architecture'].keys():
  81. if config['Architecture']['algorithm'] == 'CAN':
  82. model_type = 'can'
  83. else:
  84. model_type = config['Architecture']['model_type']
  85. else:
  86. model_type = None
  87. # build metric
  88. eval_class = build_metric(config['Metric'])
  89. # amp
  90. use_amp = config["Global"].get("use_amp", False)
  91. amp_level = config["Global"].get("amp_level", 'O2')
  92. amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
  93. if use_amp:
  94. AMP_RELATED_FLAGS_SETTING = {
  95. 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
  96. 'FLAGS_max_inplace_grad_add': 8,
  97. }
  98. paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
  99. scale_loss = config["Global"].get("scale_loss", 1.0)
  100. use_dynamic_loss_scaling = config["Global"].get(
  101. "use_dynamic_loss_scaling", False)
  102. scaler = paddle.amp.GradScaler(
  103. init_loss_scaling=scale_loss,
  104. use_dynamic_loss_scaling=use_dynamic_loss_scaling)
  105. if amp_level == "O2":
  106. model = paddle.amp.decorate(
  107. models=model, level=amp_level, master_weight=True)
  108. else:
  109. scaler = None
  110. best_model_dict = load_model(
  111. config, model, model_type=config['Architecture']["model_type"])
  112. if len(best_model_dict):
  113. logger.info('metric in ckpt ***************')
  114. for k, v in best_model_dict.items():
  115. logger.info('{}:{}'.format(k, v))
  116. # start eval
  117. metric = program.eval(model, valid_dataloader, post_process_class,
  118. eval_class, model_type, extra_input, scaler,
  119. amp_level, amp_custom_black_list)
  120. logger.info('metric eval ***************')
  121. for k, v in metric.items():
  122. logger.info('{}:{}'.format(k, v))
  123. if __name__ == '__main__':
  124. config, device, logger, vdl_writer = program.preprocess()
  125. main()