gmc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/WWangYuHsiang/SMILEtrack/blob/main/BoT-SORT/tracker/gmc.py
  16. """
  17. import cv2
  18. import matplotlib.pyplot as plt
  19. import numpy as np
  20. import copy
  21. import time
  22. from ppdet.core.workspace import register, serializable
  23. @register
  24. @serializable
  25. class GMC:
  26. def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
  27. super(GMC, self).__init__()
  28. self.method = method
  29. self.downscale = max(1, int(downscale))
  30. if self.method == 'orb':
  31. self.detector = cv2.FastFeatureDetector_create(20)
  32. self.extractor = cv2.ORB_create()
  33. self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
  34. elif self.method == 'sift':
  35. self.detector = cv2.SIFT_create(
  36. nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
  37. self.extractor = cv2.SIFT_create(
  38. nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
  39. self.matcher = cv2.BFMatcher(cv2.NORM_L2)
  40. elif self.method == 'ecc':
  41. number_of_iterations = 5000
  42. termination_eps = 1e-6
  43. self.warp_mode = cv2.MOTION_EUCLIDEAN
  44. self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
  45. number_of_iterations, termination_eps)
  46. elif self.method == 'sparseOptFlow':
  47. self.feature_params = dict(
  48. maxCorners=1000,
  49. qualityLevel=0.01,
  50. minDistance=1,
  51. blockSize=3,
  52. useHarrisDetector=False,
  53. k=0.04)
  54. # self.gmc_file = open('GMC_results.txt', 'w')
  55. elif self.method == 'file' or self.method == 'files':
  56. seqName = verbose[0]
  57. ablation = verbose[1]
  58. if ablation:
  59. filePath = r'tracker/GMC_files/MOT17_ablation'
  60. else:
  61. filePath = r'tracker/GMC_files/MOTChallenge'
  62. if '-FRCNN' in seqName:
  63. seqName = seqName[:-6]
  64. elif '-DPM' in seqName:
  65. seqName = seqName[:-4]
  66. elif '-SDP' in seqName:
  67. seqName = seqName[:-4]
  68. self.gmcFile = open(filePath + "/GMC-" + seqName + ".txt", 'r')
  69. if self.gmcFile is None:
  70. raise ValueError("Error: Unable to open GMC file in directory:"
  71. + filePath)
  72. elif self.method == 'none' or self.method == 'None':
  73. self.method = 'none'
  74. else:
  75. raise ValueError("Error: Unknown CMC method:" + method)
  76. self.prevFrame = None
  77. self.prevKeyPoints = None
  78. self.prevDescriptors = None
  79. self.initializedFirstFrame = False
  80. def apply(self, raw_frame, detections=None):
  81. if self.method == 'orb' or self.method == 'sift':
  82. return self.applyFeaures(raw_frame, detections)
  83. elif self.method == 'ecc':
  84. return self.applyEcc(raw_frame, detections)
  85. elif self.method == 'sparseOptFlow':
  86. return self.applySparseOptFlow(raw_frame, detections)
  87. elif self.method == 'file':
  88. return self.applyFile(raw_frame, detections)
  89. elif self.method == 'none':
  90. return np.eye(2, 3)
  91. else:
  92. return np.eye(2, 3)
  93. def applyEcc(self, raw_frame, detections=None):
  94. # Initialize
  95. height, width, _ = raw_frame.shape
  96. frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
  97. H = np.eye(2, 3, dtype=np.float32)
  98. # Downscale image (TODO: consider using pyramids)
  99. if self.downscale > 1.0:
  100. frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
  101. frame = cv2.resize(frame, (width // self.downscale,
  102. height // self.downscale))
  103. width = width // self.downscale
  104. height = height // self.downscale
  105. # Handle first frame
  106. if not self.initializedFirstFrame:
  107. # Initialize data
  108. self.prevFrame = frame.copy()
  109. # Initialization done
  110. self.initializedFirstFrame = True
  111. return H
  112. # Run the ECC algorithm. The results are stored in warp_matrix.
  113. # (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
  114. try:
  115. (cc,
  116. H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode,
  117. self.criteria, None, 1)
  118. except:
  119. print('Warning: find transform failed. Set warp as identity')
  120. return H
  121. def applyFeaures(self, raw_frame, detections=None):
  122. # Initialize
  123. height, width, _ = raw_frame.shape
  124. frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
  125. H = np.eye(2, 3)
  126. # Downscale image (TODO: consider using pyramids)
  127. if self.downscale > 1.0:
  128. # frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
  129. frame = cv2.resize(frame, (width // self.downscale,
  130. height // self.downscale))
  131. width = width // self.downscale
  132. height = height // self.downscale
  133. # find the keypoints
  134. mask = np.zeros_like(frame)
  135. # mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
  136. mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(
  137. 0.98 * width)] = 255
  138. if detections is not None:
  139. for det in detections:
  140. tlbr = (det[:4] / self.downscale).astype(np.int_)
  141. mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
  142. keypoints = self.detector.detect(frame, mask)
  143. # compute the descriptors
  144. keypoints, descriptors = self.extractor.compute(frame, keypoints)
  145. # Handle first frame
  146. if not self.initializedFirstFrame:
  147. # Initialize data
  148. self.prevFrame = frame.copy()
  149. self.prevKeyPoints = copy.copy(keypoints)
  150. self.prevDescriptors = copy.copy(descriptors)
  151. # Initialization done
  152. self.initializedFirstFrame = True
  153. return H
  154. # Match descriptors.
  155. knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
  156. # Filtered matches based on smallest spatial distance
  157. matches = []
  158. spatialDistances = []
  159. maxSpatialDistance = 0.25 * np.array([width, height])
  160. # Handle empty matches case
  161. if len(knnMatches) == 0:
  162. # Store to next iteration
  163. self.prevFrame = frame.copy()
  164. self.prevKeyPoints = copy.copy(keypoints)
  165. self.prevDescriptors = copy.copy(descriptors)
  166. return H
  167. for m, n in knnMatches:
  168. if m.distance < 0.9 * n.distance:
  169. prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
  170. currKeyPointLocation = keypoints[m.trainIdx].pt
  171. spatialDistance = (
  172. prevKeyPointLocation[0] - currKeyPointLocation[0],
  173. prevKeyPointLocation[1] - currKeyPointLocation[1])
  174. if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
  175. (np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
  176. spatialDistances.append(spatialDistance)
  177. matches.append(m)
  178. meanSpatialDistances = np.mean(spatialDistances, 0)
  179. stdSpatialDistances = np.std(spatialDistances, 0)
  180. inliesrs = (spatialDistances - meanSpatialDistances
  181. ) < 2.5 * stdSpatialDistances
  182. goodMatches = []
  183. prevPoints = []
  184. currPoints = []
  185. for i in range(len(matches)):
  186. if inliesrs[i, 0] and inliesrs[i, 1]:
  187. goodMatches.append(matches[i])
  188. prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
  189. currPoints.append(keypoints[matches[i].trainIdx].pt)
  190. prevPoints = np.array(prevPoints)
  191. currPoints = np.array(currPoints)
  192. # Draw the keypoint matches on the output image
  193. if 0:
  194. matches_img = np.hstack((self.prevFrame, frame))
  195. matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
  196. W = np.size(self.prevFrame, 1)
  197. for m in goodMatches:
  198. prev_pt = np.array(
  199. self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
  200. curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
  201. curr_pt[0] += W
  202. color = np.random.randint(0, 255, (3, ))
  203. color = (int(color[0]), int(color[1]), int(color[2]))
  204. matches_img = cv2.line(matches_img, prev_pt, curr_pt,
  205. tuple(color), 1, cv2.LINE_AA)
  206. matches_img = cv2.circle(matches_img, prev_pt, 2,
  207. tuple(color), -1)
  208. matches_img = cv2.circle(matches_img, curr_pt, 2,
  209. tuple(color), -1)
  210. plt.figure()
  211. plt.imshow(matches_img)
  212. plt.show()
  213. # Find rigid matrix
  214. if (np.size(prevPoints, 0) > 4) and (
  215. np.size(prevPoints, 0) == np.size(prevPoints, 0)):
  216. H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints,
  217. cv2.RANSAC)
  218. # Handle downscale
  219. if self.downscale > 1.0:
  220. H[0, 2] *= self.downscale
  221. H[1, 2] *= self.downscale
  222. else:
  223. print('Warning: not enough matching points')
  224. # Store to next iteration
  225. self.prevFrame = frame.copy()
  226. self.prevKeyPoints = copy.copy(keypoints)
  227. self.prevDescriptors = copy.copy(descriptors)
  228. return H
  229. def applySparseOptFlow(self, raw_frame, detections=None):
  230. t0 = time.time()
  231. # Initialize
  232. height, width, _ = raw_frame.shape
  233. frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
  234. H = np.eye(2, 3)
  235. # Downscale image
  236. if self.downscale > 1.0:
  237. # frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
  238. frame = cv2.resize(frame, (width // self.downscale,
  239. height // self.downscale))
  240. # find the keypoints
  241. keypoints = cv2.goodFeaturesToTrack(
  242. frame, mask=None, **self.feature_params)
  243. # Handle first frame
  244. if not self.initializedFirstFrame:
  245. # Initialize data
  246. self.prevFrame = frame.copy()
  247. self.prevKeyPoints = copy.copy(keypoints)
  248. # Initialization done
  249. self.initializedFirstFrame = True
  250. return H
  251. if self.prevFrame.shape != frame.shape:
  252. self.prevFrame = frame.copy()
  253. self.prevKeyPoints = copy.copy(keypoints)
  254. return H
  255. # find correspondences
  256. matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(
  257. self.prevFrame, frame, self.prevKeyPoints, None)
  258. # leave good correspondences only
  259. prevPoints = []
  260. currPoints = []
  261. for i in range(len(status)):
  262. if status[i]:
  263. prevPoints.append(self.prevKeyPoints[i])
  264. currPoints.append(matchedKeypoints[i])
  265. prevPoints = np.array(prevPoints)
  266. currPoints = np.array(currPoints)
  267. # Find rigid matrix
  268. if (np.size(prevPoints, 0) > 4) and (
  269. np.size(prevPoints, 0) == np.size(prevPoints, 0)):
  270. H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints,
  271. cv2.RANSAC)
  272. # Handle downscale
  273. if self.downscale > 1.0:
  274. H[0, 2] *= self.downscale
  275. H[1, 2] *= self.downscale
  276. else:
  277. print('Warning: not enough matching points')
  278. # Store to next iteration
  279. self.prevFrame = frame.copy()
  280. self.prevKeyPoints = copy.copy(keypoints)
  281. t1 = time.time()
  282. # gmc_line = str(1000 * (t1 - t0)) + "\t" + str(H[0, 0]) + "\t" + str(H[0, 1]) + "\t" + str(
  283. # H[0, 2]) + "\t" + str(H[1, 0]) + "\t" + str(H[1, 1]) + "\t" + str(H[1, 2]) + "\n"
  284. # self.gmc_file.write(gmc_line)
  285. return H
  286. def applyFile(self, raw_frame, detections=None):
  287. line = self.gmcFile.readline()
  288. tokens = line.split("\t")
  289. H = np.eye(2, 3, dtype=np.float_)
  290. H[0, 0] = float(tokens[1])
  291. H[0, 1] = float(tokens[2])
  292. H[0, 2] = float(tokens[3])
  293. H[1, 0] = float(tokens[4])
  294. H[1, 1] = float(tokens[5])
  295. H[1, 2] = float(tokens[6])
  296. return H