train.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. from infer import infer_test
  2. from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure
  3. from torchvision import transforms
  4. from torch.utils.tensorboard import SummaryWriter
  5. import argparse
  6. import torchvision.transforms as T
  7. import shutil
  8. import os
  9. from matplotlib import pyplot as plt
  10. from model import M64ColorNet
  11. from loss import DocCleanLoss
  12. from torch.utils.data import DataLoader
  13. from dataset import DocCleanDataset
  14. import torch
  15. from tqdm import tqdm
  16. from nni.compression.pytorch.pruning import L1NormPruner
  17. from nni.compression.pytorch.speedup import ModelSpeedup
  18. import matplotlib
  19. matplotlib.use('Agg')
  20. # from torchinfo import summary
  21. writer = SummaryWriter()
  22. def boolean_string(s):
  23. ''' Check s string is true or false.
  24. Args:
  25. s: the string
  26. Returns:
  27. boolean
  28. '''
  29. s = s.lower()
  30. if s not in {'false', 'true'}:
  31. raise ValueError('Not a valid boolean string')
  32. return s == 'true'
  33. # path parameters
  34. parser = argparse.ArgumentParser()
  35. parser.add_argument('--develop',
  36. type=boolean_string,
  37. help='Develop mode turn off by default',
  38. default=False)
  39. parser.add_argument('--lr',
  40. type=float,
  41. help='Develop mode turn off by default',
  42. default=1e-3)
  43. parser.add_argument('--batch_size',
  44. type=int,
  45. help='Develop mode turn off by default',
  46. default=16)
  47. parser.add_argument('--retrain',
  48. type=boolean_string,
  49. help='Whether to restore the checkpoint',
  50. default=False)
  51. parser.add_argument('--epochs',
  52. type=int,
  53. help='Max training epoch',
  54. default=500)
  55. parser.add_argument('--dataset',
  56. type=str,
  57. help='Max training epoch',
  58. default="dataset/raw_data/imgs_Trainblocks")
  59. parser.add_argument('--shuffle',
  60. type=boolean_string,
  61. help='Whether to shuffle dataset',
  62. default=True)
  63. def saveEvalImg(img_dir: str, batch_idx: int, imgs, pred_imgs, gt_imgs, normalized_imgs):
  64. transform = T.ToPILImage()
  65. for idx, (img, normalized_img, pred_img, gt_img) in enumerate(zip(imgs, normalized_imgs, pred_imgs, gt_imgs)):
  66. img = transform(img)
  67. normalized_img = transform(normalized_img)
  68. pred_img = transform(pred_img)
  69. gt_img = transform(gt_img)
  70. f, axarr = plt.subplots(1, 4)
  71. axarr[0].imshow(img)
  72. axarr[0].title.set_text('orig')
  73. axarr[1].imshow(normalized_img)
  74. axarr[1].title.set_text('normal')
  75. axarr[2].imshow(pred_img)
  76. axarr[2].title.set_text('pred')
  77. axarr[3].imshow(gt_img)
  78. axarr[3].title.set_text('gt')
  79. f.savefig(f"{img_dir}/{batch_idx:04d}_{idx}.jpg")
  80. plt.close()
  81. def evaluator(model:torch.nn.Module, epoch:int, test_loader:DataLoader, tag:str):
  82. img_dir = f"{output}/{tag}/{epoch}"
  83. if os.path.exists(img_dir):
  84. shutil.rmtree(img_dir, ignore_errors=True)
  85. os.makedirs(img_dir)
  86. valid_loss = 0
  87. model.eval()
  88. eval_criterion = DocCleanLoss(device)
  89. with torch.no_grad():
  90. ssim_score = 0
  91. psnr_score = 0
  92. for index, (imgs, normalized_imgs, gt_imgs) in enumerate(tqdm(test_loader)):
  93. imgs = imgs.to(device)
  94. gt_imgs = gt_imgs.to(device)
  95. normalized_imgs = normalized_imgs.to(device)
  96. pred_imgs = model(normalized_imgs)
  97. ssim_score += structural_similarity_index_measure(
  98. pred_imgs, gt_imgs).item()
  99. psnr_score += peak_signal_noise_ratio(pred_imgs, gt_imgs).item()
  100. loss, _, _, _ = eval_criterion(pred_imgs, gt_imgs)
  101. valid_loss += loss.item()
  102. if index % 30 == 0:
  103. saveEvalImg(img_dir=img_dir, batch_idx=index, imgs=imgs,
  104. pred_imgs=pred_imgs, gt_imgs=gt_imgs, normalized_imgs=normalized_imgs)
  105. data_len = len(test_loader)
  106. valid_loss = valid_loss / data_len
  107. psnr_score = psnr_score / data_len
  108. ssim_score = ssim_score / data_len
  109. return valid_loss, psnr_score, ssim_score
  110. def batch_mean_std(loader):
  111. nb_samples = 0.
  112. channel_mean = torch.zeros(3)
  113. channel_std = torch.zeros(3)
  114. for images, _, _ in tqdm(loader):
  115. # scale image to be between 0 and 1
  116. N, C, H, W = images.shape[:4]
  117. data = images.view(N, C, -1)
  118. channel_mean += data.mean(2).sum(0)
  119. channel_std += data.std(2).sum(0)
  120. nb_samples += N
  121. channel_mean /= nb_samples
  122. channel_std /= nb_samples
  123. return channel_mean, channel_std
  124. def saveCkpt(model, model_path, epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score):
  125. torch.save({
  126. 'epoch': epoch,
  127. 'model_state_dict': model.state_dict(),
  128. 'optimizer_state_dict': optimizer.state_dict(),
  129. 'scheduler_state_dict': scheduler.state_dict(),
  130. 'loss': validation_loss,
  131. 'mean': mean,
  132. 'std': std,
  133. 'psnr_score': psnr_score,
  134. 'ssim_score': ssim_score
  135. }, model_path)
  136. def trainer(model:torch.nn.Module, criterion:DocCleanLoss, optimizer:torch.optim.Adam, tag:str, epoch:int):
  137. # train
  138. model.train()
  139. running_loss = 0
  140. running_content_loss = 0
  141. running_style_loss = 0
  142. running_pixel_loss = 0
  143. img_dir = f"{output}/{tag}/{epoch}"
  144. if os.path.exists(img_dir):
  145. shutil.rmtree(img_dir, ignore_errors=True)
  146. os.makedirs(img_dir)
  147. for index, (imgs, normalized_imgs, gt_imgs) in enumerate(tqdm(train_loader)):
  148. optimizer.zero_grad()
  149. imgs = imgs.to(device)
  150. gt_imgs = gt_imgs.to(device)
  151. normalized_imgs = normalized_imgs.to(device)
  152. pred_imgs = model(normalized_imgs)
  153. loss, p_l_loss, content_loss, style_loss = criterion(
  154. pred_imgs, gt_imgs)
  155. loss.backward()
  156. optimizer.step()
  157. running_loss += loss.item()
  158. running_pixel_loss += p_l_loss.item()
  159. running_content_loss += content_loss.item()
  160. running_style_loss += style_loss.item()
  161. if index % 200 == 0:
  162. saveEvalImg(img_dir=img_dir, batch_idx=index, imgs=imgs,
  163. pred_imgs=pred_imgs, gt_imgs=gt_imgs, normalized_imgs=normalized_imgs)
  164. return running_loss, running_pixel_loss, running_content_loss, running_style_loss
  165. def model_pruning():
  166. model, mean, std = M64ColorNet.load_trained_model("output/model.pt")
  167. model.to(device)
  168. # Compress this model.
  169. config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
  170. pruner = L1NormPruner(model, config_list)
  171. _, masks = pruner.compress()
  172. print('\nThe accuracy with masks:')
  173. evaluator(model, 0, test_loader, "masks")
  174. pruner._unwrap_model()
  175. ModelSpeedup(model, dummy_input=torch.rand(1, 3, 256, 256).to(device), masks_file=masks).speedup_model()
  176. print('\nThe accuracy after speedup:')
  177. evaluator(model, 0, test_loader, "speedup")
  178. # Need a new optimizer due to the modules in model will be replaced during speedup.
  179. optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
  180. criterion = DocCleanLoss(device=device)
  181. print('\nFinetune the model after speedup:')
  182. for i in range(5):
  183. trainer(model, criterion, optimizer, "train_finetune", i)
  184. evaluator(model, i, test_loader, "eval_finetune")
  185. def pretrain():
  186. print(f"device={device} \
  187. develop={args.develop} \
  188. lr={args.lr} \
  189. mean={mean} \
  190. std={std} \
  191. shuffle={args.shuffle}")
  192. model_cls = M64ColorNet
  193. model = model_cls()
  194. model.to(device)
  195. # summary(model, input_size=(batch_size, 3, 256, 256))
  196. optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
  197. scheduler = torch.optim.lr_scheduler.StepLR(
  198. optimizer, step_size=15, gamma=0.8)
  199. model_path = f"{output}/model.pt"
  200. current_epoch = 1
  201. previous_loss = float('inf')
  202. criterion = DocCleanLoss(device)
  203. if os.path.exists(model_path):
  204. checkpoint = torch.load(model_path)
  205. model.load_state_dict(checkpoint['model_state_dict'])
  206. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  207. scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
  208. current_epoch = checkpoint['epoch'] + 1
  209. previous_loss = checkpoint['loss']
  210. for epoch in range(current_epoch, current_epoch+args.epochs):
  211. running_loss, running_pixel_loss, running_content_loss, running_style_loss = trainer(model, criterion, optimizer, "train", epoch)
  212. train_loss = running_loss / len(train_loader)
  213. train_content_loss = running_content_loss / len(train_loader)
  214. train_style_loss = running_style_loss / len(train_loader)
  215. train_pixel_loss = running_pixel_loss / len(train_loader)
  216. # evaluate
  217. validation_loss, psnr_score, ssim_score = evaluator(model, epoch, test_loader, "eval")
  218. writer.add_scalar("Loss/train", train_loss, epoch)
  219. writer.add_scalar("Loss/validation", validation_loss, epoch)
  220. writer.add_scalar("metric/psnr", psnr_score, epoch)
  221. writer.add_scalar("metric/ssim", ssim_score, epoch)
  222. if previous_loss > validation_loss:
  223. # This model_path is used for resume training. Hold the latest ckpt.
  224. saveCkpt(model, model_path, epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score)
  225. # This for each epoch ckpt.
  226. saveCkpt(model, f"{output}/model_{epoch}.pt", epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score)
  227. infer_test(f"{output}/infer_test/{epoch}",
  228. "infer_imgs", model_path, model_cls)
  229. previous_loss = validation_loss
  230. scheduler.step()
  231. print(
  232. f"epoch:{epoch} \
  233. train_loss:{round(train_loss, 4)} \
  234. validation_loss:{round(validation_loss, 4)} \
  235. pixel_loss:{round(train_pixel_loss, 4)} \
  236. content_loss:{round(train_content_loss, 8)} \
  237. style_loss:{round(train_style_loss, 4)} \
  238. lr:{round(optimizer.param_groups[0]['lr'], 5)} \
  239. psnr:{round(psnr_score, 3)} \
  240. ssim:{round(ssim_score, 3)}"
  241. )
  242. if __name__ == "__main__":
  243. args = parser.parse_args()
  244. train_img_names, eval_img_names, imgs_dir = DocCleanDataset.prepareDataset(args.dataset, args.shuffle)
  245. output = "output"
  246. if args.retrain == True:
  247. shutil.rmtree(output, ignore_errors=True)
  248. if os.path.exists(output) == False:
  249. os.mkdir(output)
  250. print(
  251. f"trainset num:{len(train_img_names)}\nevalset num:{len(eval_img_names)}")
  252. dataset = DocCleanDataset(
  253. img_names=train_img_names, imgs_dir=imgs_dir, dev=args.develop)
  254. mean, std = batch_mean_std(DataLoader(
  255. dataset=dataset, batch_size=args.batch_size))
  256. # mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
  257. # transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
  258. train_set = DocCleanDataset(
  259. img_names=train_img_names, imgs_dir=imgs_dir, normalized_tuple=(mean, std), dev=args.develop, img_aug=True)
  260. test_set = DocCleanDataset(
  261. img_names=eval_img_names, imgs_dir=imgs_dir, normalized_tuple=(mean, std), dev=args.develop)
  262. train_loader = DataLoader(
  263. dataset=train_set, batch_size=args.batch_size, shuffle=args.shuffle)
  264. test_loader = DataLoader(dataset=test_set, batch_size=args.batch_size)
  265. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  266. pretrain()
  267. # model_pruning()
  268. writer.flush()