train_model.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. from __future__ import print_function
  5. import argparse
  6. import torch
  7. import torch.utils.data as td
  8. import dataprocessor
  9. import experiment as ex
  10. import model
  11. import trainer
  12. import utils
  13. parser = argparse.ArgumentParser(description='Recursive-CNNs')
  14. parser.add_argument('--batch-size', type=int, default=32, metavar='N',
  15. help='input batch size for training (default: 32)')
  16. parser.add_argument('--lr', type=float, default=0.005, metavar='LR',
  17. help='learning rate (default: 0.005)')
  18. parser.add_argument('--schedule', type=int, nargs='+', default=[10, 20, 30],
  19. help='Decrease learning rate at these epochs.')
  20. parser.add_argument('--gammas', type=float, nargs='+', default=[0.2, 0.2, 0.2],
  21. help='LR is multiplied by gamma[k] on schedule[k], number of gammas should be equal to schedule')
  22. parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
  23. help='SGD momentum (default: 0.9)')
  24. parser.add_argument('--no-cuda', action='store_true', default=False,
  25. help='disables CUDA training')
  26. parser.add_argument('--pretrain', action='store_true', default=False,
  27. help='Pretrain the model on CIFAR dataset?')
  28. parser.add_argument('--load-ram', action='store_true', default=False,
  29. help='Load data in ram: TODO : Remove this')
  30. parser.add_argument('--debug', action='store_true', default=True,
  31. help='Debug messages')
  32. parser.add_argument('--seed', type=int, default=2323,
  33. help='Seeds values to be used')
  34. parser.add_argument('--log-interval', type=int, default=5, metavar='N',
  35. help='how many batches to wait before logging training status')
  36. parser.add_argument('--model-type', default="resnet",
  37. help='model type to be used. Example : resnet32, resnet20, densenet, test')
  38. parser.add_argument('--name', default="noname",
  39. help='Name of the experiment')
  40. parser.add_argument('--output-dir', default="output/",
  41. help='Directory to store the results; a new folder "DDMMYYYY" will be created '
  42. 'in the specified directory to save the results.')
  43. parser.add_argument('--decay', type=float, default=0.00001, help='Weight decay (L2 penalty).')
  44. parser.add_argument('--epochs', type=int, default=40, help='Number of epochs for trianing')
  45. parser.add_argument('--dataset', default="document", help='Dataset to be used; example document, corner')
  46. parser.add_argument('--loader', default="hdd",
  47. help='Loader to load data; hdd for reading from the hdd and ram for loading all data in the memory')
  48. parser.add_argument("-i", "--data-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
  49. help="input Directory of train data")
  50. parser.add_argument("-v", "--validation-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
  51. help="input Directory of val data")
  52. args = parser.parse_args()
  53. # Define an experiment.
  54. my_experiment = ex.experiment(args.name, args, args.output_dir)
  55. # Add logging support
  56. logger = utils.utils.setup_logger(my_experiment.path)
  57. args.cuda = not args.no_cuda and torch.cuda.is_available()
  58. dataset = dataprocessor.DatasetFactory.get_dataset(args.data_dirs, args.dataset)
  59. dataset_val = dataprocessor.DatasetFactory.get_dataset(args.validation_dirs, args.dataset)
  60. # Fix the seed.
  61. seed = args.seed
  62. torch.manual_seed(seed)
  63. if args.cuda:
  64. torch.cuda.manual_seed(seed)
  65. train_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset.myData,
  66. transform=dataset.train_transform,
  67. cuda=args.cuda)
  68. # Loader used for training data
  69. val_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset_val.myData,
  70. transform=dataset.test_transform,
  71. cuda=args.cuda)
  72. kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
  73. # Iterator to iterate over training data.
  74. train_iterator = torch.utils.data.DataLoader(train_dataset_loader,
  75. batch_size=args.batch_size, shuffle=True, **kwargs)
  76. # Iterator to iterate over training data.
  77. val_iterator = torch.utils.data.DataLoader(val_dataset_loader,
  78. batch_size=args.batch_size, shuffle=True, **kwargs)
  79. # Get the required model
  80. myModel = model.ModelFactory.get_model(args.model_type, args.dataset)
  81. if args.cuda:
  82. myModel.cuda()
  83. # Should I pretrain the model on CIFAR?
  84. if args.pretrain:
  85. trainset = dataprocessor.DatasetFactory.get_dataset(None, "CIFAR")
  86. train_iterator_cifar = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
  87. # Define the optimizer used in the experiment
  88. cifar_optimizer = torch.optim.SGD(myModel.parameters(), args.lr, momentum=args.momentum,
  89. weight_decay=args.decay, nesterov=True)
  90. # Trainer object used for training
  91. cifar_trainer = trainer.CIFARTrainer(train_iterator_cifar, myModel, args.cuda, cifar_optimizer)
  92. for epoch in range(0, 70):
  93. logger.info("Epoch : %d", epoch)
  94. cifar_trainer.update_lr(epoch, [30, 45, 60], args.gammas)
  95. cifar_trainer.train(epoch)
  96. # Freeze the model
  97. counter = 0
  98. for name, param in myModel.named_parameters():
  99. # Getting the length of total layers so I can freeze x% of layers
  100. gen_len = sum(1 for _ in myModel.parameters())
  101. if counter < int(gen_len * 0.5):
  102. param.requires_grad = False
  103. logger.warning(name)
  104. else:
  105. logger.info(name)
  106. counter += 1
  107. # Define the optimizer used in the experiment
  108. optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, myModel.parameters()), args.lr,
  109. momentum=args.momentum,
  110. weight_decay=args.decay, nesterov=True)
  111. # Trainer object used for training
  112. my_trainer = trainer.Trainer(train_iterator, myModel, args.cuda, optimizer)
  113. # Evaluator
  114. my_eval = trainer.EvaluatorFactory.get_evaluator("rmse", args.cuda)
  115. # Running epochs_class epochs
  116. for epoch in range(0, args.epochs):
  117. logger.info("Epoch : %d", epoch)
  118. my_trainer.update_lr(epoch, args.schedule, args.gammas)
  119. my_trainer.train(epoch)
  120. my_eval.evaluate(my_trainer.model, val_iterator)
  121. torch.save(myModel.state_dict(), my_experiment.path + args.dataset + "_" + args.model_type+ ".pb")
  122. my_experiment.store_json()