trainer.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. ''' Pytorch Recursive CNN Trainer
  2. Authors : Khurram Javed
  3. Maintainer : Khurram Javed
  4. Lab : TUKL-SEECS R&D Lab
  5. Email : 14besekjaved@seecs.edu.pk '''
  6. from __future__ import print_function
  7. import logging
  8. from torch.autograd import Variable
  9. logger = logging.getLogger('iCARL')
  10. import torch.nn.functional as F
  11. import torch
  12. from tqdm import tqdm
  13. class GenericTrainer:
  14. '''
  15. Base class for trainer; to implement a new training routine, inherit from this.
  16. '''
  17. def __init__(self):
  18. pass
  19. class Trainer(GenericTrainer):
  20. def __init__(self, train_iterator, model, cuda, optimizer):
  21. super().__init__()
  22. self.cuda = cuda
  23. self.train_iterator = train_iterator
  24. self.model = model
  25. self.optimizer = optimizer
  26. def update_lr(self, epoch, schedule, gammas):
  27. for temp in range(0, len(schedule)):
  28. if schedule[temp] == epoch:
  29. for param_group in self.optimizer.param_groups:
  30. self.current_lr = param_group['lr']
  31. param_group['lr'] = self.current_lr * gammas[temp]
  32. logger.debug("Changing learning rate from %0.9f to %0.9f", self.current_lr,
  33. self.current_lr * gammas[temp])
  34. self.current_lr *= gammas[temp]
  35. def train(self, epoch):
  36. self.model.train()
  37. lossAvg = None
  38. for img, target in tqdm(self.train_iterator):
  39. if self.cuda:
  40. img, target = img.cuda(), target.cuda()
  41. self.optimizer.zero_grad()
  42. response = self.model(Variable(img))
  43. # print (response[0])
  44. # print (target[0])
  45. loss = F.mse_loss(response, Variable(target.float()))
  46. loss = torch.sqrt(loss)
  47. if lossAvg is None:
  48. lossAvg = loss
  49. else:
  50. lossAvg += loss
  51. # logger.debug("Cur loss %s", str(loss))
  52. loss.backward()
  53. self.optimizer.step()
  54. lossAvg /= len(self.train_iterator)
  55. logger.info("Avg Loss %s", str((lossAvg).cpu().data.numpy()))
  56. return str((lossAvg).cpu().data.numpy())
  57. class CIFARTrainer(GenericTrainer):
  58. def __init__(self, train_iterator, model, cuda, optimizer):
  59. super().__init__()
  60. self.cuda = cuda
  61. self.train_iterator = train_iterator
  62. self.model = model
  63. self.optimizer = optimizer
  64. self.criterion = torch.nn.CrossEntropyLoss()
  65. def update_lr(self, epoch, schedule, gammas):
  66. for temp in range(0, len(schedule)):
  67. if schedule[temp] == epoch:
  68. for param_group in self.optimizer.param_groups:
  69. self.current_lr = param_group['lr']
  70. param_group['lr'] = self.current_lr * gammas[temp]
  71. logger.debug("Changing learning rate from %0.9f to %0.9f", self.current_lr,
  72. self.current_lr * gammas[temp])
  73. self.current_lr *= gammas[temp]
  74. def train(self, epoch):
  75. self.model.train()
  76. train_loss = 0
  77. correct = 0
  78. total = 0
  79. for inputs, targets in tqdm(self.train_iterator):
  80. if self.cuda:
  81. inputs, targets = inputs.cuda(), targets.cuda()
  82. self.optimizer.zero_grad()
  83. outputs = self.model(Variable(inputs), pretrain=True)
  84. loss = self.criterion(outputs, Variable(targets))
  85. loss.backward()
  86. self.optimizer.step()
  87. train_loss += loss.item()
  88. _, predicted = outputs.max(1)
  89. total += targets.size(0)
  90. correct += predicted.eq(targets).sum().item()
  91. logger.info("Accuracy : %s", str((correct * 100) / total))
  92. return correct / total