db_postprocess.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refered from:
  16. https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import numpy as np
  22. import cv2
  23. import paddle
  24. from shapely.geometry import Polygon
  25. import pyclipper
  26. class DBPostProcess(object):
  27. """
  28. The post process for Differentiable Binarization (DB).
  29. """
  30. def __init__(self,
  31. thresh=0.3,
  32. box_thresh=0.7,
  33. max_candidates=1000,
  34. unclip_ratio=2.0,
  35. use_dilation=False,
  36. score_mode="fast",
  37. box_type='quad',
  38. **kwargs):
  39. self.thresh = thresh
  40. self.box_thresh = box_thresh
  41. self.max_candidates = max_candidates
  42. self.unclip_ratio = unclip_ratio
  43. self.min_size = 3
  44. self.score_mode = score_mode
  45. self.box_type = box_type
  46. assert score_mode in [
  47. "slow", "fast"
  48. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  49. self.dilation_kernel = None if not use_dilation else np.array(
  50. [[1, 1], [1, 1]])
  51. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  52. '''
  53. _bitmap: single map with shape (1, H, W),
  54. whose values are binarized as {0, 1}
  55. '''
  56. bitmap = _bitmap
  57. height, width = bitmap.shape
  58. boxes = []
  59. scores = []
  60. contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
  61. cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
  62. for contour in contours[:self.max_candidates]:
  63. epsilon = 0.002 * cv2.arcLength(contour, True)
  64. approx = cv2.approxPolyDP(contour, epsilon, True)
  65. points = approx.reshape((-1, 2))
  66. if points.shape[0] < 4:
  67. continue
  68. score = self.box_score_fast(pred, points.reshape(-1, 2))
  69. if self.box_thresh > score:
  70. continue
  71. if points.shape[0] > 2:
  72. box = self.unclip(points, self.unclip_ratio)
  73. if len(box) > 1:
  74. continue
  75. else:
  76. continue
  77. box = box.reshape(-1, 2)
  78. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  79. if sside < self.min_size + 2:
  80. continue
  81. box = np.array(box)
  82. box[:, 0] = np.clip(
  83. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  84. box[:, 1] = np.clip(
  85. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  86. boxes.append(box.tolist())
  87. scores.append(score)
  88. return boxes, scores
  89. def boxes_from_bitmap(self, pred, _bitmap,classes, dest_width, dest_height):
  90. '''
  91. _bitmap: single map with shape (1, H, W),
  92. whose values are binarized as {0, 1}
  93. '''
  94. bitmap = _bitmap
  95. height, width = bitmap.shape
  96. outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
  97. cv2.CHAIN_APPROX_SIMPLE)
  98. if len(outs) == 3:
  99. img, contours, _ = outs[0], outs[1], outs[2]
  100. elif len(outs) == 2:
  101. contours, _ = outs[0], outs[1]
  102. num_contours = min(len(contours), self.max_candidates)
  103. boxes = []
  104. scores = []
  105. class_indexes = []
  106. class_scores = []
  107. for index in range(num_contours):
  108. contour = contours[index]
  109. points, sside = self.get_mini_boxes(contour)
  110. if sside < self.min_size:
  111. continue
  112. points = np.array(points)
  113. if self.score_mode == "fast":
  114. score, class_index, class_score = self.box_score_fast(pred, points.reshape(-1, 2), classes)
  115. else:
  116. score, class_index, class_score = self.box_score_slow(pred, contour, classes)
  117. print("origin score:" + str(score))
  118. if self.box_thresh > score:
  119. continue
  120. box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
  121. box, sside = self.get_mini_boxes(box)
  122. if sside < self.min_size + 2:
  123. continue
  124. box = np.array(box)
  125. box[:, 0] = np.clip(
  126. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  127. box[:, 1] = np.clip(
  128. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  129. boxes.append(box.astype("int32"))
  130. scores.append(score)
  131. class_indexes.append(class_index)
  132. class_scores.append(class_score)
  133. if classes is None:
  134. return np.array(boxes, dtype="int32"), scores
  135. else:
  136. return np.array(boxes, dtype="int32"), scores, class_indexes, class_scores
  137. def unclip(self, box, unclip_ratio):
  138. poly = Polygon(box)
  139. distance = poly.area * unclip_ratio / poly.length
  140. offset = pyclipper.PyclipperOffset()
  141. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  142. expanded = np.array(offset.Execute(distance))
  143. return expanded
  144. def get_mini_boxes(self, contour):
  145. bounding_box = cv2.minAreaRect(contour)
  146. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  147. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  148. if points[1][1] > points[0][1]:
  149. index_1 = 0
  150. index_4 = 1
  151. else:
  152. index_1 = 1
  153. index_4 = 0
  154. if points[3][1] > points[2][1]:
  155. index_2 = 2
  156. index_3 = 3
  157. else:
  158. index_2 = 3
  159. index_3 = 2
  160. box = [
  161. points[index_1], points[index_2], points[index_3], points[index_4]
  162. ]
  163. return box, min(bounding_box[1])
  164. def box_score_fast(self, bitmap, _box,classes):
  165. '''
  166. box_score_fast: use bbox mean score as the mean score
  167. '''
  168. # print(classes)
  169. h, w = bitmap.shape[:2]
  170. box = _box.copy()
  171. xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
  172. xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
  173. ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
  174. ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
  175. # box__ = box.reshape(1, -1, 2)
  176. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  177. box[:, 0] = box[:, 0] - xmin
  178. box[:, 1] = box[:, 1] - ymin
  179. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
  180. if classes is None:
  181. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], None, None
  182. else:
  183. k = 255
  184. class_mask = np.full((ymax - ymin + 1, xmax - xmin + 1), k, dtype=np.int32)
  185. cv2.fillPoly(class_mask, box.reshape(1, -1, 2).astype(np.int32), 0)
  186. classes = classes[ymin:ymax + 1, xmin:xmax + 1]
  187. new_classes = classes + class_mask
  188. # 拉平
  189. a = new_classes.reshape(-1)
  190. b = np.where(a >= k)
  191. # print(len(b[0].tolist()))
  192. classes = np.delete(a, b[0].tolist())
  193. class_index = np.argmax(np.bincount(classes))
  194. print(class_index)
  195. class_score = np.sum(classes == class_index) / len(classes)
  196. print(class_score)
  197. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], class_index, class_score
  198. def box_score_slow(self, bitmap, contour,classes):
  199. '''
  200. box_score_slow: use polyon mean score as the mean score
  201. '''
  202. h, w = bitmap.shape[:2]
  203. contour = contour.copy()
  204. contour = np.reshape(contour, (-1, 2))
  205. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  206. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  207. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  208. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  209. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  210. contour[:, 0] = contour[:, 0] - xmin
  211. contour[:, 1] = contour[:, 1] - ymin
  212. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
  213. if classes is None:
  214. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], None, None
  215. else:
  216. k = 999
  217. class_mask = np.full((ymax - ymin + 1, xmax - xmin + 1), k, dtype=np.int32)
  218. cv2.fillPoly(class_mask, contour.reshape(1, -1, 2).astype("int32"), 0)
  219. classes = classes[ymin:ymax + 1, xmin:xmax + 1]
  220. new_classes = classes + class_mask
  221. # 拉平
  222. a = new_classes.reshape(-1)
  223. b = np.where(a >= k)
  224. classes = np.delete(a, b[0].tolist())
  225. class_index = np.argmax(np.bincount(classes))
  226. class_score = np.sum(classes == class_index) / len(classes)
  227. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], class_index, class_score
  228. def __call__(self, outs_dict, shape_list):
  229. pred = outs_dict['maps']
  230. if isinstance(pred, paddle.Tensor):
  231. pred = pred.numpy()
  232. pred = pred[:, 0, :, :]
  233. segmentation = pred > self.thresh
  234. print(pred.shape)
  235. if "classes" in outs_dict:
  236. classes = outs_dict['classes']
  237. # print(classes)
  238. # print("jerome1")
  239. # print(classes.shape)
  240. # print(classes)
  241. # np.set_printoptions(threshold=np.inf)
  242. if isinstance(classes, paddle.Tensor):
  243. # classes = paddle.argmax(classes, axis=1, dtype='int32')
  244. classes = classes.numpy()
  245. # else:
  246. # classes = np.argmax(classes, axis=1).astype(np.int32)
  247. classes = classes[:, 0, :, :]
  248. print(classes.shape)
  249. else:
  250. classes = None
  251. boxes_batch = []
  252. for batch_index in range(pred.shape[0]):
  253. src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
  254. if self.dilation_kernel is not None:
  255. mask = cv2.dilate(
  256. np.array(segmentation[batch_index]).astype(np.uint8),
  257. self.dilation_kernel)
  258. else:
  259. mask = segmentation[batch_index]
  260. if self.box_type == 'poly':
  261. boxes, scores = self.polygons_from_bitmap(pred[batch_index],
  262. mask, src_w, src_h)
  263. elif self.box_type == 'quad':
  264. if classes is None:
  265. boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, None,
  266. src_w, src_h)
  267. else:
  268. boxes, scores, class_indexes, class_scores = self.boxes_from_bitmap(pred[batch_index], mask,
  269. classes[batch_index],
  270. src_w, src_h)
  271. boxes_batch.append({'points': boxes, "classes": class_indexes, "class_scores": class_scores})
  272. else:
  273. raise ValueError("box_type can only be one of ['quad', 'poly']")
  274. boxes_batch.append({'points': boxes})
  275. return boxes_batch
  276. class DistillationDBPostProcess(object):
  277. def __init__(self,
  278. model_name=["student"],
  279. key=None,
  280. thresh=0.3,
  281. box_thresh=0.6,
  282. max_candidates=1000,
  283. unclip_ratio=1.5,
  284. use_dilation=False,
  285. score_mode="fast",
  286. box_type='quad',
  287. **kwargs):
  288. self.model_name = model_name
  289. self.key = key
  290. self.post_process = DBPostProcess(
  291. thresh=thresh,
  292. box_thresh=box_thresh,
  293. max_candidates=max_candidates,
  294. unclip_ratio=unclip_ratio,
  295. use_dilation=use_dilation,
  296. score_mode=score_mode,
  297. box_type=box_type)
  298. def __call__(self, predicts, shape_list):
  299. results = {}
  300. for k in self.model_name:
  301. results[k] = self.post_process(predicts[k], shape_list=shape_list)
  302. return results