corner_extractor.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. import numpy as np
  5. import torch
  6. from PIL import Image
  7. from torchvision import transforms
  8. import model
  9. class GetCorners:
  10. def __init__(self, checkpoint_dir, model_type = "resnet"):
  11. self.model = model.ModelFactory.get_model(model_type, 'document')
  12. self.model.load_state_dict(torch.load(checkpoint_dir, map_location='cpu'))
  13. if torch.cuda.is_available():
  14. self.model.cuda()
  15. self.model.eval()
  16. def get(self, pil_image):
  17. with torch.no_grad():
  18. image_array = np.copy(pil_image)
  19. pil_image = Image.fromarray(pil_image)
  20. test_transform = transforms.Compose([transforms.Resize([32, 32]),
  21. transforms.ToTensor()])
  22. img_temp = test_transform(pil_image)
  23. img_temp = img_temp.unsqueeze(0)
  24. if torch.cuda.is_available():
  25. img_temp = img_temp.cuda()
  26. model_prediction = self.model(img_temp).cpu().data.numpy()[0]
  27. model_prediction = np.array(model_prediction)
  28. x_cords = model_prediction[[0, 2, 4, 6]]
  29. y_cords = model_prediction[[1, 3, 5, 7]]
  30. x_cords = x_cords * image_array.shape[1]
  31. y_cords = y_cords * image_array.shape[0]
  32. # Extract the four corners of the image. Read "Region Extractor" in Section III of the paper for an explanation.
  33. top_left = image_array[
  34. max(0, int(2 * y_cords[0] - (y_cords[3] + y_cords[0]) / 2)):int((y_cords[3] + y_cords[0]) / 2),
  35. max(0, int(2 * x_cords[0] - (x_cords[1] + x_cords[0]) / 2)):int((x_cords[1] + x_cords[0]) / 2)]
  36. top_right = image_array[
  37. max(0, int(2 * y_cords[1] - (y_cords[1] + y_cords[2]) / 2)):int((y_cords[1] + y_cords[2]) / 2),
  38. int((x_cords[1] + x_cords[0]) / 2):min(image_array.shape[1] - 1,
  39. int(x_cords[1] + (x_cords[1] - x_cords[0]) / 2))]
  40. bottom_right = image_array[int((y_cords[1] + y_cords[2]) / 2):min(image_array.shape[0] - 1, int(
  41. y_cords[2] + (y_cords[2] - y_cords[1]) / 2)),
  42. int((x_cords[2] + x_cords[3]) / 2):min(image_array.shape[1] - 1,
  43. int(x_cords[2] + (x_cords[2] - x_cords[3]) / 2))]
  44. bottom_left = image_array[int((y_cords[0] + y_cords[3]) / 2):min(image_array.shape[0] - 1, int(
  45. y_cords[3] + (y_cords[3] - y_cords[0]) / 2)),
  46. max(0, int(2 * x_cords[3] - (x_cords[2] + x_cords[3]) / 2)):int(
  47. (x_cords[3] + x_cords[2]) / 2)]
  48. top_left = (top_left, max(0, int(2 * x_cords[0] - (x_cords[1] + x_cords[0]) / 2)),
  49. max(0, int(2 * y_cords[0] - (y_cords[3] + y_cords[0]) / 2)))
  50. top_right = (
  51. top_right, int((x_cords[1] + x_cords[0]) / 2), max(0, int(2 * y_cords[1] - (y_cords[1] + y_cords[2]) / 2)))
  52. bottom_right = (bottom_right, int((x_cords[2] + x_cords[3]) / 2), int((y_cords[1] + y_cords[2]) / 2))
  53. bottom_left = (bottom_left, max(0, int(2 * x_cords[3] - (x_cords[2] + x_cords[3]) / 2)),
  54. int((y_cords[0] + y_cords[3]) / 2))
  55. return top_left, top_right, bottom_right, bottom_left