onnx_infer.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import six
  20. import glob
  21. import copy
  22. import yaml
  23. import argparse
  24. import cv2
  25. import numpy as np
  26. from shapely.geometry import Polygon
  27. from onnxruntime import InferenceSession
  28. # preprocess ops
  29. def decode_image(img_path):
  30. with open(img_path, 'rb') as f:
  31. im_read = f.read()
  32. data = np.frombuffer(im_read, dtype='uint8')
  33. im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
  34. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  35. img_info = {
  36. "im_shape": np.array(
  37. im.shape[:2], dtype=np.float32),
  38. "scale_factor": np.array(
  39. [1., 1.], dtype=np.float32)
  40. }
  41. return im, img_info
  42. class Resize(object):
  43. def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
  44. if isinstance(target_size, int):
  45. target_size = [target_size, target_size]
  46. self.target_size = target_size
  47. self.keep_ratio = keep_ratio
  48. self.interp = interp
  49. def __call__(self, im, im_info):
  50. assert len(self.target_size) == 2
  51. assert self.target_size[0] > 0 and self.target_size[1] > 0
  52. im_channel = im.shape[2]
  53. im_scale_y, im_scale_x = self.generate_scale(im)
  54. im = cv2.resize(
  55. im,
  56. None,
  57. None,
  58. fx=im_scale_x,
  59. fy=im_scale_y,
  60. interpolation=self.interp)
  61. im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
  62. im_info['scale_factor'] = np.array(
  63. [im_scale_y, im_scale_x]).astype('float32')
  64. return im, im_info
  65. def generate_scale(self, im):
  66. origin_shape = im.shape[:2]
  67. im_c = im.shape[2]
  68. if self.keep_ratio:
  69. im_size_min = np.min(origin_shape)
  70. im_size_max = np.max(origin_shape)
  71. target_size_min = np.min(self.target_size)
  72. target_size_max = np.max(self.target_size)
  73. im_scale = float(target_size_min) / float(im_size_min)
  74. if np.round(im_scale * im_size_max) > target_size_max:
  75. im_scale = float(target_size_max) / float(im_size_max)
  76. im_scale_x = im_scale
  77. im_scale_y = im_scale
  78. else:
  79. resize_h, resize_w = self.target_size
  80. im_scale_y = resize_h / float(origin_shape[0])
  81. im_scale_x = resize_w / float(origin_shape[1])
  82. return im_scale_y, im_scale_x
  83. class Permute(object):
  84. def __init__(self, ):
  85. super(Permute, self).__init__()
  86. def __call__(self, im, im_info):
  87. im = im.transpose((2, 0, 1))
  88. return im, im_info
  89. class NormalizeImage(object):
  90. def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
  91. self.mean = mean
  92. self.std = std
  93. self.is_scale = is_scale
  94. self.norm_type = norm_type
  95. def __call__(self, im, im_info):
  96. im = im.astype(np.float32, copy=False)
  97. if self.is_scale:
  98. scale = 1.0 / 255.0
  99. im *= scale
  100. if self.norm_type == 'mean_std':
  101. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  102. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  103. im -= mean
  104. im /= std
  105. return im, im_info
  106. class PadStride(object):
  107. def __init__(self, stride=0):
  108. self.coarsest_stride = stride
  109. def __call__(self, im, im_info):
  110. coarsest_stride = self.coarsest_stride
  111. if coarsest_stride <= 0:
  112. return im, im_info
  113. im_c, im_h, im_w = im.shape
  114. pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
  115. pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
  116. padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
  117. padding_im[:, :im_h, :im_w] = im
  118. return padding_im, im_info
  119. class Compose:
  120. def __init__(self, transforms):
  121. self.transforms = []
  122. for op_info in transforms:
  123. new_op_info = op_info.copy()
  124. op_type = new_op_info.pop('type')
  125. self.transforms.append(eval(op_type)(**new_op_info))
  126. def __call__(self, img_path):
  127. img, im_info = decode_image(img_path)
  128. for t in self.transforms:
  129. img, im_info = t(img, im_info)
  130. inputs = copy.deepcopy(im_info)
  131. inputs['image'] = img
  132. return inputs
  133. # postprocess
  134. def rbox_iou(g, p):
  135. g = np.array(g)
  136. p = np.array(p)
  137. g = Polygon(g[:8].reshape((4, 2)))
  138. p = Polygon(p[:8].reshape((4, 2)))
  139. g = g.buffer(0)
  140. p = p.buffer(0)
  141. if not g.is_valid or not p.is_valid:
  142. return 0
  143. inter = Polygon(g).intersection(Polygon(p)).area
  144. union = g.area + p.area - inter
  145. if union == 0:
  146. return 0
  147. else:
  148. return inter / union
  149. def multiclass_nms_rotated(pred_bboxes,
  150. pred_scores,
  151. iou_threshlod=0.1,
  152. score_threshold=0.1):
  153. """
  154. Args:
  155. pred_bboxes (numpy.ndarray): [B, N, 8]
  156. pred_scores (numpy.ndarray): [B, C, N]
  157. Return:
  158. bboxes (numpy.ndarray): [N, 10]
  159. bbox_num (numpy.ndarray): [B]
  160. """
  161. bbox_num = []
  162. bboxes = []
  163. for bbox_per_img, score_per_img in zip(pred_bboxes, pred_scores):
  164. num_per_img = 0
  165. for cls_id, score_per_cls in enumerate(score_per_img):
  166. keep_mask = score_per_cls > score_threshold
  167. bbox = bbox_per_img[keep_mask]
  168. score = score_per_cls[keep_mask]
  169. idx = score.argsort()[::-1]
  170. bbox = bbox[idx]
  171. score = score[idx]
  172. keep_idx = []
  173. for i, b in enumerate(bbox):
  174. supressed = False
  175. for gi in keep_idx:
  176. g = bbox[gi]
  177. if rbox_iou(b, g) > iou_threshlod:
  178. supressed = True
  179. break
  180. if supressed:
  181. continue
  182. keep_idx.append(i)
  183. keep_box = bbox[keep_idx]
  184. keep_score = score[keep_idx]
  185. keep_cls_ids = np.ones(len(keep_idx)) * cls_id
  186. bboxes.append(
  187. np.concatenate(
  188. [keep_cls_ids[:, None], keep_score[:, None], keep_box],
  189. axis=-1))
  190. num_per_img += len(keep_idx)
  191. bbox_num.append(num_per_img)
  192. return np.concatenate(bboxes, axis=0), np.array(bbox_num)
  193. def get_test_images(infer_dir, infer_img):
  194. """
  195. Get image path list in TEST mode
  196. """
  197. assert infer_img is not None or infer_dir is not None, \
  198. "--image_file or --image_dir should be set"
  199. assert infer_img is None or os.path.isfile(infer_img), \
  200. "{} is not a file".format(infer_img)
  201. assert infer_dir is None or os.path.isdir(infer_dir), \
  202. "{} is not a directory".format(infer_dir)
  203. # infer_img has a higher priority
  204. if infer_img and os.path.isfile(infer_img):
  205. return [infer_img]
  206. images = set()
  207. infer_dir = os.path.abspath(infer_dir)
  208. assert os.path.isdir(infer_dir), \
  209. "infer_dir {} is not a directory".format(infer_dir)
  210. exts = ['jpg', 'jpeg', 'png', 'bmp']
  211. exts += [ext.upper() for ext in exts]
  212. for ext in exts:
  213. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  214. images = list(images)
  215. assert len(images) > 0, "no image found in {}".format(infer_dir)
  216. print("Found {} inference images in total.".format(len(images)))
  217. return images
  218. def predict_image(infer_config, predictor, img_list):
  219. # load preprocess transforms
  220. transforms = Compose(infer_config['Preprocess'])
  221. # predict image
  222. for img_path in img_list:
  223. inputs = transforms(img_path)
  224. inputs_name = [var.name for var in predictor.get_inputs()]
  225. inputs = {k: inputs[k][None, ] for k in inputs_name}
  226. outputs = predictor.run(output_names=None, input_feed=inputs)
  227. bboxes, bbox_num = multiclass_nms_rotated(
  228. np.array(outputs[0]), np.array(outputs[1]))
  229. print("ONNXRuntime predict: ")
  230. for bbox in bboxes:
  231. if bbox[0] > -1 and bbox[1] > infer_config['draw_threshold']:
  232. print(f"{int(bbox[0])} {bbox[1]} "
  233. f"{bbox[2]} {bbox[3]} {bbox[4]} {bbox[5]}"
  234. f"{bbox[6]} {bbox[7]} {bbox[8]} {bbox[9]}")
  235. def parse_args():
  236. parser = argparse.ArgumentParser(description=__doc__)
  237. parser.add_argument("--infer_cfg", type=str, help="infer_cfg.yml")
  238. parser.add_argument(
  239. '--onnx_file',
  240. type=str,
  241. default="model.onnx",
  242. help="onnx model file path")
  243. parser.add_argument("--image_dir", type=str)
  244. parser.add_argument("--image_file", type=str)
  245. return parser.parse_args()
  246. if __name__ == '__main__':
  247. FLAGS = parse_args()
  248. # load image list
  249. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  250. # load predictor
  251. predictor = InferenceSession(FLAGS.onnx_file)
  252. # load infer config
  253. with open(FLAGS.infer_cfg) as f:
  254. infer_config = yaml.safe_load(f)
  255. predict_image(infer_config, predictor, img_list)