attr_infer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import yaml
  16. import glob
  17. from functools import reduce
  18. import cv2
  19. import numpy as np
  20. import math
  21. import paddle
  22. from paddle.inference import Config
  23. from paddle.inference import create_predictor
  24. import sys
  25. # add deploy path of PadleDetection to sys.path
  26. parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
  27. sys.path.insert(0, parent_path)
  28. from python.benchmark_utils import PaddleInferBenchmark
  29. from python.preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine
  30. from python.visualize import visualize_attr
  31. from python.utils import argsparser, Timer, get_current_memory_mb
  32. from python.infer import Detector, get_test_images, print_arguments, load_predictor
  33. from PIL import Image, ImageDraw, ImageFont
  34. class AttrDetector(Detector):
  35. """
  36. Args:
  37. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  38. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  39. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  40. batch_size (int): size of pre batch in inference
  41. trt_min_shape (int): min shape for dynamic shape in trt
  42. trt_max_shape (int): max shape for dynamic shape in trt
  43. trt_opt_shape (int): opt shape for dynamic shape in trt
  44. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  45. calibration, trt_calib_mode need to set True
  46. cpu_threads (int): cpu threads
  47. enable_mkldnn (bool): whether to open MKLDNN
  48. output_dir (str): The path of output
  49. threshold (float): The threshold of score for visualization
  50. """
  51. def __init__(
  52. self,
  53. model_dir,
  54. device='CPU',
  55. run_mode='paddle',
  56. batch_size=1,
  57. trt_min_shape=1,
  58. trt_max_shape=1280,
  59. trt_opt_shape=640,
  60. trt_calib_mode=False,
  61. cpu_threads=1,
  62. enable_mkldnn=False,
  63. output_dir='output',
  64. threshold=0.5, ):
  65. super(AttrDetector, self).__init__(
  66. model_dir=model_dir,
  67. device=device,
  68. run_mode=run_mode,
  69. batch_size=batch_size,
  70. trt_min_shape=trt_min_shape,
  71. trt_max_shape=trt_max_shape,
  72. trt_opt_shape=trt_opt_shape,
  73. trt_calib_mode=trt_calib_mode,
  74. cpu_threads=cpu_threads,
  75. enable_mkldnn=enable_mkldnn,
  76. output_dir=output_dir,
  77. threshold=threshold, )
  78. @classmethod
  79. def init_with_cfg(cls, args, cfg):
  80. return cls(model_dir=cfg['model_dir'],
  81. batch_size=cfg['batch_size'],
  82. device=args.device,
  83. run_mode=args.run_mode,
  84. trt_min_shape=args.trt_min_shape,
  85. trt_max_shape=args.trt_max_shape,
  86. trt_opt_shape=args.trt_opt_shape,
  87. trt_calib_mode=args.trt_calib_mode,
  88. cpu_threads=args.cpu_threads,
  89. enable_mkldnn=args.enable_mkldnn)
  90. def get_label(self):
  91. return self.pred_config.labels
  92. def postprocess(self, inputs, result):
  93. # postprocess output of predictor
  94. im_results = result['output']
  95. labels = self.pred_config.labels
  96. age_list = ['AgeLess18', 'Age18-60', 'AgeOver60']
  97. direct_list = ['Front', 'Side', 'Back']
  98. bag_list = ['HandBag', 'ShoulderBag', 'Backpack']
  99. upper_list = ['UpperStride', 'UpperLogo', 'UpperPlaid', 'UpperSplice']
  100. lower_list = [
  101. 'LowerStripe', 'LowerPattern', 'LongCoat', 'Trousers', 'Shorts',
  102. 'Skirt&Dress'
  103. ]
  104. glasses_threshold = 0.3
  105. hold_threshold = 0.6
  106. batch_res = []
  107. for res in im_results:
  108. res = res.tolist()
  109. label_res = []
  110. # gender
  111. gender = 'Female' if res[22] > self.threshold else 'Male'
  112. label_res.append(gender)
  113. # age
  114. age = age_list[np.argmax(res[19:22])]
  115. label_res.append(age)
  116. # direction
  117. direction = direct_list[np.argmax(res[23:])]
  118. label_res.append(direction)
  119. # glasses
  120. glasses = 'Glasses: '
  121. if res[1] > glasses_threshold:
  122. glasses += 'True'
  123. else:
  124. glasses += 'False'
  125. label_res.append(glasses)
  126. # hat
  127. hat = 'Hat: '
  128. if res[0] > self.threshold:
  129. hat += 'True'
  130. else:
  131. hat += 'False'
  132. label_res.append(hat)
  133. # hold obj
  134. hold_obj = 'HoldObjectsInFront: '
  135. if res[18] > hold_threshold:
  136. hold_obj += 'True'
  137. else:
  138. hold_obj += 'False'
  139. label_res.append(hold_obj)
  140. # bag
  141. bag = bag_list[np.argmax(res[15:18])]
  142. bag_score = res[15 + np.argmax(res[15:18])]
  143. bag_label = bag if bag_score > self.threshold else 'No bag'
  144. label_res.append(bag_label)
  145. # upper
  146. upper_label = 'Upper:'
  147. sleeve = 'LongSleeve' if res[3] > res[2] else 'ShortSleeve'
  148. upper_label += ' {}'.format(sleeve)
  149. upper_res = res[4:8]
  150. if np.max(upper_res) > self.threshold:
  151. upper_label += ' {}'.format(upper_list[np.argmax(upper_res)])
  152. label_res.append(upper_label)
  153. # lower
  154. lower_res = res[8:14]
  155. lower_label = 'Lower: '
  156. has_lower = False
  157. for i, l in enumerate(lower_res):
  158. if l > self.threshold:
  159. lower_label += ' {}'.format(lower_list[i])
  160. has_lower = True
  161. if not has_lower:
  162. lower_label += ' {}'.format(lower_list[np.argmax(lower_res)])
  163. label_res.append(lower_label)
  164. # shoe
  165. shoe = 'Boots' if res[14] > self.threshold else 'No boots'
  166. label_res.append(shoe)
  167. batch_res.append(label_res)
  168. result = {'output': batch_res}
  169. return result
  170. def predict(self, repeats=1):
  171. '''
  172. Args:
  173. repeats (int): repeats number for prediction
  174. Returns:
  175. result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
  176. matix element:[class, score, x_min, y_min, x_max, y_max]
  177. MaskRCNN's result include 'masks': np.ndarray:
  178. shape: [N, im_h, im_w]
  179. '''
  180. # model prediction
  181. for i in range(repeats):
  182. self.predictor.run()
  183. output_names = self.predictor.get_output_names()
  184. output_tensor = self.predictor.get_output_handle(output_names[0])
  185. np_output = output_tensor.copy_to_cpu()
  186. result = dict(output=np_output)
  187. return result
  188. def predict_image(self,
  189. image_list,
  190. run_benchmark=False,
  191. repeats=1,
  192. visual=True):
  193. batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
  194. results = []
  195. for i in range(batch_loop_cnt):
  196. start_index = i * self.batch_size
  197. end_index = min((i + 1) * self.batch_size, len(image_list))
  198. batch_image_list = image_list[start_index:end_index]
  199. if run_benchmark:
  200. # preprocess
  201. inputs = self.preprocess(batch_image_list) # warmup
  202. self.det_times.preprocess_time_s.start()
  203. inputs = self.preprocess(batch_image_list)
  204. self.det_times.preprocess_time_s.end()
  205. # model prediction
  206. result = self.predict(repeats=repeats) # warmup
  207. self.det_times.inference_time_s.start()
  208. result = self.predict(repeats=repeats)
  209. self.det_times.inference_time_s.end(repeats=repeats)
  210. # postprocess
  211. result_warmup = self.postprocess(inputs, result) # warmup
  212. self.det_times.postprocess_time_s.start()
  213. result = self.postprocess(inputs, result)
  214. self.det_times.postprocess_time_s.end()
  215. self.det_times.img_num += len(batch_image_list)
  216. cm, gm, gu = get_current_memory_mb()
  217. self.cpu_mem += cm
  218. self.gpu_mem += gm
  219. self.gpu_util += gu
  220. else:
  221. # preprocess
  222. self.det_times.preprocess_time_s.start()
  223. inputs = self.preprocess(batch_image_list)
  224. self.det_times.preprocess_time_s.end()
  225. # model prediction
  226. self.det_times.inference_time_s.start()
  227. result = self.predict()
  228. self.det_times.inference_time_s.end()
  229. # postprocess
  230. self.det_times.postprocess_time_s.start()
  231. result = self.postprocess(inputs, result)
  232. self.det_times.postprocess_time_s.end()
  233. self.det_times.img_num += len(batch_image_list)
  234. if visual:
  235. visualize(
  236. batch_image_list, result, output_dir=self.output_dir)
  237. results.append(result)
  238. if visual:
  239. print('Test iter {}'.format(i))
  240. results = self.merge_batch_result(results)
  241. return results
  242. def merge_batch_result(self, batch_result):
  243. if len(batch_result) == 1:
  244. return batch_result[0]
  245. res_key = batch_result[0].keys()
  246. results = {k: [] for k in res_key}
  247. for res in batch_result:
  248. for k, v in res.items():
  249. results[k].extend(v)
  250. return results
  251. def visualize(image_list, batch_res, output_dir='output'):
  252. # visualize the predict result
  253. batch_res = batch_res['output']
  254. for image_file, res in zip(image_list, batch_res):
  255. im = visualize_attr(image_file, [res])
  256. if not os.path.exists(output_dir):
  257. os.makedirs(output_dir)
  258. img_name = os.path.split(image_file)[-1]
  259. out_path = os.path.join(output_dir, img_name)
  260. cv2.imwrite(out_path, im)
  261. print("save result to: " + out_path)
  262. def main():
  263. detector = AttrDetector(
  264. FLAGS.model_dir,
  265. device=FLAGS.device,
  266. run_mode=FLAGS.run_mode,
  267. batch_size=FLAGS.batch_size,
  268. trt_min_shape=FLAGS.trt_min_shape,
  269. trt_max_shape=FLAGS.trt_max_shape,
  270. trt_opt_shape=FLAGS.trt_opt_shape,
  271. trt_calib_mode=FLAGS.trt_calib_mode,
  272. cpu_threads=FLAGS.cpu_threads,
  273. enable_mkldnn=FLAGS.enable_mkldnn,
  274. threshold=FLAGS.threshold,
  275. output_dir=FLAGS.output_dir)
  276. # predict from image
  277. if FLAGS.image_dir is None and FLAGS.image_file is not None:
  278. assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
  279. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  280. detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
  281. if not FLAGS.run_benchmark:
  282. detector.det_times.info(average=True)
  283. else:
  284. mems = {
  285. 'cpu_rss_mb': detector.cpu_mem / len(img_list),
  286. 'gpu_rss_mb': detector.gpu_mem / len(img_list),
  287. 'gpu_util': detector.gpu_util * 100 / len(img_list)
  288. }
  289. perf_info = detector.det_times.report(average=True)
  290. model_dir = FLAGS.model_dir
  291. mode = FLAGS.run_mode
  292. model_info = {
  293. 'model_name': model_dir.strip('/').split('/')[-1],
  294. 'precision': mode.split('_')[-1]
  295. }
  296. data_info = {
  297. 'batch_size': FLAGS.batch_size,
  298. 'shape': "dynamic_shape",
  299. 'data_num': perf_info['img_num']
  300. }
  301. det_log = PaddleInferBenchmark(detector.config, model_info, data_info,
  302. perf_info, mems)
  303. det_log('Attr')
  304. if __name__ == '__main__':
  305. paddle.enable_static()
  306. parser = argsparser()
  307. FLAGS = parser.parse_args()
  308. print_arguments(FLAGS)
  309. FLAGS.device = FLAGS.device.upper()
  310. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  311. ], "device should be CPU, GPU or XPU"
  312. assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
  313. main()