infer_demo.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import cv2
  2. import numpy as np
  3. import argparse
  4. import onnxruntime as ort
  5. from pathlib import Path
  6. from tqdm import tqdm
  7. import write_json as wj
  8. class PicoDet():
  9. def __init__(self,
  10. model_pb_path,
  11. label_path,
  12. prob_threshold=0.4,
  13. iou_threshold=0.3):
  14. self.classes = list(
  15. map(lambda x: x.strip(), open(label_path, 'r').readlines()))
  16. self.num_classes = len(self.classes)
  17. self.prob_threshold = prob_threshold
  18. self.iou_threshold = iou_threshold
  19. self.mean = np.array(
  20. [103.53, 116.28, 123.675], dtype=np.float32).reshape(1, 1, 3)
  21. self.std = np.array(
  22. [57.375, 57.12, 58.395], dtype=np.float32).reshape(1, 1, 3)
  23. so = ort.SessionOptions()
  24. so.log_severity_level = 3
  25. self.net = ort.InferenceSession(model_pb_path, so)
  26. self.input_shape = (self.net.get_inputs()[0].shape[2],
  27. self.net.get_inputs()[0].shape[3])
  28. def _normalize(self, img):
  29. img = img.astype(np.float32)
  30. img = (img / 255.0 - self.mean / 255.0) / (self.std / 255.0)
  31. return img
  32. def resize_image(self, srcimg, keep_ratio=False):
  33. top, left, newh, neww = 0, 0, self.input_shape[0], self.input_shape[1]
  34. origin_shape = srcimg.shape[:2]
  35. im_scale_y = newh / float(origin_shape[0])
  36. im_scale_x = neww / float(origin_shape[1])
  37. scale_factor = np.array([[im_scale_y, im_scale_x]]).astype('float32')
  38. if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
  39. hw_scale = srcimg.shape[0] / srcimg.shape[1]
  40. if hw_scale > 1:
  41. newh, neww = self.input_shape[0], int(self.input_shape[1] /
  42. hw_scale)
  43. img = cv2.resize(
  44. srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
  45. left = int((self.input_shape[1] - neww) * 0.5)
  46. img = cv2.copyMakeBorder(
  47. img,
  48. 0,
  49. 0,
  50. left,
  51. self.input_shape[1] - neww - left,
  52. cv2.BORDER_CONSTANT,
  53. value=0) # add border
  54. else:
  55. newh, neww = int(self.input_shape[0] *
  56. hw_scale), self.input_shape[1]
  57. img = cv2.resize(
  58. srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
  59. top = int((self.input_shape[0] - newh) * 0.5)
  60. img = cv2.copyMakeBorder(
  61. img,
  62. top,
  63. self.input_shape[0] - newh - top,
  64. 0,
  65. 0,
  66. cv2.BORDER_CONSTANT,
  67. value=0)
  68. else:
  69. img = cv2.resize(
  70. srcimg, self.input_shape, interpolation=cv2.INTER_AREA)
  71. return img, scale_factor
  72. def get_color_map_list(self, num_classes):
  73. color_map = num_classes * [0, 0, 0]
  74. for i in range(0, num_classes):
  75. j = 0
  76. lab = i
  77. while lab:
  78. color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
  79. color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
  80. color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
  81. j += 1
  82. lab >>= 3
  83. color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
  84. return color_map
  85. def detect(self, srcimg):
  86. img, scale_factor = self.resize_image(srcimg)
  87. img = self._normalize(img)
  88. shape_list = []
  89. blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0)
  90. outs = self.net.run(None, {
  91. self.net.get_inputs()[0].name: blob,
  92. self.net.get_inputs()[1].name: scale_factor
  93. })
  94. outs = np.array(outs[0])
  95. expect_boxes = (outs[:, 1] > 0.5) & (outs[:, 0] > -1)
  96. np_boxes = outs[expect_boxes, :]
  97. color_list = self.get_color_map_list(self.num_classes)
  98. clsid2color = {}
  99. for i in range(np_boxes.shape[0]):
  100. classid, conf = int(np_boxes[i, 0]), np_boxes[i, 1]
  101. xmin, ymin, xmax, ymax = int(np_boxes[i, 2]), int(np_boxes[
  102. i, 3]), int(np_boxes[i, 4]), int(np_boxes[i, 5])
  103. if classid not in clsid2color:
  104. clsid2color[classid] = color_list[classid]
  105. color = tuple(clsid2color[classid])
  106. cv2.rectangle(
  107. srcimg, (xmin, ymin), (xmax, ymax), color, thickness=1)
  108. # print((xmin, ymin), (xmax, ymax), self.classes[classid] + ': ' + str(round(conf, 3)))
  109. # 此处能取到两个角点坐标,label
  110. flags = {}
  111. shape = {'label': self.classes[classid], 'points': [[xmin, ymin], [xmax, ymax]], 'group_id': None, 'shape_type': 'rectangle', 'flags': flags}
  112. shape_list.append(shape)
  113. cv2.putText(
  114. srcimg,
  115. self.classes[classid] + ':' + str(round(conf, 3)), (xmin, ymin - 10),
  116. cv2.FONT_HERSHEY_SIMPLEX,
  117. 0.8, color,
  118. thickness=2)
  119. return srcimg, shape_list
  120. def detect_folder(self, img_fold, result_path):
  121. img_fold = Path(img_fold)
  122. result_path = Path(result_path)
  123. result_path.mkdir(parents=True, exist_ok=True)
  124. img_name_list = filter(
  125. lambda x: str(x).endswith(".png") or str(x).endswith(".jpg"),
  126. img_fold.iterdir(), )
  127. img_name_list = list(img_name_list)
  128. print(f"find {len(img_name_list)} images")
  129. for img_path in tqdm(img_name_list):
  130. img = cv2.imread(str(img_path))
  131. # 获取图片宽高
  132. img_path = str(img_path).split('\\')[-1]
  133. img_path = Path(img_path)
  134. image_height = img.shape[0]
  135. image_width = img.shape[1]
  136. # 获取预测框信息
  137. srcimg, shape_list = net.detect(img)
  138. # json文件
  139. json_file_path = str(img_path)[0:-4] + '.json'
  140. # 写入json
  141. wj.make_json(img_fold, json_file_path, shape_list, str(img_path), image_height, image_width)
  142. save_path = str(result_path / img_path.name.replace(".png", ".jpg"))
  143. cv2.imwrite(save_path, srcimg)
  144. if __name__ == '__main__':
  145. parser = argparse.ArgumentParser()
  146. parser.add_argument(
  147. '--modelpath',
  148. type=str,
  149. default='picodet_l_416_lp_0904_2.onnx',
  150. help="onnx filepath")
  151. parser.add_argument(
  152. '--classfile',
  153. type=str,
  154. default='lp_label.txt',
  155. help="classname filepath")
  156. parser.add_argument(
  157. '--confThreshold', default=0.5, type=float, help='class confidence')
  158. parser.add_argument(
  159. '--nmsThreshold', default=0.6, type=float, help='nms iou thresh')
  160. parser.add_argument(
  161. "--img_fold", dest="img_fold", type=str, default="./images")
  162. parser.add_argument(
  163. "--result_fold", dest="result_fold", type=str, default="./results")
  164. args = parser.parse_args()
  165. net = PicoDet(
  166. args.modelpath,
  167. args.classfile,
  168. prob_threshold=args.confThreshold,
  169. iou_threshold=args.nmsThreshold)
  170. net.detect_folder(args.img_fold, args.result_fold)