fairmot_onnx_openvino.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from collections import defaultdict
  2. from pathlib import Path
  3. import cv2
  4. import numpy as np
  5. import paddle.vision.transforms as T
  6. from openvino.inference_engine import IECore
  7. from ppdet.modeling.mot.tracker import JDETracker
  8. from ppdet.modeling.mot.visualization import plot_tracking_dict
  9. root_path = Path(__file__).parent
  10. target_height = 320
  11. target_width = 576
  12. # -------------------------------
  13. def get_net():
  14. ie = IECore()
  15. model_path = root_path / "fairmot_576_320_v3.onnx"
  16. net = ie.read_network(model= str(model_path))
  17. exec_net = ie.load_network(network=net, device_name="CPU")
  18. return net, exec_net
  19. def get_output_names(net):
  20. output_names = [key for key in net.outputs]
  21. return output_names
  22. def prepare_input():
  23. transforms = [
  24. T.Resize(size=(target_height, target_width)),
  25. T.Normalize(mean=(0,0,0), std=(1,1,1), data_format='HWC', to_rgb= True),
  26. T.Transpose()
  27. ]
  28. img_file = root_path / "street.jpeg"
  29. img = cv2.imread(str(img_file))
  30. normalized_img = T.Compose(transforms)(img)
  31. normalized_img = normalized_img.astype(np.float32, copy=False) / 255.0
  32. # add an new axis in front
  33. img_input = normalized_img[np.newaxis, :]
  34. # scale_factor is calculated as: im_shape / original_im_shape
  35. h_scale = target_height / img.shape[0]
  36. w_scale = target_width / img.shape[1]
  37. input = {"image": img_input, "im_shape": [target_height, target_width], "scale_factor": [h_scale, w_scale]}
  38. return input, img
  39. def predict(exec_net, input):
  40. result = exec_net.infer(input)
  41. return result
  42. def postprocess(pred_dets, pred_embs, threshold = 0.5):
  43. tracker = JDETracker()
  44. online_targets_dict = tracker.update(pred_dets, pred_embs)
  45. online_tlwhs = defaultdict(list)
  46. online_scores = defaultdict(list)
  47. online_ids = defaultdict(list)
  48. for cls_id in range(1):
  49. online_targets = online_targets_dict[cls_id]
  50. for t in online_targets:
  51. tlwh = t.tlwh
  52. tid = t.track_id
  53. tscore = t.score
  54. # make sure the tscore is no less then the threshold.
  55. if tscore < threshold: continue
  56. # make sure the target area is not less than the min_box_area.
  57. if tlwh[2] * tlwh[3] <= tracker.min_box_area:
  58. continue
  59. # make sure the vertical ratio of a found target is within the range (1.6 as default ratio).
  60. if tracker.vertical_ratio > 0 and tlwh[2] / tlwh[3] > tracker.vertical_ratio:
  61. continue
  62. online_tlwhs[cls_id].append(tlwh)
  63. online_ids[cls_id].append(tid)
  64. online_scores[cls_id].append(tscore)
  65. online_im = plot_tracking_dict(
  66. img,
  67. 1,
  68. online_tlwhs,
  69. online_ids,
  70. online_scores,
  71. frame_id=0)
  72. return online_im
  73. # -------------------------------
  74. net, exec_net = get_net()
  75. output_names = get_output_names(net)
  76. del net
  77. input, img = prepare_input()
  78. result = predict(exec_net, input)
  79. pred_dets = result[output_names[0]]
  80. pred_embs = result[output_names[1]]
  81. processed_img = postprocess(pred_dets, pred_embs)
  82. tracked_img_file_path = root_path / "tracked.jpg"
  83. cv2.imwrite(str(tracked_img_file_path), processed_img)