ct_process.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import random
  17. import pyclipper
  18. import paddle
  19. import numpy as np
  20. from ppocr.utils.utility import check_install
  21. import scipy.io as scio
  22. from PIL import Image
  23. import paddle.vision.transforms as transforms
  24. class RandomScale():
  25. def __init__(self, short_size=640, **kwargs):
  26. self.short_size = short_size
  27. def scale_aligned(self, img, scale):
  28. oh, ow = img.shape[0:2]
  29. h = int(oh * scale + 0.5)
  30. w = int(ow * scale + 0.5)
  31. if h % 32 != 0:
  32. h = h + (32 - h % 32)
  33. if w % 32 != 0:
  34. w = w + (32 - w % 32)
  35. img = cv2.resize(img, dsize=(w, h))
  36. factor_h = h / oh
  37. factor_w = w / ow
  38. return img, factor_h, factor_w
  39. def __call__(self, data):
  40. img = data['image']
  41. h, w = img.shape[0:2]
  42. random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3])
  43. scale = (np.random.choice(random_scale) * self.short_size) / min(h, w)
  44. img, factor_h, factor_w = self.scale_aligned(img, scale)
  45. data['scale_factor'] = (factor_w, factor_h)
  46. data['image'] = img
  47. return data
  48. class MakeShrink():
  49. def __init__(self, kernel_scale=0.7, **kwargs):
  50. self.kernel_scale = kernel_scale
  51. def dist(self, a, b):
  52. return np.linalg.norm((a - b), ord=2, axis=0)
  53. def perimeter(self, bbox):
  54. peri = 0.0
  55. for i in range(bbox.shape[0]):
  56. peri += self.dist(bbox[i], bbox[(i + 1) % bbox.shape[0]])
  57. return peri
  58. def shrink(self, bboxes, rate, max_shr=20):
  59. check_install('Polygon', 'Polygon3')
  60. import Polygon as plg
  61. rate = rate * rate
  62. shrinked_bboxes = []
  63. for bbox in bboxes:
  64. area = plg.Polygon(bbox).area()
  65. peri = self.perimeter(bbox)
  66. try:
  67. pco = pyclipper.PyclipperOffset()
  68. pco.AddPath(bbox, pyclipper.JT_ROUND,
  69. pyclipper.ET_CLOSEDPOLYGON)
  70. offset = min(
  71. int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)
  72. shrinked_bbox = pco.Execute(-offset)
  73. if len(shrinked_bbox) == 0:
  74. shrinked_bboxes.append(bbox)
  75. continue
  76. shrinked_bbox = np.array(shrinked_bbox[0])
  77. if shrinked_bbox.shape[0] <= 2:
  78. shrinked_bboxes.append(bbox)
  79. continue
  80. shrinked_bboxes.append(shrinked_bbox)
  81. except Exception as e:
  82. shrinked_bboxes.append(bbox)
  83. return shrinked_bboxes
  84. def __call__(self, data):
  85. img = data['image']
  86. bboxes = data['polys']
  87. words = data['texts']
  88. scale_factor = data['scale_factor']
  89. gt_instance = np.zeros(img.shape[0:2], dtype='uint8') # h,w
  90. training_mask = np.ones(img.shape[0:2], dtype='uint8')
  91. training_mask_distance = np.ones(img.shape[0:2], dtype='uint8')
  92. for i in range(len(bboxes)):
  93. bboxes[i] = np.reshape(bboxes[i] * (
  94. [scale_factor[0], scale_factor[1]] * (bboxes[i].shape[0] // 2)),
  95. (bboxes[i].shape[0] // 2, 2)).astype('int32')
  96. for i in range(len(bboxes)):
  97. #different value for different bbox
  98. cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1)
  99. # set training mask to 0
  100. cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1)
  101. # for not accurate annotation, use training_mask_distance
  102. if words[i] == '###' or words[i] == '???':
  103. cv2.drawContours(training_mask_distance, [bboxes[i]], -1, 0, -1)
  104. # make shrink
  105. gt_kernel_instance = np.zeros(img.shape[0:2], dtype='uint8')
  106. kernel_bboxes = self.shrink(bboxes, self.kernel_scale)
  107. for i in range(len(bboxes)):
  108. cv2.drawContours(gt_kernel_instance, [kernel_bboxes[i]], -1, i + 1,
  109. -1)
  110. # for training mask, kernel and background= 1, box region=0
  111. if words[i] != '###' and words[i] != '???':
  112. cv2.drawContours(training_mask, [kernel_bboxes[i]], -1, 1, -1)
  113. gt_kernel = gt_kernel_instance.copy()
  114. # for gt_kernel, kernel = 1
  115. gt_kernel[gt_kernel > 0] = 1
  116. # shrink 2 times
  117. tmp1 = gt_kernel_instance.copy()
  118. erode_kernel = np.ones((3, 3), np.uint8)
  119. tmp1 = cv2.erode(tmp1, erode_kernel, iterations=1)
  120. tmp2 = tmp1.copy()
  121. tmp2 = cv2.erode(tmp2, erode_kernel, iterations=1)
  122. # compute text region
  123. gt_kernel_inner = tmp1 - tmp2
  124. # gt_instance: text instance, bg=0, diff word use diff value
  125. # training_mask: text instance mask, word=0,kernel and bg=1
  126. # gt_kernel_instance: text kernel instance, bg=0, diff word use diff value
  127. # gt_kernel: text_kernel, bg=0,diff word use same value
  128. # gt_kernel_inner: text kernel reference
  129. # training_mask_distance: word without anno = 0, else 1
  130. data['image'] = [
  131. img, gt_instance, training_mask, gt_kernel_instance, gt_kernel,
  132. gt_kernel_inner, training_mask_distance
  133. ]
  134. return data
  135. class GroupRandomHorizontalFlip():
  136. def __init__(self, p=0.5, **kwargs):
  137. self.p = p
  138. def __call__(self, data):
  139. imgs = data['image']
  140. if random.random() < self.p:
  141. for i in range(len(imgs)):
  142. imgs[i] = np.flip(imgs[i], axis=1).copy()
  143. data['image'] = imgs
  144. return data
  145. class GroupRandomRotate():
  146. def __init__(self, **kwargs):
  147. pass
  148. def __call__(self, data):
  149. imgs = data['image']
  150. max_angle = 10
  151. angle = random.random() * 2 * max_angle - max_angle
  152. for i in range(len(imgs)):
  153. img = imgs[i]
  154. w, h = img.shape[:2]
  155. rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
  156. img_rotation = cv2.warpAffine(
  157. img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST)
  158. imgs[i] = img_rotation
  159. data['image'] = imgs
  160. return data
  161. class GroupRandomCropPadding():
  162. def __init__(self, target_size=(640, 640), **kwargs):
  163. self.target_size = target_size
  164. def __call__(self, data):
  165. imgs = data['image']
  166. h, w = imgs[0].shape[0:2]
  167. t_w, t_h = self.target_size
  168. p_w, p_h = self.target_size
  169. if w == t_w and h == t_h:
  170. return data
  171. t_h = t_h if t_h < h else h
  172. t_w = t_w if t_w < w else w
  173. if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0:
  174. # make sure to crop the text region
  175. tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w)
  176. tl[tl < 0] = 0
  177. br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w)
  178. br[br < 0] = 0
  179. br[0] = min(br[0], h - t_h)
  180. br[1] = min(br[1], w - t_w)
  181. i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
  182. j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
  183. else:
  184. i = random.randint(0, h - t_h) if h - t_h > 0 else 0
  185. j = random.randint(0, w - t_w) if w - t_w > 0 else 0
  186. n_imgs = []
  187. for idx in range(len(imgs)):
  188. if len(imgs[idx].shape) == 3:
  189. s3_length = int(imgs[idx].shape[-1])
  190. img = imgs[idx][i:i + t_h, j:j + t_w, :]
  191. img_p = cv2.copyMakeBorder(
  192. img,
  193. 0,
  194. p_h - t_h,
  195. 0,
  196. p_w - t_w,
  197. borderType=cv2.BORDER_CONSTANT,
  198. value=tuple(0 for i in range(s3_length)))
  199. else:
  200. img = imgs[idx][i:i + t_h, j:j + t_w]
  201. img_p = cv2.copyMakeBorder(
  202. img,
  203. 0,
  204. p_h - t_h,
  205. 0,
  206. p_w - t_w,
  207. borderType=cv2.BORDER_CONSTANT,
  208. value=(0, ))
  209. n_imgs.append(img_p)
  210. data['image'] = n_imgs
  211. return data
  212. class MakeCentripetalShift():
  213. def __init__(self, **kwargs):
  214. pass
  215. def jaccard(self, As, Bs):
  216. A = As.shape[0] # small
  217. B = Bs.shape[0] # large
  218. dis = np.sqrt(
  219. np.sum((As[:, np.newaxis, :].repeat(
  220. B, axis=1) - Bs[np.newaxis, :, :].repeat(
  221. A, axis=0))**2,
  222. axis=-1))
  223. ind = np.argmin(dis, axis=-1)
  224. return ind
  225. def __call__(self, data):
  226. imgs = data['image']
  227. img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, gt_kernel_inner, training_mask_distance = \
  228. imgs[0], imgs[1], imgs[2], imgs[3], imgs[4], imgs[5], imgs[6]
  229. max_instance = np.max(gt_instance) # num bbox
  230. # make centripetal shift
  231. gt_distance = np.zeros((2, *img.shape[0:2]), dtype=np.float32)
  232. for i in range(1, max_instance + 1):
  233. # kernel_reference
  234. ind = (gt_kernel_inner == i)
  235. if np.sum(ind) == 0:
  236. training_mask[gt_instance == i] = 0
  237. training_mask_distance[gt_instance == i] = 0
  238. continue
  239. kpoints = np.array(np.where(ind)).transpose(
  240. (1, 0))[:, ::-1].astype('float32')
  241. ind = (gt_instance == i) * (gt_kernel_instance == 0)
  242. if np.sum(ind) == 0:
  243. continue
  244. pixels = np.where(ind)
  245. points = np.array(pixels).transpose(
  246. (1, 0))[:, ::-1].astype('float32')
  247. bbox_ind = self.jaccard(points, kpoints)
  248. offset_gt = kpoints[bbox_ind] - points
  249. gt_distance[:, pixels[0], pixels[1]] = offset_gt.T * 0.1
  250. img = Image.fromarray(img)
  251. img = img.convert('RGB')
  252. data["image"] = img
  253. data["gt_kernel"] = gt_kernel.astype("int64")
  254. data["training_mask"] = training_mask.astype("int64")
  255. data["gt_instance"] = gt_instance.astype("int64")
  256. data["gt_kernel_instance"] = gt_kernel_instance.astype("int64")
  257. data["training_mask_distance"] = training_mask_distance.astype("int64")
  258. data["gt_distance"] = gt_distance.astype("float32")
  259. return data
  260. class ScaleAlignedShort():
  261. def __init__(self, short_size=640, **kwargs):
  262. self.short_size = short_size
  263. def __call__(self, data):
  264. img = data['image']
  265. org_img_shape = img.shape
  266. h, w = img.shape[0:2]
  267. scale = self.short_size * 1.0 / min(h, w)
  268. h = int(h * scale + 0.5)
  269. w = int(w * scale + 0.5)
  270. if h % 32 != 0:
  271. h = h + (32 - h % 32)
  272. if w % 32 != 0:
  273. w = w + (32 - w % 32)
  274. img = cv2.resize(img, dsize=(w, h))
  275. new_img_shape = img.shape
  276. img_shape = np.array(org_img_shape + new_img_shape)
  277. data['shape'] = img_shape
  278. data['image'] = img
  279. return data