east_postprocess.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. from .locality_aware_nms import nms_locality
  19. import cv2
  20. import paddle
  21. import os
  22. from ppocr.utils.utility import check_install
  23. import sys
  24. class EASTPostProcess(object):
  25. """
  26. The post process for EAST.
  27. """
  28. def __init__(self,
  29. score_thresh=0.8,
  30. cover_thresh=0.1,
  31. nms_thresh=0.2,
  32. **kwargs):
  33. self.score_thresh = score_thresh
  34. self.cover_thresh = cover_thresh
  35. self.nms_thresh = nms_thresh
  36. def restore_rectangle_quad(self, origin, geometry):
  37. """
  38. Restore rectangle from quadrangle.
  39. """
  40. # quad
  41. origin_concat = np.concatenate(
  42. (origin, origin, origin, origin), axis=1) # (n, 8)
  43. pred_quads = origin_concat - geometry
  44. pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
  45. return pred_quads
  46. def detect(self,
  47. score_map,
  48. geo_map,
  49. score_thresh=0.8,
  50. cover_thresh=0.1,
  51. nms_thresh=0.2):
  52. """
  53. restore text boxes from score map and geo map
  54. """
  55. score_map = score_map[0]
  56. geo_map = np.swapaxes(geo_map, 1, 0)
  57. geo_map = np.swapaxes(geo_map, 1, 2)
  58. # filter the score map
  59. xy_text = np.argwhere(score_map > score_thresh)
  60. if len(xy_text) == 0:
  61. return []
  62. # sort the text boxes via the y axis
  63. xy_text = xy_text[np.argsort(xy_text[:, 0])]
  64. #restore quad proposals
  65. text_box_restored = self.restore_rectangle_quad(
  66. xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :])
  67. boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
  68. boxes[:, :8] = text_box_restored.reshape((-1, 8))
  69. boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
  70. try:
  71. check_install('lanms', 'lanms-nova')
  72. import lanms
  73. except:
  74. print(
  75. 'You should install lanms by pip3 install lanms-nova to speed up nms_locality'
  76. )
  77. boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
  78. if boxes.shape[0] == 0:
  79. return []
  80. # Here we filter some low score boxes by the average score map,
  81. # this is different from the orginal paper.
  82. for i, box in enumerate(boxes):
  83. mask = np.zeros_like(score_map, dtype=np.uint8)
  84. cv2.fillPoly(mask, box[:8].reshape(
  85. (-1, 4, 2)).astype(np.int32) // 4, 1)
  86. boxes[i, 8] = cv2.mean(score_map, mask)[0]
  87. boxes = boxes[boxes[:, 8] > cover_thresh]
  88. return boxes
  89. def sort_poly(self, p):
  90. """
  91. Sort polygons.
  92. """
  93. min_axis = np.argmin(np.sum(p, axis=1))
  94. p = p[[min_axis, (min_axis + 1) % 4,\
  95. (min_axis + 2) % 4, (min_axis + 3) % 4]]
  96. if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
  97. return p
  98. else:
  99. return p[[0, 3, 2, 1]]
  100. def __call__(self, outs_dict, shape_list):
  101. score_list = outs_dict['f_score']
  102. geo_list = outs_dict['f_geo']
  103. if isinstance(score_list, paddle.Tensor):
  104. score_list = score_list.numpy()
  105. geo_list = geo_list.numpy()
  106. img_num = len(shape_list)
  107. dt_boxes_list = []
  108. for ino in range(img_num):
  109. score = score_list[ino]
  110. geo = geo_list[ino]
  111. boxes = self.detect(
  112. score_map=score,
  113. geo_map=geo,
  114. score_thresh=self.score_thresh,
  115. cover_thresh=self.cover_thresh,
  116. nms_thresh=self.nms_thresh)
  117. boxes_norm = []
  118. if len(boxes) > 0:
  119. h, w = score.shape[1:]
  120. src_h, src_w, ratio_h, ratio_w = shape_list[ino]
  121. boxes = boxes[:, :8].reshape((-1, 4, 2))
  122. boxes[:, :, 0] /= ratio_w
  123. boxes[:, :, 1] /= ratio_h
  124. for i_box, box in enumerate(boxes):
  125. box = self.sort_poly(box.astype(np.int32))
  126. if np.linalg.norm(box[0] - box[1]) < 5 \
  127. or np.linalg.norm(box[3] - box[0]) < 5:
  128. continue
  129. boxes_norm.append(box)
  130. dt_boxes_list.append({'points': np.array(boxes_norm)})
  131. return dt_boxes_list