det_keypoint_unite_infer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import json
  16. import cv2
  17. import math
  18. import numpy as np
  19. import paddle
  20. import yaml
  21. from det_keypoint_unite_utils import argsparser
  22. from preprocess import decode_image
  23. from infer import Detector, DetectorPicoDet, PredictConfig, print_arguments, get_test_images, bench_log
  24. from keypoint_infer import KeyPointDetector, PredictConfig_KeyPoint
  25. from visualize import visualize_pose
  26. from benchmark_utils import PaddleInferBenchmark
  27. from utils import get_current_memory_mb
  28. from keypoint_postprocess import translate_to_ori_images
  29. KEYPOINT_SUPPORT_MODELS = {
  30. 'HigherHRNet': 'keypoint_bottomup',
  31. 'HRNet': 'keypoint_topdown'
  32. }
  33. def predict_with_given_det(image, det_res, keypoint_detector,
  34. keypoint_batch_size, run_benchmark):
  35. keypoint_res = {}
  36. rec_images, records, det_rects = keypoint_detector.get_person_from_rect(
  37. image, det_res)
  38. if len(det_rects) == 0:
  39. keypoint_res['keypoint'] = [[], []]
  40. return keypoint_res
  41. keypoint_vector = []
  42. score_vector = []
  43. rect_vector = det_rects
  44. keypoint_results = keypoint_detector.predict_image(
  45. rec_images, run_benchmark, repeats=10, visual=False)
  46. keypoint_vector, score_vector = translate_to_ori_images(keypoint_results,
  47. np.array(records))
  48. keypoint_res['keypoint'] = [
  49. keypoint_vector.tolist(), score_vector.tolist()
  50. ] if len(keypoint_vector) > 0 else [[], []]
  51. keypoint_res['bbox'] = rect_vector
  52. return keypoint_res
  53. def topdown_unite_predict(detector,
  54. topdown_keypoint_detector,
  55. image_list,
  56. keypoint_batch_size=1,
  57. save_res=False):
  58. det_timer = detector.get_timer()
  59. store_res = []
  60. for i, img_file in enumerate(image_list):
  61. # Decode image in advance in det + pose prediction
  62. det_timer.preprocess_time_s.start()
  63. image, _ = decode_image(img_file, {})
  64. det_timer.preprocess_time_s.end()
  65. if FLAGS.run_benchmark:
  66. results = detector.predict_image(
  67. [image], run_benchmark=True, repeats=10)
  68. cm, gm, gu = get_current_memory_mb()
  69. detector.cpu_mem += cm
  70. detector.gpu_mem += gm
  71. detector.gpu_util += gu
  72. else:
  73. results = detector.predict_image([image], visual=False)
  74. results = detector.filter_box(results, FLAGS.det_threshold)
  75. if results['boxes_num'] > 0:
  76. keypoint_res = predict_with_given_det(
  77. image, results, topdown_keypoint_detector, keypoint_batch_size,
  78. FLAGS.run_benchmark)
  79. if save_res:
  80. save_name = img_file if isinstance(img_file, str) else i
  81. store_res.append([
  82. save_name, keypoint_res['bbox'],
  83. [keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
  84. ])
  85. else:
  86. results["keypoint"] = [[], []]
  87. keypoint_res = results
  88. if FLAGS.run_benchmark:
  89. cm, gm, gu = get_current_memory_mb()
  90. topdown_keypoint_detector.cpu_mem += cm
  91. topdown_keypoint_detector.gpu_mem += gm
  92. topdown_keypoint_detector.gpu_util += gu
  93. else:
  94. if not os.path.exists(FLAGS.output_dir):
  95. os.makedirs(FLAGS.output_dir)
  96. visualize_pose(
  97. img_file,
  98. keypoint_res,
  99. visual_thresh=FLAGS.keypoint_threshold,
  100. save_dir=FLAGS.output_dir)
  101. if save_res:
  102. """
  103. 1) store_res: a list of image_data
  104. 2) image_data: [imageid, rects, [keypoints, scores]]
  105. 3) rects: list of rect [xmin, ymin, xmax, ymax]
  106. 4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
  107. 5) scores: mean of all joint conf
  108. """
  109. with open("det_keypoint_unite_image_results.json", 'w') as wf:
  110. json.dump(store_res, wf, indent=4)
  111. def topdown_unite_predict_video(detector,
  112. topdown_keypoint_detector,
  113. camera_id,
  114. keypoint_batch_size=1,
  115. save_res=False):
  116. video_name = 'output.mp4'
  117. if camera_id != -1:
  118. capture = cv2.VideoCapture(camera_id)
  119. else:
  120. capture = cv2.VideoCapture(FLAGS.video_file)
  121. video_name = os.path.split(FLAGS.video_file)[-1]
  122. # Get Video info : resolution, fps, frame count
  123. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  124. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  125. fps = int(capture.get(cv2.CAP_PROP_FPS))
  126. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  127. print("fps: %d, frame_count: %d" % (fps, frame_count))
  128. if not os.path.exists(FLAGS.output_dir):
  129. os.makedirs(FLAGS.output_dir)
  130. out_path = os.path.join(FLAGS.output_dir, video_name)
  131. fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
  132. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  133. index = 0
  134. store_res = []
  135. keypoint_smoothing = KeypointSmoothing(
  136. width, height, filter_type=FLAGS.filter_type, beta=0.05)
  137. while (1):
  138. ret, frame = capture.read()
  139. if not ret:
  140. break
  141. index += 1
  142. print('detect frame: %d' % (index))
  143. frame2 = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  144. results = detector.predict_image([frame2], visual=False)
  145. results = detector.filter_box(results, FLAGS.det_threshold)
  146. if results['boxes_num'] == 0:
  147. writer.write(frame)
  148. continue
  149. keypoint_res = predict_with_given_det(
  150. frame2, results, topdown_keypoint_detector, keypoint_batch_size,
  151. FLAGS.run_benchmark)
  152. if FLAGS.smooth and len(keypoint_res['keypoint'][0]) == 1:
  153. current_keypoints = np.array(keypoint_res['keypoint'][0][0])
  154. smooth_keypoints = keypoint_smoothing.smooth_process(
  155. current_keypoints)
  156. keypoint_res['keypoint'][0][0] = smooth_keypoints.tolist()
  157. im = visualize_pose(
  158. frame,
  159. keypoint_res,
  160. visual_thresh=FLAGS.keypoint_threshold,
  161. returnimg=True)
  162. if save_res:
  163. store_res.append([
  164. index, keypoint_res['bbox'],
  165. [keypoint_res['keypoint'][0], keypoint_res['keypoint'][1]]
  166. ])
  167. writer.write(im)
  168. if camera_id != -1:
  169. cv2.imshow('Mask Detection', im)
  170. if cv2.waitKey(1) & 0xFF == ord('q'):
  171. break
  172. writer.release()
  173. print('output_video saved to: {}'.format(out_path))
  174. if save_res:
  175. """
  176. 1) store_res: a list of frame_data
  177. 2) frame_data: [frameid, rects, [keypoints, scores]]
  178. 3) rects: list of rect [xmin, ymin, xmax, ymax]
  179. 4) keypoints: 17(joint numbers)*[x, y, conf], total 51 data in list
  180. 5) scores: mean of all joint conf
  181. """
  182. with open("det_keypoint_unite_video_results.json", 'w') as wf:
  183. json.dump(store_res, wf, indent=4)
  184. class KeypointSmoothing(object):
  185. # The following code are modified from:
  186. # https://github.com/jaantollander/OneEuroFilter
  187. def __init__(self,
  188. width,
  189. height,
  190. filter_type,
  191. alpha=0.5,
  192. fc_d=0.1,
  193. fc_min=0.1,
  194. beta=0.1,
  195. thres_mult=0.3):
  196. super(KeypointSmoothing, self).__init__()
  197. self.image_width = width
  198. self.image_height = height
  199. self.threshold = np.array([
  200. 0.005, 0.005, 0.005, 0.005, 0.005, 0.01, 0.01, 0.01, 0.01, 0.01,
  201. 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01
  202. ]) * thres_mult
  203. self.filter_type = filter_type
  204. self.alpha = alpha
  205. self.dx_prev_hat = None
  206. self.x_prev_hat = None
  207. self.fc_d = fc_d
  208. self.fc_min = fc_min
  209. self.beta = beta
  210. if self.filter_type == 'OneEuro':
  211. self.smooth_func = self.one_euro_filter
  212. elif self.filter_type == 'EMA':
  213. self.smooth_func = self.ema_filter
  214. else:
  215. raise ValueError('filter type must be one_euro or ema')
  216. def smooth_process(self, current_keypoints):
  217. if self.x_prev_hat is None:
  218. self.x_prev_hat = current_keypoints[:, :2]
  219. self.dx_prev_hat = np.zeros(current_keypoints[:, :2].shape)
  220. return current_keypoints
  221. else:
  222. result = current_keypoints
  223. num_keypoints = len(current_keypoints)
  224. for i in range(num_keypoints):
  225. result[i, :2] = self.smooth(current_keypoints[i, :2],
  226. self.threshold[i], i)
  227. return result
  228. def smooth(self, current_keypoint, threshold, index):
  229. distance = np.sqrt(
  230. np.square((current_keypoint[0] - self.x_prev_hat[index][0]) /
  231. self.image_width) + np.square((current_keypoint[
  232. 1] - self.x_prev_hat[index][1]) / self.image_height))
  233. if distance < threshold:
  234. result = self.x_prev_hat[index]
  235. else:
  236. result = self.smooth_func(current_keypoint, self.x_prev_hat[index],
  237. index)
  238. return result
  239. def one_euro_filter(self, x_cur, x_pre, index):
  240. te = 1
  241. self.alpha = self.smoothing_factor(te, self.fc_d)
  242. dx_cur = (x_cur - x_pre) / te
  243. dx_cur_hat = self.exponential_smoothing(dx_cur, self.dx_prev_hat[index])
  244. fc = self.fc_min + self.beta * np.abs(dx_cur_hat)
  245. self.alpha = self.smoothing_factor(te, fc)
  246. x_cur_hat = self.exponential_smoothing(x_cur, x_pre)
  247. self.dx_prev_hat[index] = dx_cur_hat
  248. self.x_prev_hat[index] = x_cur_hat
  249. return x_cur_hat
  250. def ema_filter(self, x_cur, x_pre, index):
  251. x_cur_hat = self.exponential_smoothing(x_cur, x_pre)
  252. self.x_prev_hat[index] = x_cur_hat
  253. return x_cur_hat
  254. def smoothing_factor(self, te, fc):
  255. r = 2 * math.pi * fc * te
  256. return r / (r + 1)
  257. def exponential_smoothing(self, x_cur, x_pre, index=0):
  258. return self.alpha * x_cur + (1 - self.alpha) * x_pre
  259. def main():
  260. deploy_file = os.path.join(FLAGS.det_model_dir, 'infer_cfg.yml')
  261. with open(deploy_file) as f:
  262. yml_conf = yaml.safe_load(f)
  263. arch = yml_conf['arch']
  264. detector_func = 'Detector'
  265. if arch == 'PicoDet':
  266. detector_func = 'DetectorPicoDet'
  267. detector = eval(detector_func)(FLAGS.det_model_dir,
  268. device=FLAGS.device,
  269. run_mode=FLAGS.run_mode,
  270. trt_min_shape=FLAGS.trt_min_shape,
  271. trt_max_shape=FLAGS.trt_max_shape,
  272. trt_opt_shape=FLAGS.trt_opt_shape,
  273. trt_calib_mode=FLAGS.trt_calib_mode,
  274. cpu_threads=FLAGS.cpu_threads,
  275. enable_mkldnn=FLAGS.enable_mkldnn,
  276. threshold=FLAGS.det_threshold)
  277. topdown_keypoint_detector = KeyPointDetector(
  278. FLAGS.keypoint_model_dir,
  279. device=FLAGS.device,
  280. run_mode=FLAGS.run_mode,
  281. batch_size=FLAGS.keypoint_batch_size,
  282. trt_min_shape=FLAGS.trt_min_shape,
  283. trt_max_shape=FLAGS.trt_max_shape,
  284. trt_opt_shape=FLAGS.trt_opt_shape,
  285. trt_calib_mode=FLAGS.trt_calib_mode,
  286. cpu_threads=FLAGS.cpu_threads,
  287. enable_mkldnn=FLAGS.enable_mkldnn,
  288. use_dark=FLAGS.use_dark)
  289. keypoint_arch = topdown_keypoint_detector.pred_config.arch
  290. assert KEYPOINT_SUPPORT_MODELS[
  291. keypoint_arch] == 'keypoint_topdown', 'Detection-Keypoint unite inference only supports topdown models.'
  292. # predict from video file or camera video stream
  293. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  294. topdown_unite_predict_video(detector, topdown_keypoint_detector,
  295. FLAGS.camera_id, FLAGS.keypoint_batch_size,
  296. FLAGS.save_res)
  297. else:
  298. # predict from image
  299. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  300. topdown_unite_predict(detector, topdown_keypoint_detector, img_list,
  301. FLAGS.keypoint_batch_size, FLAGS.save_res)
  302. if not FLAGS.run_benchmark:
  303. detector.det_times.info(average=True)
  304. topdown_keypoint_detector.det_times.info(average=True)
  305. else:
  306. mode = FLAGS.run_mode
  307. det_model_dir = FLAGS.det_model_dir
  308. det_model_info = {
  309. 'model_name': det_model_dir.strip('/').split('/')[-1],
  310. 'precision': mode.split('_')[-1]
  311. }
  312. bench_log(detector, img_list, det_model_info, name='Det')
  313. keypoint_model_dir = FLAGS.keypoint_model_dir
  314. keypoint_model_info = {
  315. 'model_name': keypoint_model_dir.strip('/').split('/')[-1],
  316. 'precision': mode.split('_')[-1]
  317. }
  318. bench_log(topdown_keypoint_detector, img_list, keypoint_model_info,
  319. FLAGS.keypoint_batch_size, 'KeyPoint')
  320. if __name__ == '__main__':
  321. paddle.enable_static()
  322. parser = argsparser()
  323. FLAGS = parser.parse_args()
  324. print_arguments(FLAGS)
  325. FLAGS.device = FLAGS.device.upper()
  326. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  327. ], "device should be CPU, GPU or XPU"
  328. main()