corner_refiner.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. import numpy as np
  5. import torch
  6. from PIL import Image
  7. from torchvision import transforms
  8. import model
  9. class corner_finder():
  10. def __init__(self, CHECKPOINT_DIR, model_type = "resnet"):
  11. self.model = model.ModelFactory.get_model(model_type, "corner")
  12. self.model.load_state_dict(torch.load(CHECKPOINT_DIR, map_location='cpu'))
  13. if torch.cuda.is_available():
  14. self.model.cuda()
  15. self.model.eval()
  16. def get_location(self, img, retainFactor=0.85):
  17. with torch.no_grad():
  18. ans_x = 0.0
  19. ans_y = 0.0
  20. o_img = np.copy(img)
  21. y = [0, 0]
  22. x_start = 0
  23. y_start = 0
  24. up_scale_factor = (img.shape[1], img.shape[0])
  25. myImage = np.copy(o_img)
  26. test_transform = transforms.Compose([transforms.Resize([32, 32]),
  27. transforms.ToTensor()])
  28. CROP_FRAC = retainFactor
  29. while (myImage.shape[0] > 10 and myImage.shape[1] > 10):
  30. img_temp = Image.fromarray(myImage)
  31. img_temp = test_transform(img_temp)
  32. img_temp = img_temp.unsqueeze(0)
  33. if torch.cuda.is_available():
  34. img_temp = img_temp.cuda()
  35. response = self.model(img_temp).cpu().data.numpy()
  36. response = response[0]
  37. response_up = response
  38. response_up = response_up * up_scale_factor
  39. y = response_up + (x_start, y_start)
  40. x_loc = int(y[0])
  41. y_loc = int(y[1])
  42. if x_loc > myImage.shape[1] / 2:
  43. start_x = min(x_loc + int(round(myImage.shape[1] * CROP_FRAC / 2)), myImage.shape[1]) - int(round(
  44. myImage.shape[1] * CROP_FRAC))
  45. else:
  46. start_x = max(x_loc - int(myImage.shape[1] * CROP_FRAC / 2), 0)
  47. if y_loc > myImage.shape[0] / 2:
  48. start_y = min(y_loc + int(myImage.shape[0] * CROP_FRAC / 2), myImage.shape[0]) - int(
  49. myImage.shape[0] * CROP_FRAC)
  50. else:
  51. start_y = max(y_loc - int(myImage.shape[0] * CROP_FRAC / 2), 0)
  52. ans_x += start_x
  53. ans_y += start_y
  54. myImage = myImage[start_y:start_y + int(myImage.shape[0] * CROP_FRAC),
  55. start_x:start_x + int(myImage.shape[1] * CROP_FRAC)]
  56. img = img[start_y:start_y + int(img.shape[0] * CROP_FRAC),
  57. start_x:start_x + int(img.shape[1] * CROP_FRAC)]
  58. up_scale_factor = (img.shape[1], img.shape[0])
  59. ans_x += y[0]
  60. ans_y += y[1]
  61. return (int(round(ans_x)), int(round(ans_y)))
  62. if __name__ == "__main__":
  63. pass