train_seg_model.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. ## Document Localization using Recursive CNN
  2. ## Maintainer : Khurram Javed
  3. ## Email : kjaved@ualberta.ca
  4. ''' Document Localization using Recursive CNN
  5. Maintainer : Khurram Javed
  6. Email : kjaved@ualberta.ca '''
  7. from __future__ import print_function
  8. import argparse
  9. import torch
  10. import torch.utils.data as td
  11. import dataprocessor
  12. import experiment as ex
  13. import model
  14. import trainer
  15. import utils
  16. parser = argparse.ArgumentParser(description='iCarl2.0')
  17. parser.add_argument('--batch-size', type=int, default=32, metavar='N',
  18. help='input batch size for training (default: 64)')
  19. parser.add_argument('--lr', type=float, default=0.005, metavar='LR',
  20. help='learning rate (default: 2.0)')
  21. parser.add_argument('--schedule', type=int, nargs='+', default=[10, 20, 30],
  22. help='Decrease learning rate at these epochs.')
  23. parser.add_argument('--gammas', type=float, nargs='+', default=[0.2, 0.2, 0.2],
  24. help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule')
  25. parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
  26. help='SGD momentum (default: 0.9)')
  27. parser.add_argument('--no-cuda', action='store_true', default=False,
  28. help='disables CUDA training')
  29. parser.add_argument('--pretrain', action='store_true', default=False,
  30. help='Pretrain the model on CIFAR dataset?')
  31. parser.add_argument('--load-ram', action='store_true', default=False,
  32. help='Load data in ram')
  33. parser.add_argument('--debug', action='store_true', default=True,
  34. help='Debug messages')
  35. parser.add_argument('--seed', type=int, default=2323,
  36. help='Seeds values to be used')
  37. parser.add_argument('--log-interval', type=int, default=5, metavar='N',
  38. help='how many batches to wait before logging training status')
  39. parser.add_argument('--model-type', default="resnet",
  40. help='model type to be used. Example : resnet32, resnet20, densenet, test')
  41. parser.add_argument('--name', default="noname",
  42. help='Name of the experiment')
  43. parser.add_argument('--output-dir', default="../",
  44. help='Directory to store the results; a new folder "DDMMYYYY" will be created '
  45. 'in the specified directory to save the results.')
  46. parser.add_argument('--decay', type=float, default=0.00001, help='Weight decay (L2 penalty).')
  47. parser.add_argument('--epochs', type=int, default=40, help='Number of epochs for each increment')
  48. parser.add_argument('--dataset', default="document", help='Dataset to be used; example CIFAR, MNIST')
  49. parser.add_argument('--loader', default="hdd", help='Dataset to be used; example CIFAR, MNIST')
  50. parser.add_argument("-i", "--data-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
  51. help="input Directory of train data")
  52. parser.add_argument("-v", "--validation-dirs", nargs='+', default="/Users/khurramjaved96/documentTest64",
  53. help="input Directory of val data")
  54. args = parser.parse_args()
  55. # Define an experiment.
  56. my_experiment = ex.experiment(args.name, args, args.output_dir)
  57. # Add logging support
  58. logger = utils.utils.setup_logger(my_experiment.path)
  59. args.cuda = not args.no_cuda and torch.cuda.is_available()
  60. dataset = dataprocessor.DatasetFactory.get_dataset(args.data_dirs, args.dataset)
  61. dataset_val = dataprocessor.DatasetFactory.get_dataset(args.validation_dirs, args.dataset)
  62. # Fix the seed.
  63. seed = args.seed
  64. torch.manual_seed(seed)
  65. if args.cuda:
  66. torch.cuda.manual_seed(seed)
  67. train_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset.myData,
  68. transform=dataset.train_transform,
  69. cuda=args.cuda)
  70. # Loader used for training data
  71. val_dataset_loader = dataprocessor.LoaderFactory.get_loader(args.loader, dataset_val.myData,
  72. transform=dataset.test_transform,
  73. cuda=args.cuda)
  74. kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
  75. # Iterator to iterate over training data.
  76. train_iterator = torch.utils.data.DataLoader(train_dataset_loader,
  77. batch_size=args.batch_size, shuffle=True, **kwargs)
  78. # Iterator to iterate over training data.
  79. val_iterator = torch.utils.data.DataLoader(val_dataset_loader,
  80. batch_size=args.batch_size, shuffle=True, **kwargs)
  81. # Get the required model
  82. myModel = model.ModelFactory.get_model(args.model_type, args.dataset)
  83. if args.cuda:
  84. myModel.cuda()
  85. # Should I pretrain the model on CIFAR?
  86. if args.pretrain:
  87. trainset = dataprocessor.DatasetFactory.get_dataset(None, "CIFAR")
  88. train_iterator_cifar = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
  89. # Define the optimizer used in the experiment
  90. cifar_optimizer = torch.optim.SGD(myModel.parameters(), args.lr, momentum=args.momentum,
  91. weight_decay=args.decay, nesterov=True)
  92. # Trainer object used for training
  93. cifar_trainer = trainer.CIFARTrainer(train_iterator_cifar, myModel, args.cuda, cifar_optimizer)
  94. for epoch in range(0, 70):
  95. logger.info("Epoch : %d", epoch)
  96. cifar_trainer.update_lr(epoch, [30, 45, 60], args.gammas)
  97. cifar_trainer.train(epoch)
  98. # Freeze the model
  99. counter = 0
  100. for name, param in myModel.named_parameters():
  101. # Getting the length of total layers so I can freeze x% of layers
  102. gen_len = sum(1 for _ in myModel.parameters())
  103. if counter < int(gen_len * 0.5):
  104. param.requires_grad = False
  105. logger.warning(name)
  106. else:
  107. logger.info(name)
  108. counter += 1
  109. # Define the optimizer used in the experiment
  110. optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, myModel.parameters()), args.lr,
  111. momentum=args.momentum,
  112. weight_decay=args.decay, nesterov=True)
  113. # Trainer object used for training
  114. my_trainer = trainer.Trainer(train_iterator, myModel, args.cuda, optimizer)
  115. # Evaluator
  116. my_eval = trainer.EvaluatorFactory.get_evaluator("rmse", args.cuda)
  117. # Running epochs_class epochs
  118. for epoch in range(0, args.epochs):
  119. logger.info("Epoch : %d", epoch)
  120. my_trainer.update_lr(epoch, args.schedule, args.gammas)
  121. my_trainer.train(epoch)
  122. my_eval.evaluate(my_trainer.model, val_iterator)
  123. torch.save(myModel.state_dict(), my_experiment.path + args.dataset + "_" + args.model_type+ ".pb")
  124. my_experiment.store_json()