infer.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import yaml
  16. import glob
  17. import json
  18. from pathlib import Path
  19. from functools import reduce
  20. import cv2
  21. import numpy as np
  22. import math
  23. import paddle
  24. from paddle.inference import Config
  25. from paddle.inference import create_predictor
  26. import sys
  27. # add deploy path of PadleDetection to sys.path
  28. parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
  29. sys.path.insert(0, parent_path)
  30. from benchmark_utils import PaddleInferBenchmark
  31. from picodet_postprocess import PicoDetPostProcess
  32. from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image
  33. from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
  34. from visualize import visualize_box_mask
  35. from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco_clsid2catid
  36. # Global dictionary
  37. SUPPORT_MODELS = {
  38. 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
  39. 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
  40. 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'YOLOF', 'PPHGNet',
  41. 'PPLCNet', 'DETR', 'CenterTrack'
  42. }
  43. TUNED_TRT_DYNAMIC_MODELS = {'DETR'}
  44. def bench_log(detector, img_list, model_info, batch_size=1, name=None):
  45. mems = {
  46. 'cpu_rss_mb': detector.cpu_mem / len(img_list),
  47. 'gpu_rss_mb': detector.gpu_mem / len(img_list),
  48. 'gpu_util': detector.gpu_util * 100 / len(img_list)
  49. }
  50. perf_info = detector.det_times.report(average=True)
  51. data_info = {
  52. 'batch_size': batch_size,
  53. 'shape': "dynamic_shape",
  54. 'data_num': perf_info['img_num']
  55. }
  56. log = PaddleInferBenchmark(detector.config, model_info, data_info,
  57. perf_info, mems)
  58. log(name)
  59. class Detector(object):
  60. """
  61. Args:
  62. pred_config (object): config of model, defined by `Config(model_dir)`
  63. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  64. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  65. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  66. batch_size (int): size of pre batch in inference
  67. trt_min_shape (int): min shape for dynamic shape in trt
  68. trt_max_shape (int): max shape for dynamic shape in trt
  69. trt_opt_shape (int): opt shape for dynamic shape in trt
  70. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  71. calibration, trt_calib_mode need to set True
  72. cpu_threads (int): cpu threads
  73. enable_mkldnn (bool): whether to open MKLDNN
  74. enable_mkldnn_bfloat16 (bool): whether to turn on mkldnn bfloat16
  75. output_dir (str): The path of output
  76. threshold (float): The threshold of score for visualization
  77. delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
  78. Used by action model.
  79. """
  80. def __init__(self,
  81. model_dir,
  82. device='CPU',
  83. run_mode='paddle',
  84. batch_size=1,
  85. trt_min_shape=1,
  86. trt_max_shape=1280,
  87. trt_opt_shape=640,
  88. trt_calib_mode=False,
  89. cpu_threads=1,
  90. enable_mkldnn=False,
  91. enable_mkldnn_bfloat16=False,
  92. output_dir='output',
  93. threshold=0.5,
  94. delete_shuffle_pass=False):
  95. self.pred_config = self.set_config(model_dir)
  96. self.predictor, self.config = load_predictor(
  97. model_dir,
  98. self.pred_config.arch,
  99. run_mode=run_mode,
  100. batch_size=batch_size,
  101. min_subgraph_size=self.pred_config.min_subgraph_size,
  102. device=device,
  103. use_dynamic_shape=self.pred_config.use_dynamic_shape,
  104. trt_min_shape=trt_min_shape,
  105. trt_max_shape=trt_max_shape,
  106. trt_opt_shape=trt_opt_shape,
  107. trt_calib_mode=trt_calib_mode,
  108. cpu_threads=cpu_threads,
  109. enable_mkldnn=enable_mkldnn,
  110. enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
  111. delete_shuffle_pass=delete_shuffle_pass)
  112. self.det_times = Timer()
  113. self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
  114. self.batch_size = batch_size
  115. self.output_dir = output_dir
  116. self.threshold = threshold
  117. def set_config(self, model_dir):
  118. return PredictConfig(model_dir)
  119. def preprocess(self, image_list):
  120. preprocess_ops = []
  121. for op_info in self.pred_config.preprocess_infos:
  122. new_op_info = op_info.copy()
  123. op_type = new_op_info.pop('type')
  124. preprocess_ops.append(eval(op_type)(**new_op_info))
  125. input_im_lst = []
  126. input_im_info_lst = []
  127. for im_path in image_list:
  128. im, im_info = preprocess(im_path, preprocess_ops)
  129. input_im_lst.append(im)
  130. input_im_info_lst.append(im_info)
  131. inputs = create_inputs(input_im_lst, input_im_info_lst)
  132. input_names = self.predictor.get_input_names()
  133. for i in range(len(input_names)):
  134. input_tensor = self.predictor.get_input_handle(input_names[i])
  135. if input_names[i] == 'x':
  136. input_tensor.copy_from_cpu(inputs['image'])
  137. else:
  138. input_tensor.copy_from_cpu(inputs[input_names[i]])
  139. return inputs
  140. def postprocess(self, inputs, result):
  141. # postprocess output of predictor
  142. np_boxes_num = result['boxes_num']
  143. assert isinstance(np_boxes_num, np.ndarray), \
  144. '`np_boxes_num` should be a `numpy.ndarray`'
  145. result = {k: v for k, v in result.items() if v is not None}
  146. return result
  147. def filter_box(self, result, threshold):
  148. np_boxes_num = result['boxes_num']
  149. boxes = result['boxes']
  150. start_idx = 0
  151. filter_boxes = []
  152. filter_num = []
  153. for i in range(len(np_boxes_num)):
  154. boxes_num = np_boxes_num[i]
  155. boxes_i = boxes[start_idx:start_idx + boxes_num, :]
  156. idx = boxes_i[:, 1] > threshold
  157. filter_boxes_i = boxes_i[idx, :]
  158. filter_boxes.append(filter_boxes_i)
  159. filter_num.append(filter_boxes_i.shape[0])
  160. start_idx += boxes_num
  161. boxes = np.concatenate(filter_boxes)
  162. filter_num = np.array(filter_num)
  163. filter_res = {'boxes': boxes, 'boxes_num': filter_num}
  164. return filter_res
  165. def predict(self, repeats=1):
  166. '''
  167. Args:
  168. repeats (int): repeats number for prediction
  169. Returns:
  170. result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
  171. matix element:[class, score, x_min, y_min, x_max, y_max]
  172. MaskRCNN's result include 'masks': np.ndarray:
  173. shape: [N, im_h, im_w]
  174. '''
  175. # model prediction
  176. np_boxes_num, np_boxes, np_masks = np.array([0]), None, None
  177. for i in range(repeats):
  178. self.predictor.run()
  179. output_names = self.predictor.get_output_names()
  180. boxes_tensor = self.predictor.get_output_handle(output_names[0])
  181. np_boxes = boxes_tensor.copy_to_cpu()
  182. if len(output_names) == 1:
  183. # some exported model can not get tensor 'bbox_num'
  184. np_boxes_num = np.array([len(np_boxes)])
  185. else:
  186. boxes_num = self.predictor.get_output_handle(output_names[1])
  187. np_boxes_num = boxes_num.copy_to_cpu()
  188. if self.pred_config.mask:
  189. masks_tensor = self.predictor.get_output_handle(output_names[2])
  190. np_masks = masks_tensor.copy_to_cpu()
  191. result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
  192. return result
  193. def merge_batch_result(self, batch_result):
  194. if len(batch_result) == 1:
  195. return batch_result[0]
  196. res_key = batch_result[0].keys()
  197. results = {k: [] for k in res_key}
  198. for res in batch_result:
  199. for k, v in res.items():
  200. results[k].append(v)
  201. for k, v in results.items():
  202. if k not in ['masks', 'segm']:
  203. results[k] = np.concatenate(v)
  204. return results
  205. def get_timer(self):
  206. return self.det_times
  207. def predict_image_slice(self,
  208. img_list,
  209. slice_size=[640, 640],
  210. overlap_ratio=[0.25, 0.25],
  211. combine_method='nms',
  212. match_threshold=0.6,
  213. match_metric='ios',
  214. run_benchmark=False,
  215. repeats=1,
  216. visual=True,
  217. save_results=False):
  218. # slice infer only support bs=1
  219. results = []
  220. try:
  221. import sahi
  222. from sahi.slicing import slice_image
  223. except Exception as e:
  224. print(
  225. 'sahi not found, plaese install sahi. '
  226. 'for example: `pip install sahi`, see https://github.com/obss/sahi.'
  227. )
  228. raise e
  229. num_classes = len(self.pred_config.labels)
  230. for i in range(len(img_list)):
  231. ori_image = img_list[i]
  232. slice_image_result = sahi.slicing.slice_image(
  233. image=ori_image,
  234. slice_height=slice_size[0],
  235. slice_width=slice_size[1],
  236. overlap_height_ratio=overlap_ratio[0],
  237. overlap_width_ratio=overlap_ratio[1])
  238. sub_img_num = len(slice_image_result)
  239. merged_bboxs = []
  240. print('slice to {} sub_samples.', sub_img_num)
  241. batch_image_list = [
  242. slice_image_result.images[_ind] for _ind in range(sub_img_num)
  243. ]
  244. if run_benchmark:
  245. # preprocess
  246. inputs = self.preprocess(batch_image_list) # warmup
  247. self.det_times.preprocess_time_s.start()
  248. inputs = self.preprocess(batch_image_list)
  249. self.det_times.preprocess_time_s.end()
  250. # model prediction
  251. result = self.predict(repeats=50) # warmup
  252. self.det_times.inference_time_s.start()
  253. result = self.predict(repeats=repeats)
  254. self.det_times.inference_time_s.end(repeats=repeats)
  255. # postprocess
  256. result_warmup = self.postprocess(inputs, result) # warmup
  257. self.det_times.postprocess_time_s.start()
  258. result = self.postprocess(inputs, result)
  259. self.det_times.postprocess_time_s.end()
  260. self.det_times.img_num += 1
  261. cm, gm, gu = get_current_memory_mb()
  262. self.cpu_mem += cm
  263. self.gpu_mem += gm
  264. self.gpu_util += gu
  265. else:
  266. # preprocess
  267. self.det_times.preprocess_time_s.start()
  268. inputs = self.preprocess(batch_image_list)
  269. self.det_times.preprocess_time_s.end()
  270. # model prediction
  271. self.det_times.inference_time_s.start()
  272. result = self.predict()
  273. self.det_times.inference_time_s.end()
  274. # postprocess
  275. self.det_times.postprocess_time_s.start()
  276. result = self.postprocess(inputs, result)
  277. self.det_times.postprocess_time_s.end()
  278. self.det_times.img_num += 1
  279. st, ed = 0, result['boxes_num'][0] # start_index, end_index
  280. for _ind in range(sub_img_num):
  281. boxes_num = result['boxes_num'][_ind]
  282. ed = st + boxes_num
  283. shift_amount = slice_image_result.starting_pixels[_ind]
  284. result['boxes'][st:ed][:, 2:4] = result['boxes'][
  285. st:ed][:, 2:4] + shift_amount
  286. result['boxes'][st:ed][:, 4:6] = result['boxes'][
  287. st:ed][:, 4:6] + shift_amount
  288. merged_bboxs.append(result['boxes'][st:ed])
  289. st = ed
  290. merged_results = {'boxes': []}
  291. if combine_method == 'nms':
  292. final_boxes = multiclass_nms(
  293. np.concatenate(merged_bboxs), num_classes, match_threshold,
  294. match_metric)
  295. merged_results['boxes'] = np.concatenate(final_boxes)
  296. elif combine_method == 'concat':
  297. merged_results['boxes'] = np.concatenate(merged_bboxs)
  298. else:
  299. raise ValueError(
  300. "Now only support 'nms' or 'concat' to fuse detection results."
  301. )
  302. merged_results['boxes_num'] = np.array(
  303. [len(merged_results['boxes'])], dtype=np.int32)
  304. if visual:
  305. visualize(
  306. [ori_image], # should be list
  307. merged_results,
  308. self.pred_config.labels,
  309. output_dir=self.output_dir,
  310. threshold=self.threshold)
  311. results.append(merged_results)
  312. print('Test iter {}'.format(i))
  313. results = self.merge_batch_result(results)
  314. if save_results:
  315. Path(self.output_dir).mkdir(exist_ok=True)
  316. self.save_coco_results(
  317. img_list, results, use_coco_category=FLAGS.use_coco_category)
  318. return results
  319. def predict_image(self,
  320. image_list,
  321. run_benchmark=False,
  322. repeats=1,
  323. visual=True,
  324. save_results=False):
  325. batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
  326. results = []
  327. for i in range(batch_loop_cnt):
  328. start_index = i * self.batch_size
  329. end_index = min((i + 1) * self.batch_size, len(image_list))
  330. batch_image_list = image_list[start_index:end_index]
  331. if run_benchmark:
  332. # preprocess
  333. inputs = self.preprocess(batch_image_list) # warmup
  334. self.det_times.preprocess_time_s.start()
  335. inputs = self.preprocess(batch_image_list)
  336. self.det_times.preprocess_time_s.end()
  337. # model prediction
  338. result = self.predict(repeats=50) # warmup
  339. self.det_times.inference_time_s.start()
  340. result = self.predict(repeats=repeats)
  341. self.det_times.inference_time_s.end(repeats=repeats)
  342. # postprocess
  343. result_warmup = self.postprocess(inputs, result) # warmup
  344. self.det_times.postprocess_time_s.start()
  345. result = self.postprocess(inputs, result)
  346. self.det_times.postprocess_time_s.end()
  347. self.det_times.img_num += len(batch_image_list)
  348. cm, gm, gu = get_current_memory_mb()
  349. self.cpu_mem += cm
  350. self.gpu_mem += gm
  351. self.gpu_util += gu
  352. else:
  353. # preprocess
  354. self.det_times.preprocess_time_s.start()
  355. inputs = self.preprocess(batch_image_list)
  356. self.det_times.preprocess_time_s.end()
  357. # model prediction
  358. self.det_times.inference_time_s.start()
  359. result = self.predict()
  360. self.det_times.inference_time_s.end()
  361. # postprocess
  362. self.det_times.postprocess_time_s.start()
  363. result = self.postprocess(inputs, result)
  364. self.det_times.postprocess_time_s.end()
  365. self.det_times.img_num += len(batch_image_list)
  366. if visual:
  367. visualize(
  368. batch_image_list,
  369. result,
  370. self.pred_config.labels,
  371. output_dir=self.output_dir,
  372. threshold=self.threshold)
  373. results.append(result)
  374. print('Test iter {}'.format(i))
  375. results = self.merge_batch_result(results)
  376. if save_results:
  377. Path(self.output_dir).mkdir(exist_ok=True)
  378. self.save_coco_results(
  379. image_list, results, use_coco_category=FLAGS.use_coco_category)
  380. return results
  381. def predict_video(self, video_file, camera_id):
  382. video_out_name = 'output.mp4'
  383. if camera_id != -1:
  384. capture = cv2.VideoCapture(camera_id)
  385. else:
  386. capture = cv2.VideoCapture(video_file)
  387. video_out_name = os.path.split(video_file)[-1]
  388. # Get Video info : resolution, fps, frame count
  389. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  390. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  391. fps = int(capture.get(cv2.CAP_PROP_FPS))
  392. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  393. print("fps: %d, frame_count: %d" % (fps, frame_count))
  394. if not os.path.exists(self.output_dir):
  395. os.makedirs(self.output_dir)
  396. out_path = os.path.join(self.output_dir, video_out_name)
  397. fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
  398. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  399. index = 1
  400. while (1):
  401. ret, frame = capture.read()
  402. if not ret:
  403. break
  404. print('detect frame: %d' % (index))
  405. index += 1
  406. results = self.predict_image([frame[:, :, ::-1]], visual=False)
  407. im = visualize_box_mask(
  408. frame,
  409. results,
  410. self.pred_config.labels,
  411. threshold=self.threshold)
  412. im = np.array(im)
  413. writer.write(im)
  414. if camera_id != -1:
  415. cv2.imshow('Mask Detection', im)
  416. if cv2.waitKey(1) & 0xFF == ord('q'):
  417. break
  418. writer.release()
  419. def save_coco_results(self, image_list, results, use_coco_category=False):
  420. bbox_results = []
  421. mask_results = []
  422. idx = 0
  423. print("Start saving coco json files...")
  424. for i, box_num in enumerate(results['boxes_num']):
  425. file_name = os.path.split(image_list[i])[-1]
  426. if use_coco_category:
  427. img_id = int(os.path.splitext(file_name)[0])
  428. else:
  429. img_id = i
  430. if 'boxes' in results:
  431. boxes = results['boxes'][idx:idx + box_num].tolist()
  432. bbox_results.extend([{
  433. 'image_id': img_id,
  434. 'category_id': coco_clsid2catid[int(box[0])] \
  435. if use_coco_category else int(box[0]),
  436. 'file_name': file_name,
  437. 'bbox': [box[2], box[3], box[4] - box[2],
  438. box[5] - box[3]], # xyxy -> xywh
  439. 'score': box[1]} for box in boxes])
  440. if 'masks' in results:
  441. import pycocotools.mask as mask_util
  442. boxes = results['boxes'][idx:idx + box_num].tolist()
  443. masks = results['masks'][i][:box_num].astype(np.uint8)
  444. seg_res = []
  445. for box, mask in zip(boxes, masks):
  446. rle = mask_util.encode(
  447. np.array(
  448. mask[:, :, None], dtype=np.uint8, order="F"))[0]
  449. if 'counts' in rle:
  450. rle['counts'] = rle['counts'].decode("utf8")
  451. seg_res.append({
  452. 'image_id': img_id,
  453. 'category_id': coco_clsid2catid[int(box[0])] \
  454. if use_coco_category else int(box[0]),
  455. 'file_name': file_name,
  456. 'segmentation': rle,
  457. 'score': box[1]})
  458. mask_results.extend(seg_res)
  459. idx += box_num
  460. if bbox_results:
  461. bbox_file = os.path.join(self.output_dir, "bbox.json")
  462. with open(bbox_file, 'w') as f:
  463. json.dump(bbox_results, f)
  464. print(f"The bbox result is saved to {bbox_file}")
  465. if mask_results:
  466. mask_file = os.path.join(self.output_dir, "mask.json")
  467. with open(mask_file, 'w') as f:
  468. json.dump(mask_results, f)
  469. print(f"The mask result is saved to {mask_file}")
  470. class DetectorSOLOv2(Detector):
  471. """
  472. Args:
  473. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  474. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  475. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  476. batch_size (int): size of pre batch in inference
  477. trt_min_shape (int): min shape for dynamic shape in trt
  478. trt_max_shape (int): max shape for dynamic shape in trt
  479. trt_opt_shape (int): opt shape for dynamic shape in trt
  480. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  481. calibration, trt_calib_mode need to set True
  482. cpu_threads (int): cpu threads
  483. enable_mkldnn (bool): whether to open MKLDNN
  484. enable_mkldnn_bfloat16 (bool): Whether to turn on mkldnn bfloat16
  485. output_dir (str): The path of output
  486. threshold (float): The threshold of score for visualization
  487. """
  488. def __init__(
  489. self,
  490. model_dir,
  491. device='CPU',
  492. run_mode='paddle',
  493. batch_size=1,
  494. trt_min_shape=1,
  495. trt_max_shape=1280,
  496. trt_opt_shape=640,
  497. trt_calib_mode=False,
  498. cpu_threads=1,
  499. enable_mkldnn=False,
  500. enable_mkldnn_bfloat16=False,
  501. output_dir='./',
  502. threshold=0.5, ):
  503. super(DetectorSOLOv2, self).__init__(
  504. model_dir=model_dir,
  505. device=device,
  506. run_mode=run_mode,
  507. batch_size=batch_size,
  508. trt_min_shape=trt_min_shape,
  509. trt_max_shape=trt_max_shape,
  510. trt_opt_shape=trt_opt_shape,
  511. trt_calib_mode=trt_calib_mode,
  512. cpu_threads=cpu_threads,
  513. enable_mkldnn=enable_mkldnn,
  514. enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
  515. output_dir=output_dir,
  516. threshold=threshold, )
  517. def predict(self, repeats=1):
  518. '''
  519. Args:
  520. repeats (int): repeat number for prediction
  521. Returns:
  522. result (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
  523. 'cate_label': label of segm, shape:[N]
  524. 'cate_score': confidence score of segm, shape:[N]
  525. '''
  526. np_label, np_score, np_segms = None, None, None
  527. for i in range(repeats):
  528. self.predictor.run()
  529. output_names = self.predictor.get_output_names()
  530. np_boxes_num = self.predictor.get_output_handle(output_names[
  531. 0]).copy_to_cpu()
  532. np_label = self.predictor.get_output_handle(output_names[
  533. 1]).copy_to_cpu()
  534. np_score = self.predictor.get_output_handle(output_names[
  535. 2]).copy_to_cpu()
  536. np_segms = self.predictor.get_output_handle(output_names[
  537. 3]).copy_to_cpu()
  538. result = dict(
  539. segm=np_segms,
  540. label=np_label,
  541. score=np_score,
  542. boxes_num=np_boxes_num)
  543. return result
  544. class DetectorPicoDet(Detector):
  545. """
  546. Args:
  547. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  548. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  549. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  550. batch_size (int): size of pre batch in inference
  551. trt_min_shape (int): min shape for dynamic shape in trt
  552. trt_max_shape (int): max shape for dynamic shape in trt
  553. trt_opt_shape (int): opt shape for dynamic shape in trt
  554. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  555. calibration, trt_calib_mode need to set True
  556. cpu_threads (int): cpu threads
  557. enable_mkldnn (bool): whether to turn on MKLDNN
  558. enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
  559. """
  560. def __init__(
  561. self,
  562. model_dir,
  563. device='CPU',
  564. run_mode='paddle',
  565. batch_size=1,
  566. trt_min_shape=1,
  567. trt_max_shape=1280,
  568. trt_opt_shape=640,
  569. trt_calib_mode=False,
  570. cpu_threads=1,
  571. enable_mkldnn=False,
  572. enable_mkldnn_bfloat16=False,
  573. output_dir='./',
  574. threshold=0.5, ):
  575. super(DetectorPicoDet, self).__init__(
  576. model_dir=model_dir,
  577. device=device,
  578. run_mode=run_mode,
  579. batch_size=batch_size,
  580. trt_min_shape=trt_min_shape,
  581. trt_max_shape=trt_max_shape,
  582. trt_opt_shape=trt_opt_shape,
  583. trt_calib_mode=trt_calib_mode,
  584. cpu_threads=cpu_threads,
  585. enable_mkldnn=enable_mkldnn,
  586. enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
  587. output_dir=output_dir,
  588. threshold=threshold, )
  589. def postprocess(self, inputs, result):
  590. # postprocess output of predictor
  591. np_score_list = result['boxes']
  592. np_boxes_list = result['boxes_num']
  593. postprocessor = PicoDetPostProcess(
  594. inputs['image'].shape[2:],
  595. inputs['im_shape'],
  596. inputs['scale_factor'],
  597. strides=self.pred_config.fpn_stride,
  598. nms_threshold=self.pred_config.nms['nms_threshold'])
  599. np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
  600. result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
  601. return result
  602. def predict(self, repeats=1):
  603. '''
  604. Args:
  605. repeats (int): repeat number for prediction
  606. Returns:
  607. result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
  608. matix element:[class, score, x_min, y_min, x_max, y_max]
  609. '''
  610. np_score_list, np_boxes_list = [], []
  611. for i in range(repeats):
  612. self.predictor.run()
  613. np_score_list.clear()
  614. np_boxes_list.clear()
  615. output_names = self.predictor.get_output_names()
  616. num_outs = int(len(output_names) / 2)
  617. for out_idx in range(num_outs):
  618. np_score_list.append(
  619. self.predictor.get_output_handle(output_names[out_idx])
  620. .copy_to_cpu())
  621. np_boxes_list.append(
  622. self.predictor.get_output_handle(output_names[
  623. out_idx + num_outs]).copy_to_cpu())
  624. result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
  625. return result
  626. def create_inputs(imgs, im_info):
  627. """generate input for different model type
  628. Args:
  629. imgs (list(numpy)): list of images (np.ndarray)
  630. im_info (list(dict)): list of image info
  631. Returns:
  632. inputs (dict): input of model
  633. """
  634. inputs = {}
  635. im_shape = []
  636. scale_factor = []
  637. if len(imgs) == 1:
  638. inputs['image'] = np.array((imgs[0], )).astype('float32')
  639. inputs['im_shape'] = np.array(
  640. (im_info[0]['im_shape'], )).astype('float32')
  641. inputs['scale_factor'] = np.array(
  642. (im_info[0]['scale_factor'], )).astype('float32')
  643. return inputs
  644. for e in im_info:
  645. im_shape.append(np.array((e['im_shape'], )).astype('float32'))
  646. scale_factor.append(np.array((e['scale_factor'], )).astype('float32'))
  647. inputs['im_shape'] = np.concatenate(im_shape, axis=0)
  648. inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
  649. imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
  650. max_shape_h = max([e[0] for e in imgs_shape])
  651. max_shape_w = max([e[1] for e in imgs_shape])
  652. padding_imgs = []
  653. for img in imgs:
  654. im_c, im_h, im_w = img.shape[:]
  655. padding_im = np.zeros(
  656. (im_c, max_shape_h, max_shape_w), dtype=np.float32)
  657. padding_im[:, :im_h, :im_w] = img
  658. padding_imgs.append(padding_im)
  659. inputs['image'] = np.stack(padding_imgs, axis=0)
  660. return inputs
  661. class PredictConfig():
  662. """set config of preprocess, postprocess and visualize
  663. Args:
  664. model_dir (str): root path of model.yml
  665. """
  666. def __init__(self, model_dir):
  667. # parsing Yaml config for Preprocess
  668. deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
  669. with open(deploy_file) as f:
  670. yml_conf = yaml.safe_load(f)
  671. self.check_model(yml_conf)
  672. self.arch = yml_conf['arch']
  673. self.preprocess_infos = yml_conf['Preprocess']
  674. self.min_subgraph_size = yml_conf['min_subgraph_size']
  675. self.labels = yml_conf['label_list']
  676. self.mask = False
  677. self.use_dynamic_shape = yml_conf['use_dynamic_shape']
  678. if 'mask' in yml_conf:
  679. self.mask = yml_conf['mask']
  680. self.tracker = None
  681. if 'tracker' in yml_conf:
  682. self.tracker = yml_conf['tracker']
  683. if 'NMS' in yml_conf:
  684. self.nms = yml_conf['NMS']
  685. if 'fpn_stride' in yml_conf:
  686. self.fpn_stride = yml_conf['fpn_stride']
  687. if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
  688. print(
  689. 'The RCNN export model is used for ONNX and it only supports batch_size = 1'
  690. )
  691. self.print_config()
  692. def check_model(self, yml_conf):
  693. """
  694. Raises:
  695. ValueError: loaded model not in supported model type
  696. """
  697. for support_model in SUPPORT_MODELS:
  698. if support_model in yml_conf['arch']:
  699. return True
  700. raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
  701. 'arch'], SUPPORT_MODELS))
  702. def print_config(self):
  703. print('----------- Model Configuration -----------')
  704. print('%s: %s' % ('Model Arch', self.arch))
  705. print('%s: ' % ('Transform Order'))
  706. for op_info in self.preprocess_infos:
  707. print('--%s: %s' % ('transform op', op_info['type']))
  708. print('--------------------------------------------')
  709. def load_predictor(model_dir,
  710. arch,
  711. run_mode='paddle',
  712. batch_size=1,
  713. device='CPU',
  714. min_subgraph_size=3,
  715. use_dynamic_shape=False,
  716. trt_min_shape=1,
  717. trt_max_shape=1280,
  718. trt_opt_shape=640,
  719. trt_calib_mode=False,
  720. cpu_threads=1,
  721. enable_mkldnn=False,
  722. enable_mkldnn_bfloat16=False,
  723. delete_shuffle_pass=False,
  724. tuned_trt_shape_file="shape_range_info.pbtxt"):
  725. """set AnalysisConfig, generate AnalysisPredictor
  726. Args:
  727. model_dir (str): root path of __model__ and __params__
  728. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  729. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
  730. use_dynamic_shape (bool): use dynamic shape or not
  731. trt_min_shape (int): min shape for dynamic shape in trt
  732. trt_max_shape (int): max shape for dynamic shape in trt
  733. trt_opt_shape (int): opt shape for dynamic shape in trt
  734. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  735. calibration, trt_calib_mode need to set True
  736. delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
  737. Used by action model.
  738. Returns:
  739. predictor (PaddlePredictor): AnalysisPredictor
  740. Raises:
  741. ValueError: predict by TensorRT need device == 'GPU'.
  742. """
  743. if device != 'GPU' and run_mode != 'paddle':
  744. raise ValueError(
  745. "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
  746. .format(run_mode, device))
  747. infer_model = os.path.join(model_dir, 'model.pdmodel')
  748. infer_params = os.path.join(model_dir, 'model.pdiparams')
  749. if not os.path.exists(infer_model):
  750. infer_model = os.path.join(model_dir, 'inference.pdmodel')
  751. infer_params = os.path.join(model_dir, 'inference.pdiparams')
  752. if not os.path.exists(infer_model):
  753. raise ValueError(
  754. "Cannot find any inference model in dir: {},".format(model_dir))
  755. config = Config(infer_model, infer_params)
  756. if device == 'GPU':
  757. # initial GPU memory(M), device ID
  758. config.enable_use_gpu(200, 0)
  759. # optimize graph and fuse op
  760. config.switch_ir_optim(True)
  761. elif device == 'XPU':
  762. if config.lite_engine_enabled():
  763. config.enable_lite_engine()
  764. config.enable_xpu(10 * 1024 * 1024)
  765. elif device == 'NPU':
  766. if config.lite_engine_enabled():
  767. config.enable_lite_engine()
  768. config.enable_npu()
  769. else:
  770. config.disable_gpu()
  771. config.set_cpu_math_library_num_threads(cpu_threads)
  772. if enable_mkldnn:
  773. try:
  774. # cache 10 different shapes for mkldnn to avoid memory leak
  775. config.set_mkldnn_cache_capacity(10)
  776. config.enable_mkldnn()
  777. if enable_mkldnn_bfloat16:
  778. config.enable_mkldnn_bfloat16()
  779. except Exception as e:
  780. print(
  781. "The current environment does not support `mkldnn`, so disable mkldnn."
  782. )
  783. pass
  784. precision_map = {
  785. 'trt_int8': Config.Precision.Int8,
  786. 'trt_fp32': Config.Precision.Float32,
  787. 'trt_fp16': Config.Precision.Half
  788. }
  789. if run_mode in precision_map.keys():
  790. if arch in TUNED_TRT_DYNAMIC_MODELS:
  791. config.collect_shape_range_info(tuned_trt_shape_file)
  792. config.enable_tensorrt_engine(
  793. workspace_size=(1 << 25) * batch_size,
  794. max_batch_size=batch_size,
  795. min_subgraph_size=min_subgraph_size,
  796. precision_mode=precision_map[run_mode],
  797. use_static=False,
  798. use_calib_mode=trt_calib_mode)
  799. if arch in TUNED_TRT_DYNAMIC_MODELS:
  800. config.enable_tuned_tensorrt_dynamic_shape(tuned_trt_shape_file,
  801. True)
  802. if use_dynamic_shape:
  803. min_input_shape = {
  804. 'image': [batch_size, 3, trt_min_shape, trt_min_shape],
  805. 'scale_factor': [batch_size, 2]
  806. }
  807. max_input_shape = {
  808. 'image': [batch_size, 3, trt_max_shape, trt_max_shape],
  809. 'scale_factor': [batch_size, 2]
  810. }
  811. opt_input_shape = {
  812. 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape],
  813. 'scale_factor': [batch_size, 2]
  814. }
  815. config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
  816. opt_input_shape)
  817. print('trt set dynamic shape done!')
  818. # disable print log when predict
  819. config.disable_glog_info()
  820. # enable shared memory
  821. config.enable_memory_optim()
  822. # disable feed, fetch OP, needed by zero_copy_run
  823. config.switch_use_feed_fetch_ops(False)
  824. if delete_shuffle_pass:
  825. config.delete_pass("shuffle_channel_detect_pass")
  826. predictor = create_predictor(config)
  827. return predictor, config
  828. def get_test_images(infer_dir, infer_img):
  829. """
  830. Get image path list in TEST mode
  831. """
  832. assert infer_img is not None or infer_dir is not None, \
  833. "--image_file or --image_dir should be set"
  834. assert infer_img is None or os.path.isfile(infer_img), \
  835. "{} is not a file".format(infer_img)
  836. assert infer_dir is None or os.path.isdir(infer_dir), \
  837. "{} is not a directory".format(infer_dir)
  838. # infer_img has a higher priority
  839. if infer_img and os.path.isfile(infer_img):
  840. return [infer_img]
  841. images = set()
  842. infer_dir = os.path.abspath(infer_dir)
  843. assert os.path.isdir(infer_dir), \
  844. "infer_dir {} is not a directory".format(infer_dir)
  845. exts = ['jpg', 'jpeg', 'png', 'bmp']
  846. exts += [ext.upper() for ext in exts]
  847. for ext in exts:
  848. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  849. images = list(images)
  850. assert len(images) > 0, "no image found in {}".format(infer_dir)
  851. print("Found {} inference images in total.".format(len(images)))
  852. return images
  853. def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
  854. # visualize the predict result
  855. start_idx = 0
  856. for idx, image_file in enumerate(image_list):
  857. im_bboxes_num = result['boxes_num'][idx]
  858. im_results = {}
  859. if 'boxes' in result:
  860. im_results['boxes'] = result['boxes'][start_idx:start_idx +
  861. im_bboxes_num, :]
  862. if 'masks' in result:
  863. im_results['masks'] = result['masks'][start_idx:start_idx +
  864. im_bboxes_num, :]
  865. if 'segm' in result:
  866. im_results['segm'] = result['segm'][start_idx:start_idx +
  867. im_bboxes_num, :]
  868. if 'label' in result:
  869. im_results['label'] = result['label'][start_idx:start_idx +
  870. im_bboxes_num]
  871. if 'score' in result:
  872. im_results['score'] = result['score'][start_idx:start_idx +
  873. im_bboxes_num]
  874. start_idx += im_bboxes_num
  875. im = visualize_box_mask(
  876. image_file, im_results, labels, threshold=threshold)
  877. img_name = os.path.split(image_file)[-1]
  878. if not os.path.exists(output_dir):
  879. os.makedirs(output_dir)
  880. out_path = os.path.join(output_dir, img_name)
  881. im.save(out_path, quality=95)
  882. print("save result to: " + out_path)
  883. def print_arguments(args):
  884. print('----------- Running Arguments -----------')
  885. for arg, value in sorted(vars(args).items()):
  886. print('%s: %s' % (arg, value))
  887. print('------------------------------------------')
  888. def main():
  889. deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
  890. with open(deploy_file) as f:
  891. yml_conf = yaml.safe_load(f)
  892. arch = yml_conf['arch']
  893. detector_func = 'Detector'
  894. if arch == 'SOLOv2':
  895. detector_func = 'DetectorSOLOv2'
  896. elif arch == 'PicoDet':
  897. detector_func = 'DetectorPicoDet'
  898. detector = eval(detector_func)(
  899. FLAGS.model_dir,
  900. device=FLAGS.device,
  901. run_mode=FLAGS.run_mode,
  902. batch_size=FLAGS.batch_size,
  903. trt_min_shape=FLAGS.trt_min_shape,
  904. trt_max_shape=FLAGS.trt_max_shape,
  905. trt_opt_shape=FLAGS.trt_opt_shape,
  906. trt_calib_mode=FLAGS.trt_calib_mode,
  907. cpu_threads=FLAGS.cpu_threads,
  908. enable_mkldnn=FLAGS.enable_mkldnn,
  909. enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
  910. threshold=FLAGS.threshold,
  911. output_dir=FLAGS.output_dir)
  912. # predict from video file or camera video stream
  913. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  914. detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
  915. else:
  916. # predict from image
  917. if FLAGS.image_dir is None and FLAGS.image_file is not None:
  918. assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
  919. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  920. if FLAGS.slice_infer:
  921. detector.predict_image_slice(
  922. img_list,
  923. FLAGS.slice_size,
  924. FLAGS.overlap_ratio,
  925. FLAGS.combine_method,
  926. FLAGS.match_threshold,
  927. FLAGS.match_metric,
  928. visual=FLAGS.save_images,
  929. save_results=FLAGS.save_results)
  930. else:
  931. detector.predict_image(
  932. img_list,
  933. FLAGS.run_benchmark,
  934. repeats=100,
  935. visual=FLAGS.save_images,
  936. save_results=FLAGS.save_results)
  937. if not FLAGS.run_benchmark:
  938. detector.det_times.info(average=True)
  939. else:
  940. mode = FLAGS.run_mode
  941. model_dir = FLAGS.model_dir
  942. model_info = {
  943. 'model_name': model_dir.strip('/').split('/')[-1],
  944. 'precision': mode.split('_')[-1]
  945. }
  946. bench_log(detector, img_list, model_info, name='DET')
  947. if __name__ == '__main__':
  948. paddle.enable_static()
  949. parser = argsparser()
  950. FLAGS = parser.parse_args()
  951. print_arguments(FLAGS)
  952. FLAGS.device = FLAGS.device.upper()
  953. assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU'
  954. ], "device should be CPU, GPU, XPU or NPU"
  955. assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
  956. assert not (
  957. FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True
  958. ), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'
  959. main()