123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- ''' Document Localization using Recursive CNN
- Maintainer : Khurram Javed
- Email : kjaved@ualberta.ca '''
- import numpy as np
- import torch
- from PIL import Image
- from torchvision import transforms
- import model
- class corner_finder():
- def __init__(self, CHECKPOINT_DIR):
- self.model = model.ModelFactory.get_model("resnet", "corner")
- self.model.load_state_dict(torch.load(CHECKPOINT_DIR, map_location='cpu'))
- if torch.cuda.is_available():
- self.model.cuda()
- self.model.eval()
- def get_location(self, img, retainFactor=0.85):
- with torch.no_grad():
- ans_x = 0.0
- ans_y = 0.0
- o_img = np.copy(img)
- y = [0, 0]
- x_start = 0
- y_start = 0
- up_scale_factor = (img.shape[1], img.shape[0])
- myImage = np.copy(o_img)
- test_transform = transforms.Compose([transforms.Resize([32, 32]),
- transforms.ToTensor()])
- CROP_FRAC = retainFactor
- while (myImage.shape[0] > 10 and myImage.shape[1] > 10):
- img_temp = Image.fromarray(myImage)
- img_temp = test_transform(img_temp)
- img_temp = img_temp.unsqueeze(0)
- if torch.cuda.is_available():
- img_temp = img_temp.cuda()
- response = self.model(img_temp).cpu().data.numpy()
- response = response[0]
- response_up = response
- response_up = response_up * up_scale_factor
- y = response_up + (x_start, y_start)
- x_loc = int(y[0])
- y_loc = int(y[1])
- if x_loc > myImage.shape[1] / 2:
- start_x = min(x_loc + int(round(myImage.shape[1] * CROP_FRAC / 2)), myImage.shape[1]) - int(round(
- myImage.shape[1] * CROP_FRAC))
- else:
- start_x = max(x_loc - int(myImage.shape[1] * CROP_FRAC / 2), 0)
- if y_loc > myImage.shape[0] / 2:
- start_y = min(y_loc + int(myImage.shape[0] * CROP_FRAC / 2), myImage.shape[0]) - int(
- myImage.shape[0] * CROP_FRAC)
- else:
- start_y = max(y_loc - int(myImage.shape[0] * CROP_FRAC / 2), 0)
- ans_x += start_x
- ans_y += start_y
- myImage = myImage[start_y:start_y + int(myImage.shape[0] * CROP_FRAC),
- start_x:start_x + int(myImage.shape[1] * CROP_FRAC)]
- img = img[start_y:start_y + int(img.shape[0] * CROP_FRAC),
- start_x:start_x + int(img.shape[1] * CROP_FRAC)]
- up_scale_factor = (img.shape[1], img.shape[0])
- ans_x += y[0]
- ans_y += y[1]
- return (int(round(ans_x)), int(round(ans_y)))
- if __name__ == "__main__":
- pass
|