drrg_postprocess.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/postprocess/drrg_postprocessor.py
  17. """
  18. import functools
  19. import operator
  20. import numpy as np
  21. import paddle
  22. from numpy.linalg import norm
  23. import cv2
  24. class Node:
  25. def __init__(self, ind):
  26. self.__ind = ind
  27. self.__links = set()
  28. @property
  29. def ind(self):
  30. return self.__ind
  31. @property
  32. def links(self):
  33. return set(self.__links)
  34. def add_link(self, link_node):
  35. self.__links.add(link_node)
  36. link_node.__links.add(self)
  37. def graph_propagation(edges, scores, text_comps, edge_len_thr=50.):
  38. assert edges.ndim == 2
  39. assert edges.shape[1] == 2
  40. assert edges.shape[0] == scores.shape[0]
  41. assert text_comps.ndim == 2
  42. assert isinstance(edge_len_thr, float)
  43. edges = np.sort(edges, axis=1)
  44. score_dict = {}
  45. for i, edge in enumerate(edges):
  46. if text_comps is not None:
  47. box1 = text_comps[edge[0], :8].reshape(4, 2)
  48. box2 = text_comps[edge[1], :8].reshape(4, 2)
  49. center1 = np.mean(box1, axis=0)
  50. center2 = np.mean(box2, axis=0)
  51. distance = norm(center1 - center2)
  52. if distance > edge_len_thr:
  53. scores[i] = 0
  54. if (edge[0], edge[1]) in score_dict:
  55. score_dict[edge[0], edge[1]] = 0.5 * (
  56. score_dict[edge[0], edge[1]] + scores[i])
  57. else:
  58. score_dict[edge[0], edge[1]] = scores[i]
  59. nodes = np.sort(np.unique(edges.flatten()))
  60. mapping = -1 * np.ones((np.max(nodes) + 1), dtype=np.int32)
  61. mapping[nodes] = np.arange(nodes.shape[0])
  62. order_inds = mapping[edges]
  63. vertices = [Node(node) for node in nodes]
  64. for ind in order_inds:
  65. vertices[ind[0]].add_link(vertices[ind[1]])
  66. return vertices, score_dict
  67. def connected_components(nodes, score_dict, link_thr):
  68. assert isinstance(nodes, list)
  69. assert all([isinstance(node, Node) for node in nodes])
  70. assert isinstance(score_dict, dict)
  71. assert isinstance(link_thr, float)
  72. clusters = []
  73. nodes = set(nodes)
  74. while nodes:
  75. node = nodes.pop()
  76. cluster = {node}
  77. node_queue = [node]
  78. while node_queue:
  79. node = node_queue.pop(0)
  80. neighbors = set([
  81. neighbor for neighbor in node.links if
  82. score_dict[tuple(sorted([node.ind, neighbor.ind]))] >= link_thr
  83. ])
  84. neighbors.difference_update(cluster)
  85. nodes.difference_update(neighbors)
  86. cluster.update(neighbors)
  87. node_queue.extend(neighbors)
  88. clusters.append(list(cluster))
  89. return clusters
  90. def clusters2labels(clusters, num_nodes):
  91. assert isinstance(clusters, list)
  92. assert all([isinstance(cluster, list) for cluster in clusters])
  93. assert all(
  94. [isinstance(node, Node) for cluster in clusters for node in cluster])
  95. assert isinstance(num_nodes, int)
  96. node_labels = np.zeros(num_nodes)
  97. for cluster_ind, cluster in enumerate(clusters):
  98. for node in cluster:
  99. node_labels[node.ind] = cluster_ind
  100. return node_labels
  101. def remove_single(text_comps, comp_pred_labels):
  102. assert text_comps.ndim == 2
  103. assert text_comps.shape[0] == comp_pred_labels.shape[0]
  104. single_flags = np.zeros_like(comp_pred_labels)
  105. pred_labels = np.unique(comp_pred_labels)
  106. for label in pred_labels:
  107. current_label_flag = (comp_pred_labels == label)
  108. if np.sum(current_label_flag) == 1:
  109. single_flags[np.where(current_label_flag)[0][0]] = 1
  110. keep_ind = [i for i in range(len(comp_pred_labels)) if not single_flags[i]]
  111. filtered_text_comps = text_comps[keep_ind, :]
  112. filtered_labels = comp_pred_labels[keep_ind]
  113. return filtered_text_comps, filtered_labels
  114. def norm2(point1, point2):
  115. return ((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)**0.5
  116. def min_connect_path(points):
  117. assert isinstance(points, list)
  118. assert all([isinstance(point, list) for point in points])
  119. assert all([isinstance(coord, int) for point in points for coord in point])
  120. points_queue = points.copy()
  121. shortest_path = []
  122. current_edge = [[], []]
  123. edge_dict0 = {}
  124. edge_dict1 = {}
  125. current_edge[0] = points_queue[0]
  126. current_edge[1] = points_queue[0]
  127. points_queue.remove(points_queue[0])
  128. while points_queue:
  129. for point in points_queue:
  130. length0 = norm2(point, current_edge[0])
  131. edge_dict0[length0] = [point, current_edge[0]]
  132. length1 = norm2(current_edge[1], point)
  133. edge_dict1[length1] = [current_edge[1], point]
  134. key0 = min(edge_dict0.keys())
  135. key1 = min(edge_dict1.keys())
  136. if key0 <= key1:
  137. start = edge_dict0[key0][0]
  138. end = edge_dict0[key0][1]
  139. shortest_path.insert(0, [points.index(start), points.index(end)])
  140. points_queue.remove(start)
  141. current_edge[0] = start
  142. else:
  143. start = edge_dict1[key1][0]
  144. end = edge_dict1[key1][1]
  145. shortest_path.append([points.index(start), points.index(end)])
  146. points_queue.remove(end)
  147. current_edge[1] = end
  148. edge_dict0 = {}
  149. edge_dict1 = {}
  150. shortest_path = functools.reduce(operator.concat, shortest_path)
  151. shortest_path = sorted(set(shortest_path), key=shortest_path.index)
  152. return shortest_path
  153. def in_contour(cont, point):
  154. x, y = point
  155. is_inner = cv2.pointPolygonTest(cont, (int(x), int(y)), False) > 0.5
  156. return is_inner
  157. def fix_corner(top_line, bot_line, start_box, end_box):
  158. assert isinstance(top_line, list)
  159. assert all(isinstance(point, list) for point in top_line)
  160. assert isinstance(bot_line, list)
  161. assert all(isinstance(point, list) for point in bot_line)
  162. assert start_box.shape == end_box.shape == (4, 2)
  163. contour = np.array(top_line + bot_line[::-1])
  164. start_left_mid = (start_box[0] + start_box[3]) / 2
  165. start_right_mid = (start_box[1] + start_box[2]) / 2
  166. end_left_mid = (end_box[0] + end_box[3]) / 2
  167. end_right_mid = (end_box[1] + end_box[2]) / 2
  168. if not in_contour(contour, start_left_mid):
  169. top_line.insert(0, start_box[0].tolist())
  170. bot_line.insert(0, start_box[3].tolist())
  171. elif not in_contour(contour, start_right_mid):
  172. top_line.insert(0, start_box[1].tolist())
  173. bot_line.insert(0, start_box[2].tolist())
  174. if not in_contour(contour, end_left_mid):
  175. top_line.append(end_box[0].tolist())
  176. bot_line.append(end_box[3].tolist())
  177. elif not in_contour(contour, end_right_mid):
  178. top_line.append(end_box[1].tolist())
  179. bot_line.append(end_box[2].tolist())
  180. return top_line, bot_line
  181. def comps2boundaries(text_comps, comp_pred_labels):
  182. assert text_comps.ndim == 2
  183. assert len(text_comps) == len(comp_pred_labels)
  184. boundaries = []
  185. if len(text_comps) < 1:
  186. return boundaries
  187. for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1):
  188. cluster_comp_inds = np.where(comp_pred_labels == cluster_ind)
  189. text_comp_boxes = text_comps[cluster_comp_inds, :8].reshape(
  190. (-1, 4, 2)).astype(np.int32)
  191. score = np.mean(text_comps[cluster_comp_inds, -1])
  192. if text_comp_boxes.shape[0] < 1:
  193. continue
  194. elif text_comp_boxes.shape[0] > 1:
  195. centers = np.mean(text_comp_boxes, axis=1).astype(np.int32).tolist()
  196. shortest_path = min_connect_path(centers)
  197. text_comp_boxes = text_comp_boxes[shortest_path]
  198. top_line = np.mean(
  199. text_comp_boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
  200. bot_line = np.mean(
  201. text_comp_boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
  202. top_line, bot_line = fix_corner(
  203. top_line, bot_line, text_comp_boxes[0], text_comp_boxes[-1])
  204. boundary_points = top_line + bot_line[::-1]
  205. else:
  206. top_line = text_comp_boxes[0, 0:2, :].astype(np.int32).tolist()
  207. bot_line = text_comp_boxes[0, 2:4:-1, :].astype(np.int32).tolist()
  208. boundary_points = top_line + bot_line
  209. boundary = [p for coord in boundary_points for p in coord] + [score]
  210. boundaries.append(boundary)
  211. return boundaries
  212. class DRRGPostprocess(object):
  213. """Merge text components and construct boundaries of text instances.
  214. Args:
  215. link_thr (float): The edge score threshold.
  216. """
  217. def __init__(self, link_thr, **kwargs):
  218. assert isinstance(link_thr, float)
  219. self.link_thr = link_thr
  220. def __call__(self, preds, shape_list):
  221. """
  222. Args:
  223. edges (ndarray): The edge array of shape N * 2, each row is a node
  224. index pair that makes up an edge in graph.
  225. scores (ndarray): The edge score array of shape (N,).
  226. text_comps (ndarray): The text components.
  227. Returns:
  228. List[list[float]]: The predicted boundaries of text instances.
  229. """
  230. edges, scores, text_comps = preds
  231. if edges is not None:
  232. if isinstance(edges, paddle.Tensor):
  233. edges = edges.numpy()
  234. if isinstance(scores, paddle.Tensor):
  235. scores = scores.numpy()
  236. if isinstance(text_comps, paddle.Tensor):
  237. text_comps = text_comps.numpy()
  238. assert len(edges) == len(scores)
  239. assert text_comps.ndim == 2
  240. assert text_comps.shape[1] == 9
  241. vertices, score_dict = graph_propagation(edges, scores, text_comps)
  242. clusters = connected_components(vertices, score_dict, self.link_thr)
  243. pred_labels = clusters2labels(clusters, text_comps.shape[0])
  244. text_comps, pred_labels = remove_single(text_comps, pred_labels)
  245. boundaries = comps2boundaries(text_comps, pred_labels)
  246. else:
  247. boundaries = []
  248. boundaries, scores = self.resize_boundary(
  249. boundaries, (1 / shape_list[0, 2:]).tolist()[::-1])
  250. boxes_batch = [dict(points=boundaries, scores=scores)]
  251. return boxes_batch
  252. def resize_boundary(self, boundaries, scale_factor):
  253. """Rescale boundaries via scale_factor.
  254. Args:
  255. boundaries (list[list[float]]): The boundary list. Each boundary
  256. with size 2k+1 with k>=4.
  257. scale_factor(ndarray): The scale factor of size (4,).
  258. Returns:
  259. boundaries (list[list[float]]): The scaled boundaries.
  260. """
  261. boxes = []
  262. scores = []
  263. for b in boundaries:
  264. sz = len(b)
  265. scores.append(b[-1])
  266. b = (np.array(b[:sz - 1]) *
  267. (np.tile(scale_factor[:2], int(
  268. (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
  269. boxes.append(np.array(b).reshape([-1, 2]))
  270. return boxes, scores