lane_seg_infer.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import codecs
  16. import os
  17. import yaml
  18. import numpy as np
  19. import cv2
  20. from sklearn.cluster import DBSCAN
  21. from pptracking.python.det_infer import load_predictor
  22. class LaneSegPredictor:
  23. def __init__(self, lane_seg_config, model_dir):
  24. """
  25. Prepare for prediction.
  26. The usage and docs of paddle inference, please refer to
  27. https://paddleinference.paddlepaddle.org.cn/product_introduction/summary.html
  28. """
  29. if not os.path.exists(lane_seg_config):
  30. raise ValueError("Cannot find : {},".format(lane_seg_config))
  31. args = yaml.safe_load(open(lane_seg_config))
  32. self.model_dir = model_dir
  33. self.args = args[args['type']]
  34. self.shape = None
  35. self.filter_horizontal_flag = self.args['filter_horizontal_flag']
  36. self.horizontal_filtration_degree = self.args[
  37. 'horizontal_filtration_degree']
  38. self.horizontal_filtering_threshold = self.args[
  39. 'horizontal_filtering_threshold']
  40. try:
  41. self.predictor, _ = load_predictor(
  42. model_dir=self.model_dir,
  43. run_mode=self.args['run_mode'],
  44. batch_size=self.args['batch_size'],
  45. device=self.args['device'],
  46. min_subgraph_size=self.args['min_subgraph_size'],
  47. use_dynamic_shape=self.args['use_dynamic_shape'],
  48. trt_min_shape=self.args['trt_min_shape'],
  49. trt_max_shape=self.args['trt_max_shape'],
  50. trt_opt_shape=self.args['trt_opt_shape'],
  51. trt_calib_mode=self.args['trt_calib_mode'],
  52. cpu_threads=self.args['cpu_threads'],
  53. enable_mkldnn=self.args['enable_mkldnn'])
  54. except Exception as e:
  55. print(str(e))
  56. exit()
  57. def run(self, img):
  58. input_names = self.predictor.get_input_names()
  59. input_handle = self.predictor.get_input_handle(input_names[0])
  60. output_names = self.predictor.get_output_names()
  61. output_handle = self.predictor.get_output_handle(output_names[0])
  62. img = np.array(img)
  63. self.shape = img.shape[1:3]
  64. img = self.normalize(img)
  65. img = np.transpose(img, (0, 3, 1, 2))
  66. input_handle.reshape(img.shape)
  67. input_handle.copy_from_cpu(img)
  68. self.predictor.run()
  69. results = output_handle.copy_to_cpu()
  70. results = self.postprocess(results)
  71. return self.get_line(results)
  72. def normalize(self, im, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
  73. mean = np.array(mean)[np.newaxis, np.newaxis, :]
  74. std = np.array(std)[np.newaxis, np.newaxis, :]
  75. im = im.astype(np.float32, copy=False) / 255.0
  76. im -= mean
  77. im /= std
  78. return im
  79. def postprocess(self, pred):
  80. pred = np.argmax(pred, axis=1)
  81. pred[pred == 3] = 0
  82. pred[pred > 0] = 255
  83. return pred
  84. def get_line(self, results):
  85. lines = []
  86. directions = []
  87. for i in range(results.shape[0]):
  88. line, direction = self.hough_line(np.uint8(results[i]))
  89. lines.append(line)
  90. directions.append(direction)
  91. return lines, directions
  92. def get_distance(self, array_1, array_2):
  93. lon_a = array_1[0]
  94. lat_a = array_1[1]
  95. lon_b = array_2[0]
  96. lat_b = array_2[1]
  97. s = pow(pow((lat_b - lat_a), 2) + pow((lon_b - lon_a), 2), 0.5)
  98. return s
  99. def get_angle(self, array):
  100. import math
  101. x1, y1, x2, y2 = array
  102. a_x = x2 - x1
  103. a_y = y2 - y1
  104. angle1 = math.atan2(a_y, a_x)
  105. angle1 = int(angle1 * 180 / math.pi)
  106. if angle1 > 90:
  107. angle1 = 180 - angle1
  108. return angle1
  109. def get_proportion(self, lines):
  110. proportion = 0.0
  111. h, w = self.shape
  112. for line in lines:
  113. x1, y1, x2, y2 = line
  114. length = abs(y2 - y1) / h + abs(x2 - x1) / w
  115. proportion = proportion + length
  116. return proportion
  117. def line_cluster(self, linesP):
  118. points = []
  119. for i in range(0, len(linesP)):
  120. l = linesP[i]
  121. x_center = (float(
  122. (max(l[2], l[0]) - min(l[2], l[0]))) / 2.0 + min(l[2], l[0]))
  123. y_center = (float(
  124. (max(l[3], l[1]) - min(l[3], l[1]))) / 2.0 + min(l[3], l[1]))
  125. points.append([x_center, y_center])
  126. dbscan = DBSCAN(
  127. eps=50, min_samples=2, metric=self.get_distance).fit(points)
  128. labels = dbscan.labels_
  129. n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
  130. cluster_list = list([] for i in range(n_clusters_))
  131. if linesP is not None:
  132. for i in range(0, len(linesP)):
  133. if labels[i] == -1:
  134. continue
  135. l = linesP[i]
  136. x1, y1, x2, y2 = l
  137. if y2 >= y1:
  138. cluster_list[labels[i]].append([x1, y1, x2, y2])
  139. else:
  140. ll = [x2, y2, x1, y1]
  141. cluster_list[labels[i]].append(ll)
  142. return cluster_list
  143. def hough_line(self,
  144. binary_img,
  145. min_line=50,
  146. min_line_points=50,
  147. max_line_gap=10):
  148. linesP = cv2.HoughLinesP(binary_img, 1, np.pi / 180, min_line, None,
  149. min_line_points, max_line_gap)
  150. if linesP is None:
  151. return [], None
  152. coarse_cluster_list = self.line_cluster(linesP[:, 0])
  153. filter_lines_output, direction = self.filter_lines(coarse_cluster_list)
  154. return filter_lines_output, direction
  155. def filter_lines(self, coarse_cluster_list):
  156. lines = []
  157. angles = []
  158. for i in range(len(coarse_cluster_list)):
  159. if len(coarse_cluster_list[i]) == 0:
  160. continue
  161. coarse_cluster_list[i] = np.array(coarse_cluster_list[i])
  162. distance = abs(coarse_cluster_list[i][:, 3] - coarse_cluster_list[i]
  163. [:, 1]) + abs(coarse_cluster_list[i][:, 2] -
  164. coarse_cluster_list[i][:, 0])
  165. l = coarse_cluster_list[i][np.argmax(distance)]
  166. angles.append(self.get_angle(l))
  167. lines.append(l)
  168. if len(lines) == 0:
  169. return [], None
  170. if not self.filter_horizontal_flag:
  171. return lines, None
  172. #filter horizontal roads
  173. angles = np.array(angles)
  174. max_angle, min_angle = np.max(angles), np.min(angles)
  175. if (max_angle - min_angle) < self.horizontal_filtration_degree:
  176. return lines, np.mean(angles)
  177. thr_angle = (
  178. max_angle + min_angle) * self.horizontal_filtering_threshold
  179. lines = np.array(lines)
  180. min_angle_line = lines[np.where(angles < thr_angle)]
  181. max_angle_line = lines[np.where(angles >= thr_angle)]
  182. max_angle_line_pro = self.get_proportion(max_angle_line)
  183. min_angle_line_pro = self.get_proportion(min_angle_line)
  184. if max_angle_line_pro >= min_angle_line_pro:
  185. angle_list = angles[np.where(angles >= thr_angle)]
  186. return max_angle_line, np.mean(angle_list)
  187. else:
  188. angle_list = angles[np.where(angles < thr_angle)]
  189. return min_angle_line, np.mean(angle_list)