infer.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. sys.path.insert(0, parent_path)
  22. # ignore warning log
  23. import warnings
  24. warnings.filterwarnings('ignore')
  25. import glob
  26. import ast
  27. import paddle
  28. from ppdet.core.workspace import load_config, merge_config
  29. from ppdet.engine import Trainer
  30. from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_mlu, check_version, check_config
  31. from ppdet.utils.cli import ArgsParser, merge_args
  32. from ppdet.slim import build_slim_model
  33. from ppdet.utils.logger import setup_logger
  34. logger = setup_logger('train')
  35. def parse_args():
  36. parser = ArgsParser()
  37. parser.add_argument(
  38. "--infer_dir",
  39. type=str,
  40. default=None,
  41. help="Directory for images to perform inference on.")
  42. parser.add_argument(
  43. "--infer_img",
  44. type=str,
  45. default=None,
  46. help="Image path, has higher priority over --infer_dir")
  47. parser.add_argument(
  48. "--output_dir",
  49. type=str,
  50. default="output",
  51. help="Directory for storing the output visualization files.")
  52. parser.add_argument(
  53. "--draw_threshold",
  54. type=float,
  55. default=0.5,
  56. help="Threshold to reserve the result for visualization.")
  57. parser.add_argument(
  58. "--slim_config",
  59. default=None,
  60. type=str,
  61. help="Configuration file of slim method.")
  62. parser.add_argument(
  63. "--use_vdl",
  64. type=bool,
  65. default=False,
  66. help="Whether to record the data to VisualDL.")
  67. parser.add_argument(
  68. '--vdl_log_dir',
  69. type=str,
  70. default="vdl_log_dir/image",
  71. help='VisualDL logging directory for image.')
  72. parser.add_argument(
  73. "--save_results",
  74. type=bool,
  75. default=False,
  76. help="Whether to save inference results to output_dir.")
  77. parser.add_argument(
  78. "--slice_infer",
  79. action='store_true',
  80. help="Whether to slice the image and merge the inference results for small object detection."
  81. )
  82. parser.add_argument(
  83. '--slice_size',
  84. nargs='+',
  85. type=int,
  86. default=[640, 640],
  87. help="Height of the sliced image.")
  88. parser.add_argument(
  89. "--overlap_ratio",
  90. nargs='+',
  91. type=float,
  92. default=[0.25, 0.25],
  93. help="Overlap height ratio of the sliced image.")
  94. parser.add_argument(
  95. "--combine_method",
  96. type=str,
  97. default='nms',
  98. help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']."
  99. )
  100. parser.add_argument(
  101. "--match_threshold",
  102. type=float,
  103. default=0.6,
  104. help="Combine method matching threshold.")
  105. parser.add_argument(
  106. "--match_metric",
  107. type=str,
  108. default='ios',
  109. help="Combine method matching metric, choose in ['iou', 'ios'].")
  110. parser.add_argument(
  111. "--visualize",
  112. type=ast.literal_eval,
  113. default=True,
  114. help="Whether to save visualize results to output_dir.")
  115. args = parser.parse_args()
  116. return args
  117. def get_test_images(infer_dir, infer_img):
  118. """
  119. Get image path list in TEST mode
  120. """
  121. assert infer_img is not None or infer_dir is not None, \
  122. "--infer_img or --infer_dir should be set"
  123. assert infer_img is None or os.path.isfile(infer_img), \
  124. "{} is not a file".format(infer_img)
  125. assert infer_dir is None or os.path.isdir(infer_dir), \
  126. "{} is not a directory".format(infer_dir)
  127. # infer_img has a higher priority
  128. if infer_img and os.path.isfile(infer_img):
  129. return [infer_img]
  130. images = set()
  131. infer_dir = os.path.abspath(infer_dir)
  132. assert os.path.isdir(infer_dir), \
  133. "infer_dir {} is not a directory".format(infer_dir)
  134. exts = ['jpg', 'jpeg', 'png', 'bmp']
  135. exts += [ext.upper() for ext in exts]
  136. for ext in exts:
  137. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  138. images = list(images)
  139. assert len(images) > 0, "no image found in {}".format(infer_dir)
  140. logger.info("Found {} inference images in total.".format(len(images)))
  141. return images
  142. def run(FLAGS, cfg):
  143. # build trainer
  144. trainer = Trainer(cfg, mode='test')
  145. # load weights
  146. trainer.load_weights(cfg.weights)
  147. # get inference images
  148. images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
  149. # inference
  150. if FLAGS.slice_infer:
  151. trainer.slice_predict(
  152. images,
  153. slice_size=FLAGS.slice_size,
  154. overlap_ratio=FLAGS.overlap_ratio,
  155. combine_method=FLAGS.combine_method,
  156. match_threshold=FLAGS.match_threshold,
  157. match_metric=FLAGS.match_metric,
  158. draw_threshold=FLAGS.draw_threshold,
  159. output_dir=FLAGS.output_dir,
  160. save_results=FLAGS.save_results,
  161. visualize=FLAGS.visualize)
  162. else:
  163. trainer.predict(
  164. images,
  165. draw_threshold=FLAGS.draw_threshold,
  166. output_dir=FLAGS.output_dir,
  167. save_results=FLAGS.save_results,
  168. visualize=FLAGS.visualize)
  169. def main():
  170. FLAGS = parse_args()
  171. cfg = load_config(FLAGS.config)
  172. merge_args(cfg, FLAGS)
  173. merge_config(FLAGS.opt)
  174. # disable npu in config by default
  175. if 'use_npu' not in cfg:
  176. cfg.use_npu = False
  177. # disable xpu in config by default
  178. if 'use_xpu' not in cfg:
  179. cfg.use_xpu = False
  180. if 'use_gpu' not in cfg:
  181. cfg.use_gpu = False
  182. # disable mlu in config by default
  183. if 'use_mlu' not in cfg:
  184. cfg.use_mlu = False
  185. if cfg.use_gpu:
  186. place = paddle.set_device('gpu')
  187. elif cfg.use_npu:
  188. place = paddle.set_device('npu')
  189. elif cfg.use_xpu:
  190. place = paddle.set_device('xpu')
  191. elif cfg.use_mlu:
  192. place = paddle.set_device('mlu')
  193. else:
  194. place = paddle.set_device('cpu')
  195. if FLAGS.slim_config:
  196. cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test')
  197. check_config(cfg)
  198. check_gpu(cfg.use_gpu)
  199. check_npu(cfg.use_npu)
  200. check_xpu(cfg.use_xpu)
  201. check_mlu(cfg.use_mlu)
  202. check_version()
  203. run(FLAGS, cfg)
  204. if __name__ == '__main__':
  205. main()