eval.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. sys.path.insert(0, parent_path)
  22. # ignore warning log
  23. import warnings
  24. warnings.filterwarnings('ignore')
  25. import paddle
  26. from ppdet.core.workspace import load_config, merge_config
  27. from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_mlu, check_version, check_config
  28. from ppdet.utils.cli import ArgsParser, merge_args
  29. from ppdet.engine import Trainer, init_parallel_env
  30. from ppdet.metrics.coco_utils import json_eval_results
  31. from ppdet.slim import build_slim_model
  32. from ppdet.utils.logger import setup_logger
  33. logger = setup_logger('eval')
  34. def parse_args():
  35. parser = ArgsParser()
  36. parser.add_argument(
  37. "--output_eval",
  38. default=None,
  39. type=str,
  40. help="Evaluation directory, default is current directory.")
  41. parser.add_argument(
  42. '--json_eval',
  43. action='store_true',
  44. default=False,
  45. help='Whether to re eval with already exists bbox.json or mask.json')
  46. parser.add_argument(
  47. "--slim_config",
  48. default=None,
  49. type=str,
  50. help="Configuration file of slim method.")
  51. # TODO: bias should be unified
  52. parser.add_argument(
  53. "--bias",
  54. action="store_true",
  55. help="whether add bias or not while getting w and h")
  56. parser.add_argument(
  57. "--classwise",
  58. action="store_true",
  59. help="whether per-category AP and draw P-R Curve or not.")
  60. parser.add_argument(
  61. '--save_prediction_only',
  62. action='store_true',
  63. default=False,
  64. help='Whether to save the evaluation results only')
  65. parser.add_argument(
  66. "--amp",
  67. action='store_true',
  68. default=False,
  69. help="Enable auto mixed precision eval.")
  70. # for smalldet slice_infer
  71. parser.add_argument(
  72. "--slice_infer",
  73. action='store_true',
  74. help="Whether to slice the image and merge the inference results for small object detection."
  75. )
  76. parser.add_argument(
  77. '--slice_size',
  78. nargs='+',
  79. type=int,
  80. default=[640, 640],
  81. help="Height of the sliced image.")
  82. parser.add_argument(
  83. "--overlap_ratio",
  84. nargs='+',
  85. type=float,
  86. default=[0.25, 0.25],
  87. help="Overlap height ratio of the sliced image.")
  88. parser.add_argument(
  89. "--combine_method",
  90. type=str,
  91. default='nms',
  92. help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']."
  93. )
  94. parser.add_argument(
  95. "--match_threshold",
  96. type=float,
  97. default=0.6,
  98. help="Combine method matching threshold.")
  99. parser.add_argument(
  100. "--match_metric",
  101. type=str,
  102. default='ios',
  103. help="Combine method matching metric, choose in ['iou', 'ios'].")
  104. args = parser.parse_args()
  105. return args
  106. def run(FLAGS, cfg):
  107. if FLAGS.json_eval:
  108. logger.info(
  109. "In json_eval mode, PaddleDetection will evaluate json files in "
  110. "output_eval directly. And proposal.json, bbox.json and mask.json "
  111. "will be detected by default.")
  112. json_eval_results(
  113. cfg.metric,
  114. json_directory=FLAGS.output_eval,
  115. dataset=cfg['EvalDataset'])
  116. return
  117. # init parallel environment if nranks > 1
  118. init_parallel_env()
  119. # build trainer
  120. trainer = Trainer(cfg, mode='eval')
  121. # load weights
  122. trainer.load_weights(cfg.weights)
  123. # training
  124. if FLAGS.slice_infer:
  125. trainer.evaluate_slice(
  126. slice_size=FLAGS.slice_size,
  127. overlap_ratio=FLAGS.overlap_ratio,
  128. combine_method=FLAGS.combine_method,
  129. match_threshold=FLAGS.match_threshold,
  130. match_metric=FLAGS.match_metric)
  131. else:
  132. trainer.evaluate()
  133. def main():
  134. FLAGS = parse_args()
  135. cfg = load_config(FLAGS.config)
  136. merge_args(cfg, FLAGS)
  137. merge_config(FLAGS.opt)
  138. # disable npu in config by default
  139. if 'use_npu' not in cfg:
  140. cfg.use_npu = False
  141. # disable xpu in config by default
  142. if 'use_xpu' not in cfg:
  143. cfg.use_xpu = False
  144. if 'use_gpu' not in cfg:
  145. cfg.use_gpu = False
  146. # disable mlu in config by default
  147. if 'use_mlu' not in cfg:
  148. cfg.use_mlu = False
  149. if cfg.use_gpu:
  150. place = paddle.set_device('gpu')
  151. elif cfg.use_npu:
  152. place = paddle.set_device('npu')
  153. elif cfg.use_xpu:
  154. place = paddle.set_device('xpu')
  155. elif cfg.use_mlu:
  156. place = paddle.set_device('mlu')
  157. else:
  158. place = paddle.set_device('cpu')
  159. if FLAGS.slim_config:
  160. cfg = build_slim_model(cfg, FLAGS.slim_config, mode='eval')
  161. check_config(cfg)
  162. check_gpu(cfg.use_gpu)
  163. check_npu(cfg.use_npu)
  164. check_xpu(cfg.use_xpu)
  165. check_mlu(cfg.use_mlu)
  166. check_version()
  167. run(FLAGS, cfg)
  168. if __name__ == '__main__':
  169. main()