modelfactory.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. import model.resnet32 as resnet
  5. import model.cornerModel as tm
  6. import torchvision.models as models
  7. class ModelFactory():
  8. def __init__(self):
  9. pass
  10. @staticmethod
  11. def get_model(model_type, dataset):
  12. if model_type == "resnet":
  13. if dataset == 'document':
  14. return resnet.resnet20(8)
  15. elif dataset == 'corner':
  16. return resnet.resnet20(2)
  17. if model_type == "resnet8":
  18. if dataset == 'document':
  19. return resnet.resnet8(8)
  20. elif dataset == 'corner':
  21. return resnet.resnet8(2)
  22. elif model_type == 'shallow':
  23. if dataset == 'document':
  24. return tm.MobileNet(8)
  25. elif dataset == 'corner':
  26. return tm.MobileNet(2)
  27. elif model_type =="squeeze":
  28. if dataset == 'document':
  29. return models.squeezenet1_1(True)
  30. elif dataset == 'corner':
  31. return models.squeezenet1_1(True)
  32. else:
  33. print("Unsupported model; either implement the model in model/ModelFactory or choose a different model")
  34. assert (False)