keypoint_infer.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import time
  16. import yaml
  17. import glob
  18. from functools import reduce
  19. from PIL import Image
  20. import cv2
  21. import math
  22. import numpy as np
  23. import paddle
  24. import sys
  25. # add deploy path of PadleDetection to sys.path
  26. parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
  27. sys.path.insert(0, parent_path)
  28. from preprocess import preprocess, NormalizeImage, Permute
  29. from keypoint_preprocess import EvalAffine, TopDownEvalAffine, expand_crop
  30. from keypoint_postprocess import HrHRNetPostProcess, HRNetPostProcess
  31. from visualize import visualize_pose
  32. from paddle.inference import Config
  33. from paddle.inference import create_predictor
  34. from utils import argsparser, Timer, get_current_memory_mb
  35. from benchmark_utils import PaddleInferBenchmark
  36. from infer import Detector, get_test_images, print_arguments
  37. # Global dictionary
  38. KEYPOINT_SUPPORT_MODELS = {
  39. 'HigherHRNet': 'keypoint_bottomup',
  40. 'HRNet': 'keypoint_topdown'
  41. }
  42. class KeyPointDetector(Detector):
  43. """
  44. Args:
  45. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  46. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  47. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  48. batch_size (int): size of pre batch in inference
  49. trt_min_shape (int): min shape for dynamic shape in trt
  50. trt_max_shape (int): max shape for dynamic shape in trt
  51. trt_opt_shape (int): opt shape for dynamic shape in trt
  52. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  53. calibration, trt_calib_mode need to set True
  54. cpu_threads (int): cpu threads
  55. enable_mkldnn (bool): whether to open MKLDNN
  56. use_dark(bool): whether to use postprocess in DarkPose
  57. """
  58. def __init__(self,
  59. model_dir,
  60. device='CPU',
  61. run_mode='paddle',
  62. batch_size=1,
  63. trt_min_shape=1,
  64. trt_max_shape=1280,
  65. trt_opt_shape=640,
  66. trt_calib_mode=False,
  67. cpu_threads=1,
  68. enable_mkldnn=False,
  69. output_dir='output',
  70. threshold=0.5,
  71. use_dark=True):
  72. super(KeyPointDetector, self).__init__(
  73. model_dir=model_dir,
  74. device=device,
  75. run_mode=run_mode,
  76. batch_size=batch_size,
  77. trt_min_shape=trt_min_shape,
  78. trt_max_shape=trt_max_shape,
  79. trt_opt_shape=trt_opt_shape,
  80. trt_calib_mode=trt_calib_mode,
  81. cpu_threads=cpu_threads,
  82. enable_mkldnn=enable_mkldnn,
  83. output_dir=output_dir,
  84. threshold=threshold, )
  85. self.use_dark = use_dark
  86. def set_config(self, model_dir):
  87. return PredictConfig_KeyPoint(model_dir)
  88. def get_person_from_rect(self, image, results):
  89. # crop the person result from image
  90. self.det_times.preprocess_time_s.start()
  91. valid_rects = results['boxes']
  92. rect_images = []
  93. new_rects = []
  94. org_rects = []
  95. for rect in valid_rects:
  96. rect_image, new_rect, org_rect = expand_crop(image, rect)
  97. if rect_image is None or rect_image.size == 0:
  98. continue
  99. rect_images.append(rect_image)
  100. new_rects.append(new_rect)
  101. org_rects.append(org_rect)
  102. self.det_times.preprocess_time_s.end()
  103. return rect_images, new_rects, org_rects
  104. def postprocess(self, inputs, result):
  105. np_heatmap = result['heatmap']
  106. np_masks = result['masks']
  107. # postprocess output of predictor
  108. if KEYPOINT_SUPPORT_MODELS[
  109. self.pred_config.arch] == 'keypoint_bottomup':
  110. results = {}
  111. h, w = inputs['im_shape'][0]
  112. preds = [np_heatmap]
  113. if np_masks is not None:
  114. preds += np_masks
  115. preds += [h, w]
  116. keypoint_postprocess = HrHRNetPostProcess()
  117. kpts, scores = keypoint_postprocess(*preds)
  118. results['keypoint'] = kpts
  119. results['score'] = scores
  120. return results
  121. elif KEYPOINT_SUPPORT_MODELS[
  122. self.pred_config.arch] == 'keypoint_topdown':
  123. results = {}
  124. imshape = inputs['im_shape'][:, ::-1]
  125. center = np.round(imshape / 2.)
  126. scale = imshape / 200.
  127. keypoint_postprocess = HRNetPostProcess(use_dark=self.use_dark)
  128. kpts, scores = keypoint_postprocess(np_heatmap, center, scale)
  129. results['keypoint'] = kpts
  130. results['score'] = scores
  131. return results
  132. else:
  133. raise ValueError("Unsupported arch: {}, expect {}".format(
  134. self.pred_config.arch, KEYPOINT_SUPPORT_MODELS))
  135. def predict(self, repeats=1):
  136. '''
  137. Args:
  138. repeats (int): repeat number for prediction
  139. Returns:
  140. results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
  141. matix element:[class, score, x_min, y_min, x_max, y_max]
  142. MaskRCNN's results include 'masks': np.ndarray:
  143. shape: [N, im_h, im_w]
  144. '''
  145. # model prediction
  146. np_heatmap, np_masks = None, None
  147. for i in range(repeats):
  148. self.predictor.run()
  149. output_names = self.predictor.get_output_names()
  150. heatmap_tensor = self.predictor.get_output_handle(output_names[0])
  151. np_heatmap = heatmap_tensor.copy_to_cpu()
  152. if self.pred_config.tagmap:
  153. masks_tensor = self.predictor.get_output_handle(output_names[1])
  154. heat_k = self.predictor.get_output_handle(output_names[2])
  155. inds_k = self.predictor.get_output_handle(output_names[3])
  156. np_masks = [
  157. masks_tensor.copy_to_cpu(), heat_k.copy_to_cpu(),
  158. inds_k.copy_to_cpu()
  159. ]
  160. result = dict(heatmap=np_heatmap, masks=np_masks)
  161. return result
  162. def predict_image(self,
  163. image_list,
  164. run_benchmark=False,
  165. repeats=1,
  166. visual=True):
  167. results = []
  168. batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
  169. for i in range(batch_loop_cnt):
  170. start_index = i * self.batch_size
  171. end_index = min((i + 1) * self.batch_size, len(image_list))
  172. batch_image_list = image_list[start_index:end_index]
  173. if run_benchmark:
  174. # preprocess
  175. inputs = self.preprocess(batch_image_list) # warmup
  176. self.det_times.preprocess_time_s.start()
  177. inputs = self.preprocess(batch_image_list)
  178. self.det_times.preprocess_time_s.end()
  179. # model prediction
  180. result_warmup = self.predict(repeats=repeats) # warmup
  181. self.det_times.inference_time_s.start()
  182. result = self.predict(repeats=repeats)
  183. self.det_times.inference_time_s.end(repeats=repeats)
  184. # postprocess
  185. result_warmup = self.postprocess(inputs, result) # warmup
  186. self.det_times.postprocess_time_s.start()
  187. result = self.postprocess(inputs, result)
  188. self.det_times.postprocess_time_s.end()
  189. self.det_times.img_num += len(batch_image_list)
  190. cm, gm, gu = get_current_memory_mb()
  191. self.cpu_mem += cm
  192. self.gpu_mem += gm
  193. self.gpu_util += gu
  194. else:
  195. # preprocess
  196. self.det_times.preprocess_time_s.start()
  197. inputs = self.preprocess(batch_image_list)
  198. self.det_times.preprocess_time_s.end()
  199. # model prediction
  200. self.det_times.inference_time_s.start()
  201. result = self.predict()
  202. self.det_times.inference_time_s.end()
  203. # postprocess
  204. self.det_times.postprocess_time_s.start()
  205. result = self.postprocess(inputs, result)
  206. self.det_times.postprocess_time_s.end()
  207. self.det_times.img_num += len(batch_image_list)
  208. if visual:
  209. if not os.path.exists(self.output_dir):
  210. os.makedirs(self.output_dir)
  211. visualize(
  212. batch_image_list,
  213. result,
  214. visual_thresh=self.threshold,
  215. save_dir=self.output_dir)
  216. results.append(result)
  217. if visual:
  218. print('Test iter {}'.format(i))
  219. results = self.merge_batch_result(results)
  220. return results
  221. def predict_video(self, video_file, camera_id):
  222. video_name = 'output.mp4'
  223. if camera_id != -1:
  224. capture = cv2.VideoCapture(camera_id)
  225. else:
  226. capture = cv2.VideoCapture(video_file)
  227. video_name = os.path.split(video_file)[-1]
  228. # Get Video info : resolution, fps, frame count
  229. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  230. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  231. fps = int(capture.get(cv2.CAP_PROP_FPS))
  232. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  233. print("fps: %d, frame_count: %d" % (fps, frame_count))
  234. if not os.path.exists(self.output_dir):
  235. os.makedirs(self.output_dir)
  236. out_path = os.path.join(self.output_dir, video_name)
  237. fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
  238. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  239. index = 1
  240. while (1):
  241. ret, frame = capture.read()
  242. if not ret:
  243. break
  244. print('detect frame: %d' % (index))
  245. index += 1
  246. results = self.predict_image([frame[:, :, ::-1]], visual=False)
  247. im_results = {}
  248. im_results['keypoint'] = [results['keypoint'], results['score']]
  249. im = visualize_pose(
  250. frame, im_results, visual_thresh=self.threshold, returnimg=True)
  251. writer.write(im)
  252. if camera_id != -1:
  253. cv2.imshow('Mask Detection', im)
  254. if cv2.waitKey(1) & 0xFF == ord('q'):
  255. break
  256. writer.release()
  257. def create_inputs(imgs, im_info):
  258. """generate input for different model type
  259. Args:
  260. imgs (list(numpy)): list of image (np.ndarray)
  261. im_info (list(dict)): list of image info
  262. Returns:
  263. inputs (dict): input of model
  264. """
  265. inputs = {}
  266. inputs['image'] = np.stack(imgs, axis=0).astype('float32')
  267. im_shape = []
  268. for e in im_info:
  269. im_shape.append(np.array((e['im_shape'])).astype('float32'))
  270. inputs['im_shape'] = np.stack(im_shape, axis=0)
  271. return inputs
  272. class PredictConfig_KeyPoint():
  273. """set config of preprocess, postprocess and visualize
  274. Args:
  275. model_dir (str): root path of model.yml
  276. """
  277. def __init__(self, model_dir):
  278. # parsing Yaml config for Preprocess
  279. deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
  280. with open(deploy_file) as f:
  281. yml_conf = yaml.safe_load(f)
  282. self.check_model(yml_conf)
  283. self.arch = yml_conf['arch']
  284. self.archcls = KEYPOINT_SUPPORT_MODELS[yml_conf['arch']]
  285. self.preprocess_infos = yml_conf['Preprocess']
  286. self.min_subgraph_size = yml_conf['min_subgraph_size']
  287. self.labels = yml_conf['label_list']
  288. self.tagmap = False
  289. self.use_dynamic_shape = yml_conf['use_dynamic_shape']
  290. if 'keypoint_bottomup' == self.archcls:
  291. self.tagmap = True
  292. self.print_config()
  293. def check_model(self, yml_conf):
  294. """
  295. Raises:
  296. ValueError: loaded model not in supported model type
  297. """
  298. for support_model in KEYPOINT_SUPPORT_MODELS:
  299. if support_model in yml_conf['arch']:
  300. return True
  301. raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
  302. 'arch'], KEYPOINT_SUPPORT_MODELS))
  303. def print_config(self):
  304. print('----------- Model Configuration -----------')
  305. print('%s: %s' % ('Model Arch', self.arch))
  306. print('%s: ' % ('Transform Order'))
  307. for op_info in self.preprocess_infos:
  308. print('--%s: %s' % ('transform op', op_info['type']))
  309. print('--------------------------------------------')
  310. def visualize(image_list, results, visual_thresh=0.6, save_dir='output'):
  311. im_results = {}
  312. for i, image_file in enumerate(image_list):
  313. skeletons = results['keypoint']
  314. scores = results['score']
  315. skeleton = skeletons[i:i + 1]
  316. score = scores[i:i + 1]
  317. im_results['keypoint'] = [skeleton, score]
  318. visualize_pose(
  319. image_file,
  320. im_results,
  321. visual_thresh=visual_thresh,
  322. save_dir=save_dir)
  323. def main():
  324. detector = KeyPointDetector(
  325. FLAGS.model_dir,
  326. device=FLAGS.device,
  327. run_mode=FLAGS.run_mode,
  328. batch_size=FLAGS.batch_size,
  329. trt_min_shape=FLAGS.trt_min_shape,
  330. trt_max_shape=FLAGS.trt_max_shape,
  331. trt_opt_shape=FLAGS.trt_opt_shape,
  332. trt_calib_mode=FLAGS.trt_calib_mode,
  333. cpu_threads=FLAGS.cpu_threads,
  334. enable_mkldnn=FLAGS.enable_mkldnn,
  335. threshold=FLAGS.threshold,
  336. output_dir=FLAGS.output_dir,
  337. use_dark=FLAGS.use_dark)
  338. # predict from video file or camera video stream
  339. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  340. detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
  341. else:
  342. # predict from image
  343. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  344. detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
  345. if not FLAGS.run_benchmark:
  346. detector.det_times.info(average=True)
  347. else:
  348. mems = {
  349. 'cpu_rss_mb': detector.cpu_mem / len(img_list),
  350. 'gpu_rss_mb': detector.gpu_mem / len(img_list),
  351. 'gpu_util': detector.gpu_util * 100 / len(img_list)
  352. }
  353. perf_info = detector.det_times.report(average=True)
  354. model_dir = FLAGS.model_dir
  355. mode = FLAGS.run_mode
  356. model_info = {
  357. 'model_name': model_dir.strip('/').split('/')[-1],
  358. 'precision': mode.split('_')[-1]
  359. }
  360. data_info = {
  361. 'batch_size': 1,
  362. 'shape': "dynamic_shape",
  363. 'data_num': perf_info['img_num']
  364. }
  365. det_log = PaddleInferBenchmark(detector.config, model_info,
  366. data_info, perf_info, mems)
  367. det_log('KeyPoint')
  368. if __name__ == '__main__':
  369. paddle.enable_static()
  370. parser = argsparser()
  371. FLAGS = parser.parse_args()
  372. print_arguments(FLAGS)
  373. FLAGS.device = FLAGS.device.upper()
  374. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  375. ], "device should be CPU, GPU or XPU"
  376. assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
  377. main()