vehicle_attr.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import yaml
  16. import glob
  17. import cv2
  18. import numpy as np
  19. import math
  20. import paddle
  21. import sys
  22. from collections import Sequence
  23. # add deploy path of PadleDetection to sys.path
  24. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 3)))
  25. sys.path.insert(0, parent_path)
  26. from paddle.inference import Config, create_predictor
  27. from python.utils import argsparser, Timer, get_current_memory_mb
  28. from python.benchmark_utils import PaddleInferBenchmark
  29. from python.infer import Detector, print_arguments
  30. from pipeline.pphuman.attr_infer import AttrDetector
  31. class VehicleAttr(AttrDetector):
  32. """
  33. Args:
  34. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  35. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  36. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  37. batch_size (int): size of pre batch in inference
  38. trt_min_shape (int): min shape for dynamic shape in trt
  39. trt_max_shape (int): max shape for dynamic shape in trt
  40. trt_opt_shape (int): opt shape for dynamic shape in trt
  41. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  42. calibration, trt_calib_mode need to set True
  43. cpu_threads (int): cpu threads
  44. enable_mkldnn (bool): whether to open MKLDNN
  45. type_threshold (float): The threshold of score for vehicle type recognition.
  46. color_threshold (float): The threshold of score for vehicle color recognition.
  47. """
  48. def __init__(self,
  49. model_dir,
  50. device='CPU',
  51. run_mode='paddle',
  52. batch_size=1,
  53. trt_min_shape=1,
  54. trt_max_shape=1280,
  55. trt_opt_shape=640,
  56. trt_calib_mode=False,
  57. cpu_threads=1,
  58. enable_mkldnn=False,
  59. output_dir='output',
  60. color_threshold=0.5,
  61. type_threshold=0.5):
  62. super(VehicleAttr, self).__init__(
  63. model_dir=model_dir,
  64. device=device,
  65. run_mode=run_mode,
  66. batch_size=batch_size,
  67. trt_min_shape=trt_min_shape,
  68. trt_max_shape=trt_max_shape,
  69. trt_opt_shape=trt_opt_shape,
  70. trt_calib_mode=trt_calib_mode,
  71. cpu_threads=cpu_threads,
  72. enable_mkldnn=enable_mkldnn,
  73. output_dir=output_dir)
  74. self.color_threshold = color_threshold
  75. self.type_threshold = type_threshold
  76. self.result_history = {}
  77. self.color_list = [
  78. "yellow", "orange", "green", "gray", "red", "blue", "white",
  79. "golden", "brown", "black"
  80. ]
  81. self.type_list = [
  82. "sedan", "suv", "van", "hatchback", "mpv", "pickup", "bus", "truck",
  83. "estate"
  84. ]
  85. @classmethod
  86. def init_with_cfg(cls, args, cfg):
  87. return cls(model_dir=cfg['model_dir'],
  88. batch_size=cfg['batch_size'],
  89. color_threshold=cfg['color_threshold'],
  90. type_threshold=cfg['type_threshold'],
  91. device=args.device,
  92. run_mode=args.run_mode,
  93. trt_min_shape=args.trt_min_shape,
  94. trt_max_shape=args.trt_max_shape,
  95. trt_opt_shape=args.trt_opt_shape,
  96. trt_calib_mode=args.trt_calib_mode,
  97. cpu_threads=args.cpu_threads,
  98. enable_mkldnn=args.enable_mkldnn)
  99. def postprocess(self, inputs, result):
  100. # postprocess output of predictor
  101. im_results = result['output']
  102. batch_res = []
  103. for res in im_results:
  104. res = res.tolist()
  105. attr_res = []
  106. color_res_str = "Color: "
  107. type_res_str = "Type: "
  108. color_idx = np.argmax(res[:10])
  109. type_idx = np.argmax(res[10:])
  110. if res[color_idx] >= self.color_threshold:
  111. color_res_str += self.color_list[color_idx]
  112. else:
  113. color_res_str += "Unknown"
  114. attr_res.append(color_res_str)
  115. if res[type_idx + 10] >= self.type_threshold:
  116. type_res_str += self.type_list[type_idx]
  117. else:
  118. type_res_str += "Unknown"
  119. attr_res.append(type_res_str)
  120. batch_res.append(attr_res)
  121. result = {'output': batch_res}
  122. return result
  123. if __name__ == '__main__':
  124. paddle.enable_static()
  125. parser = argsparser()
  126. FLAGS = parser.parse_args()
  127. print_arguments(FLAGS)
  128. FLAGS.device = FLAGS.device.upper()
  129. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  130. ], "device should be CPU, GPU or XPU"
  131. assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
  132. main()