|
@@ -0,0 +1,295 @@
|
|
|
+from infer import infer_test
|
|
|
+from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure
|
|
|
+from torchvision import transforms
|
|
|
+from torch.utils.tensorboard import SummaryWriter
|
|
|
+import argparse
|
|
|
+import torchvision.transforms as T
|
|
|
+import shutil
|
|
|
+import os
|
|
|
+from matplotlib import pyplot as plt
|
|
|
+from model import M64ColorNet
|
|
|
+from loss import DocCleanLoss
|
|
|
+from torch.utils.data import DataLoader
|
|
|
+from dataset import DocCleanDataset
|
|
|
+import torch
|
|
|
+from tqdm import tqdm
|
|
|
+from nni.compression.pytorch.pruning import L1NormPruner
|
|
|
+from nni.compression.pytorch.speedup import ModelSpeedup
|
|
|
+import matplotlib
|
|
|
+matplotlib.use('Agg')
|
|
|
+# from torchinfo import summary
|
|
|
+writer = SummaryWriter()
|
|
|
+
|
|
|
+
|
|
|
+def boolean_string(s):
|
|
|
+ ''' Check s string is true or false.
|
|
|
+ Args:
|
|
|
+ s: the string
|
|
|
+ Returns:
|
|
|
+ boolean
|
|
|
+ '''
|
|
|
+ s = s.lower()
|
|
|
+ if s not in {'false', 'true'}:
|
|
|
+ raise ValueError('Not a valid boolean string')
|
|
|
+ return s == 'true'
|
|
|
+
|
|
|
+
|
|
|
+# path parameters
|
|
|
+parser = argparse.ArgumentParser()
|
|
|
+parser.add_argument('--develop',
|
|
|
+ type=boolean_string,
|
|
|
+ help='Develop mode turn off by default',
|
|
|
+ default=False)
|
|
|
+parser.add_argument('--lr',
|
|
|
+ type=float,
|
|
|
+ help='Develop mode turn off by default',
|
|
|
+ default=1e-3)
|
|
|
+parser.add_argument('--batch_size',
|
|
|
+ type=int,
|
|
|
+ help='Develop mode turn off by default',
|
|
|
+ default=16)
|
|
|
+parser.add_argument('--retrain',
|
|
|
+ type=boolean_string,
|
|
|
+ help='Whether to restore the checkpoint',
|
|
|
+ default=False)
|
|
|
+parser.add_argument('--epochs',
|
|
|
+ type=int,
|
|
|
+ help='Max training epoch',
|
|
|
+ default=500)
|
|
|
+parser.add_argument('--dataset',
|
|
|
+ type=str,
|
|
|
+ help='Max training epoch',
|
|
|
+ default="dataset/raw_data/imgs_Trainblocks")
|
|
|
+parser.add_argument('--shuffle',
|
|
|
+ type=boolean_string,
|
|
|
+ help='Whether to shuffle dataset',
|
|
|
+ default=True)
|
|
|
+
|
|
|
+
|
|
|
+def saveEvalImg(img_dir: str, batch_idx: int, imgs, pred_imgs, gt_imgs, normalized_imgs):
|
|
|
+ transform = T.ToPILImage()
|
|
|
+ for idx, (img, normalized_img, pred_img, gt_img) in enumerate(zip(imgs, normalized_imgs, pred_imgs, gt_imgs)):
|
|
|
+ img = transform(img)
|
|
|
+ normalized_img = transform(normalized_img)
|
|
|
+ pred_img = transform(pred_img)
|
|
|
+ gt_img = transform(gt_img)
|
|
|
+ f, axarr = plt.subplots(1, 4)
|
|
|
+ axarr[0].imshow(img)
|
|
|
+ axarr[0].title.set_text('orig')
|
|
|
+ axarr[1].imshow(normalized_img)
|
|
|
+ axarr[1].title.set_text('normal')
|
|
|
+ axarr[2].imshow(pred_img)
|
|
|
+ axarr[2].title.set_text('pred')
|
|
|
+ axarr[3].imshow(gt_img)
|
|
|
+ axarr[3].title.set_text('gt')
|
|
|
+ f.savefig(f"{img_dir}/{batch_idx:04d}_{idx}.jpg")
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+
|
|
|
+def evaluator(model:torch.nn.Module, epoch:int, test_loader:DataLoader, tag:str):
|
|
|
+ img_dir = f"{output}/{tag}/{epoch}"
|
|
|
+ if os.path.exists(img_dir):
|
|
|
+ shutil.rmtree(img_dir, ignore_errors=True)
|
|
|
+ os.makedirs(img_dir)
|
|
|
+ valid_loss = 0
|
|
|
+ model.eval()
|
|
|
+ eval_criterion = DocCleanLoss(device)
|
|
|
+ with torch.no_grad():
|
|
|
+ ssim_score = 0
|
|
|
+ psnr_score = 0
|
|
|
+ for index, (imgs, normalized_imgs, gt_imgs) in enumerate(tqdm(test_loader)):
|
|
|
+ imgs = imgs.to(device)
|
|
|
+ gt_imgs = gt_imgs.to(device)
|
|
|
+ normalized_imgs = normalized_imgs.to(device)
|
|
|
+ pred_imgs = model(normalized_imgs)
|
|
|
+ ssim_score += structural_similarity_index_measure(
|
|
|
+ pred_imgs, gt_imgs).item()
|
|
|
+ psnr_score += peak_signal_noise_ratio(pred_imgs, gt_imgs).item()
|
|
|
+ loss, _, _, _ = eval_criterion(pred_imgs, gt_imgs)
|
|
|
+ valid_loss += loss.item()
|
|
|
+ if index % 30 == 0:
|
|
|
+ saveEvalImg(img_dir=img_dir, batch_idx=index, imgs=imgs,
|
|
|
+ pred_imgs=pred_imgs, gt_imgs=gt_imgs, normalized_imgs=normalized_imgs)
|
|
|
+ data_len = len(test_loader)
|
|
|
+ valid_loss = valid_loss / data_len
|
|
|
+ psnr_score = psnr_score / data_len
|
|
|
+ ssim_score = ssim_score / data_len
|
|
|
+ return valid_loss, psnr_score, ssim_score
|
|
|
+
|
|
|
+
|
|
|
+def batch_mean_std(loader):
|
|
|
+ nb_samples = 0.
|
|
|
+ channel_mean = torch.zeros(3)
|
|
|
+ channel_std = torch.zeros(3)
|
|
|
+ for images, _, _ in tqdm(loader):
|
|
|
+ # scale image to be between 0 and 1
|
|
|
+ N, C, H, W = images.shape[:4]
|
|
|
+ data = images.view(N, C, -1)
|
|
|
+
|
|
|
+ channel_mean += data.mean(2).sum(0)
|
|
|
+ channel_std += data.std(2).sum(0)
|
|
|
+ nb_samples += N
|
|
|
+
|
|
|
+ channel_mean /= nb_samples
|
|
|
+ channel_std /= nb_samples
|
|
|
+ return channel_mean, channel_std
|
|
|
+
|
|
|
+
|
|
|
+def saveCkpt(model, model_path, epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score):
|
|
|
+ torch.save({
|
|
|
+ 'epoch': epoch,
|
|
|
+ 'model_state_dict': model.state_dict(),
|
|
|
+ 'optimizer_state_dict': optimizer.state_dict(),
|
|
|
+ 'scheduler_state_dict': scheduler.state_dict(),
|
|
|
+ 'loss': validation_loss,
|
|
|
+ 'mean': mean,
|
|
|
+ 'std': std,
|
|
|
+ 'psnr_score': psnr_score,
|
|
|
+ 'ssim_score': ssim_score
|
|
|
+ }, model_path)
|
|
|
+
|
|
|
+def trainer(model:torch.nn.Module, criterion:DocCleanLoss, optimizer:torch.optim.Adam, tag:str, epoch:int):
|
|
|
+ # train
|
|
|
+ model.train()
|
|
|
+ running_loss = 0
|
|
|
+ running_content_loss = 0
|
|
|
+ running_style_loss = 0
|
|
|
+ running_pixel_loss = 0
|
|
|
+ img_dir = f"{output}/{tag}/{epoch}"
|
|
|
+ if os.path.exists(img_dir):
|
|
|
+ shutil.rmtree(img_dir, ignore_errors=True)
|
|
|
+ os.makedirs(img_dir)
|
|
|
+ for index, (imgs, normalized_imgs, gt_imgs) in enumerate(tqdm(train_loader)):
|
|
|
+ optimizer.zero_grad()
|
|
|
+ imgs = imgs.to(device)
|
|
|
+ gt_imgs = gt_imgs.to(device)
|
|
|
+ normalized_imgs = normalized_imgs.to(device)
|
|
|
+ pred_imgs = model(normalized_imgs)
|
|
|
+ loss, p_l_loss, content_loss, style_loss = criterion(
|
|
|
+ pred_imgs, gt_imgs)
|
|
|
+ loss.backward()
|
|
|
+ optimizer.step()
|
|
|
+ running_loss += loss.item()
|
|
|
+ running_pixel_loss += p_l_loss.item()
|
|
|
+ running_content_loss += content_loss.item()
|
|
|
+ running_style_loss += style_loss.item()
|
|
|
+ if index % 200 == 0:
|
|
|
+ saveEvalImg(img_dir=img_dir, batch_idx=index, imgs=imgs,
|
|
|
+ pred_imgs=pred_imgs, gt_imgs=gt_imgs, normalized_imgs=normalized_imgs)
|
|
|
+ return running_loss, running_pixel_loss, running_content_loss, running_style_loss
|
|
|
+
|
|
|
+def model_pruning():
|
|
|
+ model, mean, std = M64ColorNet.load_trained_model("output/model.pt")
|
|
|
+ model.to(device)
|
|
|
+ # Compress this model.
|
|
|
+ config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
|
|
|
+ pruner = L1NormPruner(model, config_list)
|
|
|
+ _, masks = pruner.compress()
|
|
|
+
|
|
|
+ print('\nThe accuracy with masks:')
|
|
|
+ evaluator(model, 0, test_loader, "masks")
|
|
|
+
|
|
|
+ pruner._unwrap_model()
|
|
|
+ ModelSpeedup(model, dummy_input=torch.rand(1, 3, 256, 256).to(device), masks_file=masks).speedup_model()
|
|
|
+
|
|
|
+ print('\nThe accuracy after speedup:')
|
|
|
+ evaluator(model, 0, test_loader, "speedup")
|
|
|
+
|
|
|
+ # Need a new optimizer due to the modules in model will be replaced during speedup.
|
|
|
+ optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
|
|
|
+ criterion = DocCleanLoss(device=device)
|
|
|
+ print('\nFinetune the model after speedup:')
|
|
|
+ for i in range(5):
|
|
|
+ trainer(model, criterion, optimizer, "train_finetune", i)
|
|
|
+ evaluator(model, i, test_loader, "eval_finetune")
|
|
|
+
|
|
|
+def pretrain():
|
|
|
+ print(f"device={device} \
|
|
|
+ develop={args.develop} \
|
|
|
+ lr={args.lr} \
|
|
|
+ mean={mean} \
|
|
|
+ std={std} \
|
|
|
+ shuffle={args.shuffle}")
|
|
|
+ model_cls = M64ColorNet
|
|
|
+ model = model_cls()
|
|
|
+ model.to(device)
|
|
|
+ # summary(model, input_size=(batch_size, 3, 256, 256))
|
|
|
+ optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
|
|
|
+ scheduler = torch.optim.lr_scheduler.StepLR(
|
|
|
+ optimizer, step_size=15, gamma=0.8)
|
|
|
+
|
|
|
+ model_path = f"{output}/model.pt"
|
|
|
+ current_epoch = 1
|
|
|
+ previous_loss = float('inf')
|
|
|
+ criterion = DocCleanLoss(device)
|
|
|
+ if os.path.exists(model_path):
|
|
|
+ checkpoint = torch.load(model_path)
|
|
|
+ model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
+ current_epoch = checkpoint['epoch'] + 1
|
|
|
+ previous_loss = checkpoint['loss']
|
|
|
+ for epoch in range(current_epoch, current_epoch+args.epochs):
|
|
|
+ running_loss, running_pixel_loss, running_content_loss, running_style_loss = trainer(model, criterion, optimizer, "train", epoch)
|
|
|
+ train_loss = running_loss / len(train_loader)
|
|
|
+ train_content_loss = running_content_loss / len(train_loader)
|
|
|
+ train_style_loss = running_style_loss / len(train_loader)
|
|
|
+ train_pixel_loss = running_pixel_loss / len(train_loader)
|
|
|
+ # evaluate
|
|
|
+ validation_loss, psnr_score, ssim_score = evaluator(model, epoch, test_loader, "eval")
|
|
|
+ writer.add_scalar("Loss/train", train_loss, epoch)
|
|
|
+ writer.add_scalar("Loss/validation", validation_loss, epoch)
|
|
|
+ writer.add_scalar("metric/psnr", psnr_score, epoch)
|
|
|
+ writer.add_scalar("metric/ssim", ssim_score, epoch)
|
|
|
+ if previous_loss > validation_loss:
|
|
|
+ # This model_path is used for resume training. Hold the latest ckpt.
|
|
|
+ saveCkpt(model, model_path, epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score)
|
|
|
+ # This for each epoch ckpt.
|
|
|
+ saveCkpt(model, f"{output}/model_{epoch}.pt", epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score)
|
|
|
+ infer_test(f"{output}/infer_test/{epoch}",
|
|
|
+ "infer_imgs", model_path, model_cls)
|
|
|
+ previous_loss = validation_loss
|
|
|
+ scheduler.step()
|
|
|
+ print(
|
|
|
+ f"epoch:{epoch} \
|
|
|
+ train_loss:{round(train_loss, 4)} \
|
|
|
+ validation_loss:{round(validation_loss, 4)} \
|
|
|
+ pixel_loss:{round(train_pixel_loss, 4)} \
|
|
|
+ content_loss:{round(train_content_loss, 8)} \
|
|
|
+ style_loss:{round(train_style_loss, 4)} \
|
|
|
+ lr:{round(optimizer.param_groups[0]['lr'], 5)} \
|
|
|
+ psnr:{round(psnr_score, 3)} \
|
|
|
+ ssim:{round(ssim_score, 3)}"
|
|
|
+ )
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ args = parser.parse_args()
|
|
|
+ train_img_names, eval_img_names, imgs_dir = DocCleanDataset.prepareDataset(args.dataset, args.shuffle)
|
|
|
+ output = "output"
|
|
|
+ if args.retrain == True:
|
|
|
+ shutil.rmtree(output, ignore_errors=True)
|
|
|
+ if os.path.exists(output) == False:
|
|
|
+ os.mkdir(output)
|
|
|
+ print(
|
|
|
+ f"trainset num:{len(train_img_names)}\nevalset num:{len(eval_img_names)}")
|
|
|
+
|
|
|
+ dataset = DocCleanDataset(
|
|
|
+ img_names=train_img_names, imgs_dir=imgs_dir, dev=args.develop)
|
|
|
+ mean, std = batch_mean_std(DataLoader(
|
|
|
+ dataset=dataset, batch_size=args.batch_size))
|
|
|
+ # mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
|
|
+ # transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
|
|
|
+ train_set = DocCleanDataset(
|
|
|
+ img_names=train_img_names, imgs_dir=imgs_dir, normalized_tuple=(mean, std), dev=args.develop, img_aug=True)
|
|
|
+ test_set = DocCleanDataset(
|
|
|
+ img_names=eval_img_names, imgs_dir=imgs_dir, normalized_tuple=(mean, std), dev=args.develop)
|
|
|
+ train_loader = DataLoader(
|
|
|
+ dataset=train_set, batch_size=args.batch_size, shuffle=args.shuffle)
|
|
|
+ test_loader = DataLoader(dataset=test_set, batch_size=args.batch_size)
|
|
|
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
|
+
|
|
|
+ pretrain()
|
|
|
+
|
|
|
+ # model_pruning()
|
|
|
+
|
|
|
+ writer.flush()
|