datasetfactory.py 608 B

12345678910111213141516171819
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. import dataprocessor.dataset as data
  5. import torchvision
  6. class DatasetFactory:
  7. def __init__(self):
  8. pass
  9. @staticmethod
  10. def get_dataset(directory, type="document"):
  11. if type=="document":
  12. return data.SmartDoc(directory)
  13. elif type =="corner":
  14. return data.SmartDocCorner(directory)
  15. elif type=="CIFAR":
  16. return torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())