train_model.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. ''' Document Localization using Recursive CNN
  2. Maintainer : Khurram Javed
  3. Email : kjaved@ualberta.ca '''
  4. from __future__ import print_function
  5. import argparse
  6. import torch
  7. import torch.utils.data as td
  8. import dataprocessor
  9. import experiment as ex
  10. import model
  11. import trainer
  12. import utils
  13. parser = argparse.ArgumentParser(description='Recursive-CNNs')
  14. parser.add_argument('--batch-size', type=int, default=32, metavar='N',
  15. help='input batch size for training (default: 32)')
  16. parser.add_argument('--eval_interval', type=int, default=5)
  17. parser.add_argument('--lr', type=float, default=0.005, metavar='LR',
  18. help='learning rate (default: 0.005)')
  19. parser.add_argument('--schedule', type=int, nargs='+', default=[10, 20, 30],
  20. help='Decrease learning rate at these epochs.')
  21. parser.add_argument('--gammas', type=float, nargs='+', default=[0.2, 0.2, 0.2],
  22. help='LR is multiplied by gamma[k] on schedule[k], number of gammas should be equal to schedule')
  23. parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
  24. help='SGD momentum (default: 0.9)')
  25. parser.add_argument('--no-cuda', action='store_true', default=False,
  26. help='disables CUDA training')
  27. parser.add_argument('--pretrain', action='store_true', default=False,
  28. help='Pretrain the model on CIFAR dataset?')
  29. parser.add_argument('--load-ram', action='store_true', default=False,
  30. help='Load data in ram: TODO : Remove this')
  31. parser.add_argument('--debug', action='store_true', default=True,
  32. help='Debug messages')
  33. parser.add_argument('--seed', type=int, default=2323,
  34. help='Seeds values to be used')
  35. parser.add_argument('--log-interval', type=int, default=5, metavar='N',
  36. help='how many batches to wait before logging training status')
  37. parser.add_argument('--model-type', default="resnet",
  38. help='model type to be used. Example : resnet32, resnet20, densenet, test')
  39. parser.add_argument('--name', default="noname",
  40. help='Name of the experiment')
  41. parser.add_argument('--output-dir', default="./",
  42. help='Directory to store the results; a new folder "DDMMYYYY" will be created '
  43. 'in the specified directory to save the results.')
  44. parser.add_argument('--decay', type=float, default=0.00001, help='Weight decay (L2 penalty).')
  45. parser.add_argument('--epochs', type=int, default=100, help='Number of epochs for trianing')
  46. parser.add_argument('--dataset', default="document", help='Dataset to be used; example document, corner')
  47. parser.add_argument('--loader', default="hdd",
  48. help='Loader to load data; hdd for reading from the hdd and ram for loading all data in the memory')
  49. parser.add_argument("-i", "--data-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
  50. help="input Directory of train data")
  51. parser.add_argument("-v", "--validation-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
  52. help="input Directory of val data")
  53. args = parser.parse_args()
  54. # Define an experiment.
  55. my_experiment = ex.experiment(args.name, args, args.output_dir)
  56. # Add logging support
  57. logger = utils.utils.setup_logger(my_experiment.path)
  58. args.cuda = not args.no_cuda and torch.cuda.is_available()
  59. dataset = dataprocessor.DatasetFactory.get_dataset(args.data_dirs, args.dataset)
  60. dataset_val = dataprocessor.DatasetFactory.get_dataset(args.validation_dirs, args.dataset)
  61. # Fix the seed.
  62. seed = args.seed
  63. torch.manual_seed(seed)
  64. if args.cuda:
  65. torch.cuda.manual_seed(seed)
  66. train_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset.myData,
  67. transform=dataset.train_transform,
  68. cuda=args.cuda)
  69. # Loader used for training data
  70. val_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset_val.myData,
  71. transform=dataset.test_transform,
  72. cuda=args.cuda)
  73. kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
  74. # Iterator to iterate over training data.
  75. train_iterator = torch.utils.data.DataLoader(train_dataset_loader,
  76. batch_size=args.batch_size, shuffle=True, **kwargs)
  77. # Iterator to iterate over training data.
  78. val_iterator = torch.utils.data.DataLoader(val_dataset_loader,
  79. batch_size=args.batch_size, shuffle=True, **kwargs)
  80. # Get the required model
  81. myModel = model.ModelFactory.get_model(args.model_type, args.dataset)
  82. if args.cuda:
  83. myModel.cuda()
  84. # Should I pretrain the model on CIFAR?
  85. if args.pretrain:
  86. trainset = dataprocessor.DatasetFactory.get_dataset(None, "CIFAR")
  87. train_iterator_cifar = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
  88. # Define the optimizer used in the experiment
  89. cifar_optimizer = torch.optim.SGD(myModel.parameters(), args.lr, momentum=args.momentum,
  90. weight_decay=args.decay, nesterov=True)
  91. # Trainer object used for training
  92. cifar_trainer = trainer.CIFARTrainer(train_iterator_cifar, myModel, args.cuda, cifar_optimizer)
  93. for epoch in range(0, 70):
  94. logger.info("Epoch : %d", epoch)
  95. cifar_trainer.update_lr(epoch, [30, 45, 60], args.gammas)
  96. cifar_trainer.train(epoch)
  97. # Freeze the model
  98. counter = 0
  99. for name, param in myModel.named_parameters():
  100. # Getting the length of total layers so I can freeze x% of layers
  101. gen_len = sum(1 for _ in myModel.parameters())
  102. if counter < int(gen_len * 0.5):
  103. param.requires_grad = False
  104. logger.warning(name)
  105. else:
  106. logger.info(name)
  107. counter += 1
  108. # Define the optimizer used in the experiment
  109. optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, myModel.parameters()), args.lr,
  110. momentum=args.momentum,
  111. weight_decay=args.decay, nesterov=True)
  112. # Trainer object used for training
  113. my_trainer = trainer.Trainer(train_iterator, myModel, args.cuda, optimizer)
  114. # Evaluator
  115. my_eval = trainer.EvaluatorFactory.get_evaluator("rmse", args.cuda)
  116. # Running epochs_class epochs
  117. cnt = 1
  118. max_loss = 100
  119. for epoch in range(0, args.epochs):
  120. logger.info("Epoch : %d", epoch)
  121. my_trainer.update_lr(epoch, args.schedule, args.gammas)
  122. l = my_trainer.train(epoch)
  123. loss = float(l)
  124. if loss < max_loss:
  125. torch.save(myModel.state_dict(), my_experiment.path + "_" + args.dataset + "_" + args.model_type+ "_best.pth")
  126. logger.info("best_model saved, avg_loss is %s", l)
  127. max_loss = loss
  128. logger.info("the best model's avg_loss is %s", l)
  129. if cnt%args.eval_interval == 0:
  130. my_eval.evaluate(my_trainer.model, val_iterator)
  131. cnt += 1
  132. torch.save(myModel.state_dict(), my_experiment.path + "_" + args.dataset + "_" + args.model_type+ "_final.pth")
  133. # torch.save(myModel, my_experiment.path + args.dataset + "_" + args.model_type+ ".pth")
  134. my_experiment.store_json()