trainer.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. ''' Pytorch Recursive CNN Trainer
  2. Authors : Khurram Javed
  3. Maintainer : Khurram Javed
  4. Lab : TUKL-SEECS R&D Lab
  5. Email : 14besekjaved@seecs.edu.pk '''
  6. from __future__ import print_function
  7. import logging
  8. from torch.autograd import Variable
  9. logger = logging.getLogger('iCARL')
  10. import torch.nn.functional as F
  11. import torch
  12. from tqdm import tqdm
  13. class GenericTrainer:
  14. '''
  15. Base class for trainer; to implement a new training routine, inherit from this.
  16. '''
  17. def __init__(self):
  18. pass
  19. class Trainer(GenericTrainer):
  20. def __init__(self, train_iterator, model, cuda, optimizer):
  21. super().__init__()
  22. self.cuda = cuda
  23. self.train_iterator = train_iterator
  24. self.model = model
  25. self.optimizer = optimizer
  26. def update_lr(self, epoch, schedule, gammas):
  27. for temp in range(0, len(schedule)):
  28. if schedule[temp] == epoch:
  29. for param_group in self.optimizer.param_groups:
  30. self.current_lr = param_group['lr']
  31. param_group['lr'] = self.current_lr * gammas[temp]
  32. logger.debug("Changing learning rate from %0.9f to %0.9f", self.current_lr,
  33. self.current_lr * gammas[temp])
  34. self.current_lr *= gammas[temp]
  35. def train(self, epoch):
  36. self.model.train()
  37. lossAvg = None
  38. for img, target in tqdm(self.train_iterator):
  39. if self.cuda:
  40. img, target = img.cuda(), target.cuda()
  41. self.optimizer.zero_grad()
  42. response = self.model(Variable(img))
  43. # print (response[0])
  44. # print (target[0])
  45. loss = F.mse_loss(response, Variable(target.float()))
  46. loss = torch.sqrt(loss)
  47. if lossAvg is None:
  48. lossAvg = loss
  49. else:
  50. lossAvg += loss
  51. # logger.debug("Cur loss %s", str(loss))
  52. loss.backward()
  53. self.optimizer.step()
  54. lossAvg /= len(self.train_iterator)
  55. logger.info("Avg Loss %s", str((lossAvg).cpu().data.numpy()))
  56. class CIFARTrainer(GenericTrainer):
  57. def __init__(self, train_iterator, model, cuda, optimizer):
  58. super().__init__()
  59. self.cuda = cuda
  60. self.train_iterator = train_iterator
  61. self.model = model
  62. self.optimizer = optimizer
  63. self.criterion = torch.nn.CrossEntropyLoss()
  64. def update_lr(self, epoch, schedule, gammas):
  65. for temp in range(0, len(schedule)):
  66. if schedule[temp] == epoch:
  67. for param_group in self.optimizer.param_groups:
  68. self.current_lr = param_group['lr']
  69. param_group['lr'] = self.current_lr * gammas[temp]
  70. logger.debug("Changing learning rate from %0.9f to %0.9f", self.current_lr,
  71. self.current_lr * gammas[temp])
  72. self.current_lr *= gammas[temp]
  73. def train(self, epoch):
  74. self.model.train()
  75. train_loss = 0
  76. correct = 0
  77. total = 0
  78. for inputs, targets in tqdm(self.train_iterator):
  79. if self.cuda:
  80. inputs, targets = inputs.cuda(), targets.cuda()
  81. self.optimizer.zero_grad()
  82. outputs = self.model(Variable(inputs), pretrain=True)
  83. loss = self.criterion(outputs, Variable(targets))
  84. loss.backward()
  85. self.optimizer.step()
  86. train_loss += loss.item()
  87. _, predicted = outputs.max(1)
  88. total += targets.size(0)
  89. correct += predicted.eq(targets).sum().item()
  90. logger.info("Accuracy : %s", str((correct * 100) / total))
  91. return correct / total