slicebase.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # Reference: https://github.com/CAPTAIN-WHU/DOTA_devkit
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import os
  20. import math
  21. import copy
  22. from numbers import Number
  23. from multiprocessing import Pool
  24. import cv2
  25. import numpy as np
  26. from tqdm import tqdm
  27. import shapely.geometry as shgeo
  28. def choose_best_pointorder_fit_another(poly1, poly2):
  29. """
  30. To make the two polygons best fit with each point
  31. """
  32. x1, y1, x2, y2, x3, y3, x4, y4 = poly1
  33. combinate = [
  34. np.array([x1, y1, x2, y2, x3, y3, x4, y4]),
  35. np.array([x2, y2, x3, y3, x4, y4, x1, y1]),
  36. np.array([x3, y3, x4, y4, x1, y1, x2, y2]),
  37. np.array([x4, y4, x1, y1, x2, y2, x3, y3])
  38. ]
  39. dst_coordinate = np.array(poly2)
  40. distances = np.array(
  41. [np.sum((coord - dst_coordinate)**2) for coord in combinate])
  42. sorted = distances.argsort()
  43. return combinate[sorted[0]]
  44. def cal_line_length(point1, point2):
  45. return math.sqrt(
  46. math.pow(point1[0] - point2[0], 2) + math.pow(point1[1] - point2[1], 2))
  47. class SliceBase(object):
  48. def __init__(self,
  49. gap=512,
  50. subsize=1024,
  51. thresh=0.7,
  52. choosebestpoint=True,
  53. ext='.png',
  54. padding=True,
  55. num_process=8,
  56. image_only=False):
  57. self.gap = gap
  58. self.subsize = subsize
  59. self.slide = subsize - gap
  60. self.thresh = thresh
  61. self.choosebestpoint = choosebestpoint
  62. self.ext = ext
  63. self.padding = padding
  64. self.num_process = num_process
  65. self.image_only = image_only
  66. def get_windows(self, height, width):
  67. windows = []
  68. left, up = 0, 0
  69. while (left < width):
  70. if (left + self.subsize >= width):
  71. left = max(width - self.subsize, 0)
  72. up = 0
  73. while (up < height):
  74. if (up + self.subsize >= height):
  75. up = max(height - self.subsize, 0)
  76. right = min(left + self.subsize, width - 1)
  77. down = min(up + self.subsize, height - 1)
  78. windows.append((left, up, right, down))
  79. if (up + self.subsize >= height):
  80. break
  81. else:
  82. up = up + self.slide
  83. if (left + self.subsize >= width):
  84. break
  85. else:
  86. left = left + self.slide
  87. return windows
  88. def slice_image_single(self, image, windows, output_dir, output_name):
  89. image_dir = os.path.join(output_dir, 'images')
  90. for (left, up, right, down) in windows:
  91. image_name = output_name + str(left) + '___' + str(up) + self.ext
  92. subimg = copy.deepcopy(image[up:up + self.subsize, left:left +
  93. self.subsize])
  94. h, w, c = subimg.shape
  95. if (self.padding):
  96. outimg = np.zeros((self.subsize, self.subsize, 3))
  97. outimg[0:h, 0:w, :] = subimg
  98. cv2.imwrite(os.path.join(image_dir, image_name), outimg)
  99. else:
  100. cv2.imwrite(os.path.join(image_dir, image_name), subimg)
  101. def iof(self, poly1, poly2):
  102. inter_poly = poly1.intersection(poly2)
  103. inter_area = inter_poly.area
  104. poly1_area = poly1.area
  105. half_iou = inter_area / poly1_area
  106. return inter_poly, half_iou
  107. def translate(self, poly, left, up):
  108. n = len(poly)
  109. out_poly = np.zeros(n)
  110. for i in range(n // 2):
  111. out_poly[i * 2] = int(poly[i * 2] - left)
  112. out_poly[i * 2 + 1] = int(poly[i * 2 + 1] - up)
  113. return out_poly
  114. def get_poly4_from_poly5(self, poly):
  115. distances = [
  116. cal_line_length((poly[i * 2], poly[i * 2 + 1]),
  117. (poly[(i + 1) * 2], poly[(i + 1) * 2 + 1]))
  118. for i in range(int(len(poly) / 2 - 1))
  119. ]
  120. distances.append(
  121. cal_line_length((poly[0], poly[1]), (poly[8], poly[9])))
  122. pos = np.array(distances).argsort()[0]
  123. count = 0
  124. out_poly = []
  125. while count < 5:
  126. if (count == pos):
  127. out_poly.append(
  128. (poly[count * 2] + poly[(count * 2 + 2) % 10]) / 2)
  129. out_poly.append(
  130. (poly[(count * 2 + 1) % 10] + poly[(count * 2 + 3) % 10]) /
  131. 2)
  132. count = count + 1
  133. elif (count == (pos + 1) % 5):
  134. count = count + 1
  135. continue
  136. else:
  137. out_poly.append(poly[count * 2])
  138. out_poly.append(poly[count * 2 + 1])
  139. count = count + 1
  140. return out_poly
  141. def slice_anno_single(self, annos, windows, output_dir, output_name):
  142. anno_dir = os.path.join(output_dir, 'labelTxt')
  143. for (left, up, right, down) in windows:
  144. image_poly = shgeo.Polygon(
  145. [(left, up), (right, up), (right, down), (left, down)])
  146. anno_file = output_name + str(left) + '___' + str(up) + '.txt'
  147. with open(os.path.join(anno_dir, anno_file), 'w') as f:
  148. for anno in annos:
  149. gt_poly = shgeo.Polygon(
  150. [(anno['poly'][0], anno['poly'][1]),
  151. (anno['poly'][2], anno['poly'][3]),
  152. (anno['poly'][4], anno['poly'][5]),
  153. (anno['poly'][6], anno['poly'][7])])
  154. if gt_poly.area <= 0:
  155. continue
  156. inter_poly, iof = self.iof(gt_poly, image_poly)
  157. if iof == 1:
  158. final_poly = self.translate(anno['poly'], left, up)
  159. elif iof > 0:
  160. inter_poly = shgeo.polygon.orient(inter_poly, sign=1)
  161. out_poly = list(inter_poly.exterior.coords)[0:-1]
  162. if len(out_poly) < 4 or len(out_poly) > 5:
  163. continue
  164. final_poly = []
  165. for p in out_poly:
  166. final_poly.append(p[0])
  167. final_poly.append(p[1])
  168. if len(out_poly) == 5:
  169. final_poly = self.get_poly4_from_poly5(final_poly)
  170. if self.choosebestpoint:
  171. final_poly = choose_best_pointorder_fit_another(
  172. final_poly, anno['poly'])
  173. final_poly = self.translate(final_poly, left, up)
  174. final_poly = np.clip(final_poly, 1, self.subsize)
  175. else:
  176. continue
  177. outline = ' '.join(list(map(str, final_poly)))
  178. if iof >= self.thresh:
  179. outline = outline + ' ' + anno['name'] + ' ' + str(anno[
  180. 'difficult'])
  181. else:
  182. outline = outline + ' ' + anno['name'] + ' ' + '2'
  183. f.write(outline + '\n')
  184. def slice_data_single(self, info, rate, output_dir):
  185. file_name = info['image_file']
  186. base_name = os.path.splitext(os.path.split(file_name)[-1])[0]
  187. base_name = base_name + '__' + str(rate) + '__'
  188. img = cv2.imread(file_name)
  189. if img.shape == ():
  190. return
  191. if (rate != 1):
  192. resize_img = cv2.resize(
  193. img, None, fx=rate, fy=rate, interpolation=cv2.INTER_CUBIC)
  194. else:
  195. resize_img = img
  196. height, width, _ = resize_img.shape
  197. windows = self.get_windows(height, width)
  198. self.slice_image_single(resize_img, windows, output_dir, base_name)
  199. if not self.image_only:
  200. annos = info['annotation']
  201. for anno in annos:
  202. anno['poly'] = list(map(lambda x: rate * x, anno['poly']))
  203. self.slice_anno_single(annos, windows, output_dir, base_name)
  204. def check_or_mkdirs(self, path):
  205. if not os.path.exists(path):
  206. os.makedirs(path, exist_ok=True)
  207. def slice_data(self, infos, rates, output_dir):
  208. """
  209. Args:
  210. infos (list[dict]): data_infos
  211. rates (float, list): scale rates
  212. output_dir (str): output directory
  213. """
  214. if isinstance(rates, Number):
  215. rates = [rates, ]
  216. self.check_or_mkdirs(output_dir)
  217. self.check_or_mkdirs(os.path.join(output_dir, 'images'))
  218. if not self.image_only:
  219. self.check_or_mkdirs(os.path.join(output_dir, 'labelTxt'))
  220. pbar = tqdm(total=len(rates) * len(infos), desc='slicing data')
  221. if self.num_process <= 1:
  222. for rate in rates:
  223. for info in infos:
  224. self.slice_data_single(info, rate, output_dir)
  225. pbar.update()
  226. else:
  227. pool = Pool(self.num_process)
  228. for rate in rates:
  229. for info in infos:
  230. pool.apply_async(
  231. self.slice_data_single, (info, rate, output_dir),
  232. callback=lambda x: pbar.update())
  233. pool.close()
  234. pool.join()
  235. pbar.close()