infer.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import yaml
  16. import argparse
  17. import numpy as np
  18. import glob
  19. from onnxruntime import InferenceSession
  20. from preprocess import Compose
  21. # Global dictionary
  22. SUPPORT_MODELS = {
  23. 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
  24. 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
  25. 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
  26. }
  27. parser = argparse.ArgumentParser(description=__doc__)
  28. parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml")
  29. parser.add_argument(
  30. '--onnx_file', type=str, default="model.onnx", help="onnx model file path")
  31. parser.add_argument("--image_dir", type=str)
  32. parser.add_argument("--image_file", type=str)
  33. def get_test_images(infer_dir, infer_img):
  34. """
  35. Get image path list in TEST mode
  36. """
  37. assert infer_img is not None or infer_dir is not None, \
  38. "--image_file or --image_dir should be set"
  39. assert infer_img is None or os.path.isfile(infer_img), \
  40. "{} is not a file".format(infer_img)
  41. assert infer_dir is None or os.path.isdir(infer_dir), \
  42. "{} is not a directory".format(infer_dir)
  43. # infer_img has a higher priority
  44. if infer_img and os.path.isfile(infer_img):
  45. return [infer_img]
  46. images = set()
  47. infer_dir = os.path.abspath(infer_dir)
  48. assert os.path.isdir(infer_dir), \
  49. "infer_dir {} is not a directory".format(infer_dir)
  50. exts = ['jpg', 'jpeg', 'png', 'bmp']
  51. exts += [ext.upper() for ext in exts]
  52. for ext in exts:
  53. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  54. images = list(images)
  55. assert len(images) > 0, "no image found in {}".format(infer_dir)
  56. print("Found {} inference images in total.".format(len(images)))
  57. return images
  58. class PredictConfig(object):
  59. """set config of preprocess, postprocess and visualize
  60. Args:
  61. infer_config (str): path of infer_cfg.yml
  62. """
  63. def __init__(self, infer_config):
  64. # parsing Yaml config for Preprocess
  65. with open(infer_config) as f:
  66. yml_conf = yaml.safe_load(f)
  67. self.check_model(yml_conf)
  68. self.arch = yml_conf['arch']
  69. self.preprocess_infos = yml_conf['Preprocess']
  70. self.min_subgraph_size = yml_conf['min_subgraph_size']
  71. self.label_list = yml_conf['label_list']
  72. self.use_dynamic_shape = yml_conf['use_dynamic_shape']
  73. self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
  74. self.mask = yml_conf.get("mask", False)
  75. self.tracker = yml_conf.get("tracker", None)
  76. self.nms = yml_conf.get("NMS", None)
  77. self.fpn_stride = yml_conf.get("fpn_stride", None)
  78. if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
  79. print(
  80. 'The RCNN export model is used for ONNX and it only supports batch_size = 1'
  81. )
  82. self.print_config()
  83. def check_model(self, yml_conf):
  84. """
  85. Raises:
  86. ValueError: loaded model not in supported model type
  87. """
  88. for support_model in SUPPORT_MODELS:
  89. if support_model in yml_conf['arch']:
  90. return True
  91. raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
  92. 'arch'], SUPPORT_MODELS))
  93. def print_config(self):
  94. print('----------- Model Configuration -----------')
  95. print('%s: %s' % ('Model Arch', self.arch))
  96. print('%s: ' % ('Transform Order'))
  97. for op_info in self.preprocess_infos:
  98. print('--%s: %s' % ('transform op', op_info['type']))
  99. print('--------------------------------------------')
  100. def predict_image(infer_config, predictor, img_list):
  101. # load preprocess transforms
  102. transforms = Compose(infer_config.preprocess_infos)
  103. # predict image
  104. for img_path in img_list:
  105. inputs = transforms(img_path)
  106. inputs_name = [var.name for var in predictor.get_inputs()]
  107. inputs = {k: inputs[k][None, ] for k in inputs_name}
  108. outputs = predictor.run(output_names=None, input_feed=inputs)
  109. print("ONNXRuntime predict: ")
  110. if infer_config.arch in ["HRNet"]:
  111. print(np.array(outputs[0]))
  112. else:
  113. bboxes = np.array(outputs[0])
  114. for bbox in bboxes:
  115. if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
  116. print(f"{int(bbox[0])} {bbox[1]} "
  117. f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}")
  118. if __name__ == '__main__':
  119. FLAGS = parser.parse_args()
  120. # load image list
  121. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  122. # load predictor
  123. predictor = InferenceSession(FLAGS.onnx_file)
  124. # load infer config
  125. infer_config = PredictConfig(FLAGS.infer_cfg)
  126. predict_image(infer_config, predictor, img_list)