utils.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. import random
  5. import cv2
  6. import numpy as np
  7. import polygon
  8. def unison_shuffled_copies(a, b):
  9. assert len(a) == len(b)
  10. p = np.random.permutation(len(a))
  11. return a[p], b[p]
  12. def intersection(a, b, img):
  13. img1 = np.zeros_like(img)
  14. cv2.fillConvexPoly(img1, a, (255, 0, 0))
  15. img1 = np.sum(img1, axis=2)
  16. img1 = img1 / 255
  17. img2 = np.zeros_like(img)
  18. cv2.fillConvexPoly(img2, b, (255, 0, 0))
  19. img2 = np.sum(img2, axis=2)
  20. img2 = img2 / 255
  21. inte = img1 * img2
  22. union = np.logical_or(img1, img2)
  23. iou = np.sum(inte) / np.sum(union)
  24. print(iou)
  25. return iou
  26. def intersection_with_correction(a, b, img):
  27. img1 = np.zeros_like(img)
  28. cv2.fillConvexPoly(img1, a, (255, 0, 0))
  29. img2 = np.zeros_like(img)
  30. cv2.fillConvexPoly(img2, b, (255, 0, 0))
  31. min_x = min(a[0][0], a[1][0], a[2][0], a[3][0])
  32. min_y = min(a[0][1], a[1][1], a[2][1], a[3][1])
  33. max_x = max(a[0][0], a[1][0], a[2][0], a[3][0])
  34. max_y = max(a[0][1], a[1][1], a[2][1], a[3][1])
  35. dst = np.array(((min_x, min_y), (max_x, min_y), (max_x, max_y), (min_x, max_y)))
  36. mat = cv2.getPerspectiveTransform(a.astype(np.float32), dst.astype(np.float32))
  37. img1 = cv2.warpPerspective(img1, mat, tuple((img.shape[0], img.shape[1])))
  38. img2 = cv2.warpPerspective(img2, mat, tuple((img.shape[0], img.shape[1])))
  39. img1 = np.sum(img1, axis=2)
  40. img1 = img1 / 255
  41. img2 = np.sum(img2, axis=2)
  42. img2 = img2 / 255
  43. inte = img1 * img2
  44. union = np.logical_or(img1, img2)
  45. iou = np.sum(inte) / np.sum(union)
  46. return iou
  47. def intersection_with_correction_smart_doc_implementation(gt, prediction, img):
  48. # Reference : https://github.com/jchazalon/smartdoc15-ch1-eval
  49. gt = sort_gt(gt)
  50. prediction = sort_gt(prediction)
  51. img1 = np.zeros_like(img)
  52. cv2.fillConvexPoly(img1, gt, (255, 0, 0))
  53. target_width = 2100
  54. target_height = 2970
  55. # Referential: (0,0) at TL, x > 0 toward right and y > 0 toward bottom
  56. # Corner order: TL, BL, BR, TR
  57. # object_coord_target = np.float32([[0, 0], [0, target_height], [target_width, target_height], [target_width, 0]])
  58. object_coord_target = np.array(np.float32([[0, 0], [target_width, 0], [target_width, target_height],[0, target_height]]))
  59. # print (gt, object_coord_target)
  60. H = cv2.getPerspectiveTransform(gt.astype(np.float32).reshape(-1, 1, 2), object_coord_target.reshape(-1, 1, 2))
  61. # 2/ Apply to test result to project in target referential
  62. test_coords = cv2.perspectiveTransform(prediction.astype(np.float32).reshape(-1, 1, 2), H)
  63. # 3/ Compute intersection between target region and test result region
  64. # poly = Polygon.Polygon([(0,0),(1,0),(0,1)])
  65. poly_target = polygon.Polygon(object_coord_target.reshape(-1, 2))
  66. poly_test = polygon.Polygon(test_coords.reshape(-1, 2))
  67. poly_inter = poly_target & poly_test
  68. area_target = poly_target.area()
  69. area_test = poly_test.area()
  70. area_inter = poly_inter.area()
  71. area_union = area_test + area_target - area_inter
  72. # Little hack to cope with float precision issues when dealing with polygons:
  73. # If intersection area is close enough to target area or GT area, but slighlty >,
  74. # then fix it, assuming it is due to rounding issues.
  75. area_min = min(area_target, area_test)
  76. if area_min < area_inter and area_min * 1.0000000001 > area_inter:
  77. area_inter = area_min
  78. print("Capping area_inter.")
  79. jaccard_index = area_inter / area_union
  80. return jaccard_index
  81. def __rotateImage(image, angle):
  82. rot_mat = cv2.getRotationMatrix2D((image.shape[1] / 2, image.shape[0] / 2), angle, 1)
  83. result = cv2.warpAffine(image, rot_mat, (image.shape[1], image.shape[0]), flags=cv2.INTER_LINEAR)
  84. return result, rot_mat
  85. def rotate(img, gt, angle):
  86. img, mat = __rotateImage(img, angle)
  87. gt = gt.astype(np.float64)
  88. for a in range(0, 4):
  89. gt[a] = np.dot(mat[..., 0:2], gt[a]) + mat[..., 2]
  90. return img, gt
  91. def random_crop(img, gt):
  92. ptr1 = (min(gt[0][0], gt[1][0], gt[2][0], gt[3][0]),
  93. min(gt[0][1], gt[1][1], gt[2][1], gt[3][1]))
  94. ptr2 = ((max(gt[0][0], gt[1][0], gt[2][0], gt[3][0]),
  95. max(gt[0][1], gt[1][1], gt[2][1], gt[3][1])))
  96. start_x = np.random.randint(0, int(max(ptr1[0] - 1, 1)))
  97. start_y = np.random.randint(0, int(max(ptr1[1] - 1, 1)))
  98. end_x = np.random.randint(int(min(ptr2[0] + 1, img.shape[1] - 1)), img.shape[1])
  99. end_y = np.random.randint(int(min(ptr2[1] + 1, img.shape[0] - 1)), img.shape[0])
  100. img = img[start_y:end_y, start_x:end_x]
  101. myGt = gt - (start_x, start_y)
  102. myGt = myGt * (1.0 / img.shape[1], 1.0 / img.shape[0])
  103. myGtTemp = myGt * myGt
  104. sum_array = myGtTemp.sum(axis=1)
  105. tl_index = np.argmin(sum_array)
  106. tl = myGt[tl_index]
  107. tr = myGt[(tl_index + 1) % 4]
  108. br = myGt[(tl_index + 2) % 4]
  109. bl = myGt[(tl_index + 3) % 4]
  110. return img, (tl, tr, br, bl)
  111. def get_corners(img, gt):
  112. gt = gt.astype(int)
  113. list_of_points = {}
  114. myGt = gt
  115. myGtTemp = myGt * myGt
  116. sum_array = myGtTemp.sum(axis=1)
  117. tl_index = np.argmin(sum_array)
  118. tl = myGt[tl_index]
  119. tr = myGt[(tl_index + 1) % 4]
  120. br = myGt[(tl_index + 2) % 4]
  121. bl = myGt[(tl_index + 3) % 4]
  122. list_of_points["tr"] = tr
  123. list_of_points["tl"] = tl
  124. list_of_points["br"] = br
  125. list_of_points["bl"] = bl
  126. gt_list = []
  127. images_list = []
  128. for k, v in list_of_points.items():
  129. if (k == "tl"):
  130. cords_x = __get_cords(v[0], 0, list_of_points["tr"][0], buf=10, size=abs(list_of_points["tr"][0] - v[0]))
  131. cords_y = __get_cords(v[1], 0, list_of_points["bl"][1], buf=10, size=abs(list_of_points["bl"][1] - v[1]))
  132. # print cords_y, cords_x
  133. gt = (v[0] - cords_x[0], v[1] - cords_y[0])
  134. cut_image = img[cords_y[0]:cords_y[1], cords_x[0]:cords_x[1]]
  135. if (k == "tr"):
  136. cords_x = __get_cords(v[0], list_of_points["tl"][0], img.shape[1], buf=10,
  137. size=abs(list_of_points["tl"][0] - v[0]))
  138. cords_y = __get_cords(v[1], 0, list_of_points["br"][1], buf=10, size=abs(list_of_points["br"][1] - v[1]))
  139. # print cords_y, cords_x
  140. gt = (v[0] - cords_x[0], v[1] - cords_y[0])
  141. cut_image = img[cords_y[0]:cords_y[1], cords_x[0]:cords_x[1]]
  142. if (k == "bl"):
  143. cords_x = __get_cords(v[0], 0, list_of_points["br"][1], buf=10,
  144. size=abs(list_of_points["br"][0] - v[0]))
  145. cords_y = __get_cords(v[1], list_of_points["tl"][1], img.shape[0], buf=10,
  146. size=abs(list_of_points["tl"][1] - v[1]))
  147. # print cords_y, cords_x
  148. gt = (v[0] - cords_x[0], v[1] - cords_y[0])
  149. cut_image = img[cords_y[0]:cords_y[1], cords_x[0]:cords_x[1]]
  150. if (k == "br"):
  151. cords_x = __get_cords(v[0], 0, list_of_points["bl"][1], buf=10,
  152. size=abs(list_of_points["bl"][0] - v[0]))
  153. cords_y = __get_cords(v[1], list_of_points["tr"][0], img.shape[0], buf=10,
  154. size=abs(list_of_points["tr"][1] - v[1]))
  155. # print cords_y, cords_x
  156. gt = (v[0] - cords_x[0], v[1] - cords_y[0])
  157. cut_image = img[cords_y[0]:cords_y[1], cords_x[0]:cords_x[1]]
  158. # cv2.circle(cut_image, gt, 2, (255, 0, 0), 6)
  159. mah_size = cut_image.shape
  160. cut_image = cv2.resize(cut_image, (300, 300))
  161. a = int(gt[0] * 300 / mah_size[1])
  162. b = int(gt[1] * 300 / mah_size[0])
  163. images_list.append(cut_image)
  164. gt_list.append((a, b))
  165. return images_list, gt_list
  166. def __get_cords(cord, min_start, max_end, size=299, buf=5, random_scale=True):
  167. # size = max(abs(cord-min_start), abs(cord-max_end))
  168. iter = 0
  169. if (random_scale):
  170. size /= random.randint(1, 4)
  171. while (max_end - min_start) < size:
  172. size = size * .9
  173. temp = -1
  174. while (temp < 1):
  175. temp = random.normalvariate(size / 2, size / 6)
  176. x_start = max(cord - temp, min_start)
  177. x_start = int(x_start)
  178. if x_start >= cord:
  179. print("XSTART AND CORD", x_start, cord)
  180. assert (x_start < cord)
  181. while ((x_start < min_start) or (x_start + size > max_end) or (x_start + size <= cord)):
  182. # x_start = random.randint(int(min(max(min_start, int(cord - size + buf)), cord - buf - 1)), cord - buf)
  183. temp = -1
  184. while (temp < 1):
  185. temp = random.normalvariate(size / 2, size / 6)
  186. temp = max(temp, 1)
  187. x_start = max(cord - temp, min_start)
  188. x_start = int(x_start)
  189. size = size * .995
  190. iter += 1
  191. if (iter == 1000):
  192. x_start = int(cord - (size / 2))
  193. print("Gets here")
  194. break
  195. assert (x_start >= 0)
  196. if x_start >= cord:
  197. print("XSTART AND CORD", x_start, cord)
  198. assert (x_start < cord)
  199. assert (x_start + size <= max_end)
  200. assert (x_start + size > cord)
  201. return (x_start, int(x_start + size))
  202. def setup_logger(path):
  203. import logging
  204. logger = logging.getLogger('iCARL')
  205. logger.setLevel(logging.DEBUG)
  206. fh = logging.FileHandler(path + ".log")
  207. fh.setLevel(logging.DEBUG)
  208. fh2 = logging.FileHandler("../temp.log")
  209. fh2.setLevel(logging.DEBUG)
  210. ch = logging.StreamHandler()
  211. ch.setLevel(logging.DEBUG)
  212. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  213. fh.setFormatter(formatter)
  214. fh2.setFormatter(formatter)
  215. logger.addHandler(fh)
  216. logger.addHandler(fh2)
  217. logger.addHandler(ch)
  218. return logger
  219. def sort_gt(gt):
  220. '''
  221. Sort the ground truth labels so that TL corresponds to the label with smallest distance from O
  222. :param gt:
  223. :return: sorted gt
  224. '''
  225. myGtTemp = gt * gt
  226. sum_array = myGtTemp.sum(axis=1)
  227. tl_index = np.argmin(sum_array)
  228. tl = gt[tl_index]
  229. tr = gt[(tl_index + 1) % 4]
  230. br = gt[(tl_index + 2) % 4]
  231. bl = gt[(tl_index + 3) % 4]
  232. return np.asarray((tl, tr, br, bl))