12345678910111213141516171819202122232425262728293031323334353637 |
- ''' Document Localization using Recursive CNN
- Maintainer : Khurram Javed
- Email : kjaved@ualberta.ca '''
- import model.resnet32 as resnet
- import model.cornerModel as tm
- import torchvision.models as models
- class ModelFactory():
- def __init__(self):
- pass
- @staticmethod
- def get_model(model_type, dataset):
- if model_type == "resnet":
- if dataset == 'document':
- return resnet.resnet20(8)
- elif dataset == 'corner':
- return resnet.resnet20(2)
- if model_type == "resnet8":
- if dataset == 'document':
- return resnet.resnet8(8)
- elif dataset == 'corner':
- return resnet.resnet8(2)
- elif model_type == 'shallow':
- if dataset == 'document':
- return tm.MobileNet(8)
- elif dataset == 'corner':
- return tm.MobileNet(2)
- elif model_type =="squeeze":
- if dataset == 'document':
- return models.squeezenet1_1(True)
- elif dataset == 'corner':
- return models.squeezenet1_1(True)
- else:
- print("Unsupported model; either implement the model in model/ModelFactory or choose a different model")
- assert (False)
|