trt_infer.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import time
  15. import numpy as np
  16. import pycuda.autoinit
  17. import pycuda.driver as cuda
  18. import tensorrt as trt
  19. from collections import OrderedDict
  20. import os
  21. import yaml
  22. import json
  23. import glob
  24. import argparse
  25. from preprocess import Compose
  26. from preprocess import coco_clsid2catid
  27. parser = argparse.ArgumentParser(description=__doc__)
  28. parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml")
  29. parser.add_argument(
  30. "--trt_engine", required=True, type=str, help="trt engine path")
  31. parser.add_argument("--image_dir", type=str)
  32. parser.add_argument("--image_file", type=str)
  33. parser.add_argument(
  34. "--repeats",
  35. type=int,
  36. default=1,
  37. help="Repeat the running test `repeats` times in benchmark")
  38. parser.add_argument(
  39. "--save_coco",
  40. action='store_true',
  41. default=False,
  42. help="Whether to save coco results")
  43. parser.add_argument(
  44. "--coco_file", type=str, default="results.json", help="coco results path")
  45. TRT_LOGGER = trt.Logger()
  46. trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
  47. # Global dictionary
  48. SUPPORT_MODELS = {
  49. 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
  50. 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
  51. 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'HRNet'
  52. }
  53. def get_test_images(infer_dir, infer_img):
  54. """
  55. Get image path list in TEST mode
  56. """
  57. assert infer_img is not None or infer_dir is not None, \
  58. "--image_file or --image_dir should be set"
  59. assert infer_img is None or os.path.isfile(infer_img), \
  60. "{} is not a file".format(infer_img)
  61. assert infer_dir is None or os.path.isdir(infer_dir), \
  62. "{} is not a directory".format(infer_dir)
  63. # infer_img has a higher priority
  64. if infer_img and os.path.isfile(infer_img):
  65. return [infer_img]
  66. images = set()
  67. infer_dir = os.path.abspath(infer_dir)
  68. assert os.path.isdir(infer_dir), \
  69. "infer_dir {} is not a directory".format(infer_dir)
  70. exts = ['jpg', 'jpeg', 'png', 'bmp']
  71. exts += [ext.upper() for ext in exts]
  72. for ext in exts:
  73. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  74. images = list(images)
  75. assert len(images) > 0, "no image found in {}".format(infer_dir)
  76. print("Found {} inference images in total.".format(len(images)))
  77. return images
  78. class PredictConfig(object):
  79. """set config of preprocess, postprocess and visualize
  80. Args:
  81. infer_config (str): path of infer_cfg.yml
  82. """
  83. def __init__(self, infer_config):
  84. # parsing Yaml config for Preprocess
  85. with open(infer_config) as f:
  86. yml_conf = yaml.safe_load(f)
  87. self.check_model(yml_conf)
  88. self.arch = yml_conf['arch']
  89. self.preprocess_infos = yml_conf['Preprocess']
  90. self.min_subgraph_size = yml_conf['min_subgraph_size']
  91. self.label_list = yml_conf['label_list']
  92. self.use_dynamic_shape = yml_conf['use_dynamic_shape']
  93. self.draw_threshold = yml_conf.get("draw_threshold", 0.5)
  94. self.mask = yml_conf.get("mask", False)
  95. self.tracker = yml_conf.get("tracker", None)
  96. self.nms = yml_conf.get("NMS", None)
  97. self.fpn_stride = yml_conf.get("fpn_stride", None)
  98. if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
  99. print(
  100. 'The RCNN export model is used for ONNX and it only supports batch_size = 1'
  101. )
  102. self.print_config()
  103. def check_model(self, yml_conf):
  104. """
  105. Raises:
  106. ValueError: loaded model not in supported model type
  107. """
  108. for support_model in SUPPORT_MODELS:
  109. if support_model in yml_conf['arch']:
  110. return True
  111. raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
  112. 'arch'], SUPPORT_MODELS))
  113. def print_config(self):
  114. print('----------- Model Configuration -----------')
  115. print('%s: %s' % ('Model Arch', self.arch))
  116. print('%s: ' % ('Transform Order'))
  117. for op_info in self.preprocess_infos:
  118. print('--%s: %s' % ('transform op', op_info['type']))
  119. print('--------------------------------------------')
  120. def load_trt_engine(engine_path):
  121. assert os.path.exists(engine_path)
  122. print("Reading engine from file {}".format(engine_path))
  123. with open(engine_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
  124. return runtime.deserialize_cuda_engine(f.read())
  125. def predict_image(infer_config, engine, img_list, save_coco=False, repeats=1):
  126. # load preprocess transforms
  127. transforms = Compose(infer_config.preprocess_infos)
  128. stream = cuda.Stream()
  129. coco_results = []
  130. num_data = len(img_list)
  131. avg_time = []
  132. with engine.create_execution_context() as context:
  133. # Allocate host and device buffers
  134. bindings = create_trt_bindings(engine, context)
  135. # warmup
  136. run_trt_context(context, bindings, stream, repeats=10)
  137. # predict image
  138. for i, img_path in enumerate(img_list):
  139. inputs = transforms(img_path)
  140. inputs_name = [k for k, v in bindings.items() if v['is_input']]
  141. inputs = {
  142. k: inputs[k][None, ]
  143. for k in inputs.keys() if k in inputs_name
  144. }
  145. # run infer
  146. for k, v in inputs.items():
  147. bindings[k]['cpu_data'][...] = v
  148. output = run_trt_context(context, bindings, stream, repeats=repeats)
  149. print(f"{i + 1}/{num_data} infer time: {output['infer_time']} ms.")
  150. avg_time.append(output['infer_time'])
  151. # get output
  152. for k, v in output.items():
  153. if k in bindings.keys():
  154. output[k] = np.reshape(v, bindings[k]['shape'])
  155. if save_coco:
  156. coco_results.extend(
  157. format_coco_results(os.path.split(img_path)[-1], output))
  158. avg_time = np.mean(avg_time)
  159. print(
  160. f"Run on {num_data} data, repeats {repeats} times, avg time: {avg_time} ms."
  161. )
  162. if save_coco:
  163. with open(FLAGS.coco_file, 'w') as f:
  164. json.dump(coco_results, f)
  165. print(f"save coco json to {FLAGS.coco_file}")
  166. def create_trt_bindings(engine, context):
  167. bindings = OrderedDict()
  168. for name in engine:
  169. binding_idx = engine.get_binding_index(name)
  170. size = trt.volume(context.get_binding_shape(binding_idx))
  171. dtype = trt.nptype(engine.get_binding_dtype(name))
  172. shape = list(engine.get_binding_shape(binding_idx))
  173. if shape[0] == -1:
  174. shape[0] = 1
  175. bindings[name] = {
  176. "idx": binding_idx,
  177. "size": size,
  178. "dtype": dtype,
  179. "shape": shape,
  180. "cpu_data": None,
  181. "cuda_ptr": None,
  182. "is_input": True if engine.binding_is_input(name) else False
  183. }
  184. if engine.binding_is_input(name):
  185. bindings[name]['cpu_data'] = np.random.randn(*shape).astype(
  186. np.float32)
  187. bindings[name]['cuda_ptr'] = cuda.mem_alloc(bindings[name][
  188. 'cpu_data'].nbytes)
  189. else:
  190. bindings[name]['cpu_data'] = cuda.pagelocked_empty(size, dtype)
  191. bindings[name]['cuda_ptr'] = cuda.mem_alloc(bindings[name][
  192. 'cpu_data'].nbytes)
  193. return bindings
  194. def run_trt_context(context, bindings, stream, repeats=1):
  195. # Transfer input data to the GPU.
  196. for k, v in bindings.items():
  197. if v['is_input']:
  198. cuda.memcpy_htod_async(v['cuda_ptr'], v['cpu_data'], stream)
  199. in_bindings = [int(v['cuda_ptr']) for k, v in bindings.items()]
  200. output_data = {}
  201. avg_time = []
  202. for _ in range(repeats):
  203. # Run inference
  204. t1 = time.time()
  205. context.execute_async_v2(
  206. bindings=in_bindings, stream_handle=stream.handle)
  207. # Transfer prediction output from the GPU.
  208. for k, v in bindings.items():
  209. if not v['is_input']:
  210. cuda.memcpy_dtoh_async(v['cpu_data'], v['cuda_ptr'], stream)
  211. output_data[k] = v['cpu_data']
  212. # Synchronize the stream
  213. stream.synchronize()
  214. t2 = time.time()
  215. avg_time.append(t2 - t1)
  216. output_data['infer_time'] = np.mean(avg_time) * 1000
  217. return output_data
  218. def format_coco_results(file_name, result):
  219. try:
  220. image_id = int(os.path.splitext(file_name)[0])
  221. except:
  222. image_id = file_name
  223. num_dets = result['num_dets'].tolist()
  224. det_classes = result['det_classes'].tolist()
  225. det_scores = result['det_scores'].tolist()
  226. det_boxes = result['det_boxes'].tolist()
  227. per_result = [
  228. {
  229. 'image_id': image_id,
  230. 'category_id': coco_clsid2catid[int(det_classes[0][idx])],
  231. 'file_name': file_name,
  232. 'bbox': [
  233. det_boxes[0][idx][0], det_boxes[0][idx][1],
  234. det_boxes[0][idx][2] - det_boxes[0][idx][0],
  235. det_boxes[0][idx][3] - det_boxes[0][idx][1]
  236. ], # xyxy -> xywh
  237. 'score': det_scores[0][idx]
  238. } for idx in range(num_dets[0][0])
  239. ]
  240. return per_result
  241. if __name__ == '__main__':
  242. FLAGS = parser.parse_args()
  243. # load image list
  244. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  245. # load trt engine
  246. engine = load_trt_engine(FLAGS.trt_engine)
  247. # load infer config
  248. infer_config = PredictConfig(FLAGS.infer_cfg)
  249. predict_image(infer_config, engine, img_list, FLAGS.save_coco,
  250. FLAGS.repeats)
  251. print('Done!')