from infer import infer_test from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure from torchvision import transforms from torch.utils.tensorboard import SummaryWriter import argparse import torchvision.transforms as T import shutil import os from matplotlib import pyplot as plt from model import M64ColorNet from loss import DocCleanLoss from torch.utils.data import DataLoader from dataset import DocCleanDataset import torch from tqdm import tqdm from nni.compression.pytorch.pruning import L1NormPruner from nni.compression.pytorch.speedup import ModelSpeedup import matplotlib matplotlib.use('Agg') # from torchinfo import summary writer = SummaryWriter() def boolean_string(s): ''' Check s string is true or false. Args: s: the string Returns: boolean ''' s = s.lower() if s not in {'false', 'true'}: raise ValueError('Not a valid boolean string') return s == 'true' # path parameters parser = argparse.ArgumentParser() parser.add_argument('--develop', type=boolean_string, help='Develop mode turn off by default', default=False) parser.add_argument('--lr', type=float, help='Develop mode turn off by default', default=1e-3) parser.add_argument('--batch_size', type=int, help='Develop mode turn off by default', default=16) parser.add_argument('--retrain', type=boolean_string, help='Whether to restore the checkpoint', default=False) parser.add_argument('--epochs', type=int, help='Max training epoch', default=500) parser.add_argument('--dataset', type=str, help='Max training epoch', default="dataset/raw_data/imgs_Trainblocks") parser.add_argument('--shuffle', type=boolean_string, help='Whether to shuffle dataset', default=True) def saveEvalImg(img_dir: str, batch_idx: int, imgs, pred_imgs, gt_imgs, normalized_imgs): transform = T.ToPILImage() for idx, (img, normalized_img, pred_img, gt_img) in enumerate(zip(imgs, normalized_imgs, pred_imgs, gt_imgs)): img = transform(img) normalized_img = transform(normalized_img) pred_img = transform(pred_img) gt_img = transform(gt_img) f, axarr = plt.subplots(1, 4) axarr[0].imshow(img) axarr[0].title.set_text('orig') axarr[1].imshow(normalized_img) axarr[1].title.set_text('normal') axarr[2].imshow(pred_img) axarr[2].title.set_text('pred') axarr[3].imshow(gt_img) axarr[3].title.set_text('gt') f.savefig(f"{img_dir}/{batch_idx:04d}_{idx}.jpg") plt.close() def evaluator(model:torch.nn.Module, epoch:int, test_loader:DataLoader, tag:str): img_dir = f"{output}/{tag}/{epoch}" if os.path.exists(img_dir): shutil.rmtree(img_dir, ignore_errors=True) os.makedirs(img_dir) valid_loss = 0 model.eval() eval_criterion = DocCleanLoss(device) with torch.no_grad(): ssim_score = 0 psnr_score = 0 for index, (imgs, normalized_imgs, gt_imgs) in enumerate(tqdm(test_loader)): imgs = imgs.to(device) gt_imgs = gt_imgs.to(device) normalized_imgs = normalized_imgs.to(device) pred_imgs = model(normalized_imgs) ssim_score += structural_similarity_index_measure( pred_imgs, gt_imgs).item() psnr_score += peak_signal_noise_ratio(pred_imgs, gt_imgs).item() loss, _, _, _ = eval_criterion(pred_imgs, gt_imgs) valid_loss += loss.item() if index % 30 == 0: saveEvalImg(img_dir=img_dir, batch_idx=index, imgs=imgs, pred_imgs=pred_imgs, gt_imgs=gt_imgs, normalized_imgs=normalized_imgs) data_len = len(test_loader) valid_loss = valid_loss / data_len psnr_score = psnr_score / data_len ssim_score = ssim_score / data_len return valid_loss, psnr_score, ssim_score def batch_mean_std(loader): nb_samples = 0. channel_mean = torch.zeros(3) channel_std = torch.zeros(3) for images, _, _ in tqdm(loader): # scale image to be between 0 and 1 N, C, H, W = images.shape[:4] data = images.view(N, C, -1) channel_mean += data.mean(2).sum(0) channel_std += data.std(2).sum(0) nb_samples += N channel_mean /= nb_samples channel_std /= nb_samples return channel_mean, channel_std def saveCkpt(model, model_path, epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score): torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': validation_loss, 'mean': mean, 'std': std, 'psnr_score': psnr_score, 'ssim_score': ssim_score }, model_path) def trainer(model:torch.nn.Module, criterion:DocCleanLoss, optimizer:torch.optim.Adam, tag:str, epoch:int): # train model.train() running_loss = 0 running_content_loss = 0 running_style_loss = 0 running_pixel_loss = 0 img_dir = f"{output}/{tag}/{epoch}" if os.path.exists(img_dir): shutil.rmtree(img_dir, ignore_errors=True) os.makedirs(img_dir) for index, (imgs, normalized_imgs, gt_imgs) in enumerate(tqdm(train_loader)): optimizer.zero_grad() imgs = imgs.to(device) gt_imgs = gt_imgs.to(device) normalized_imgs = normalized_imgs.to(device) pred_imgs = model(normalized_imgs) loss, p_l_loss, content_loss, style_loss = criterion( pred_imgs, gt_imgs) loss.backward() optimizer.step() running_loss += loss.item() running_pixel_loss += p_l_loss.item() running_content_loss += content_loss.item() running_style_loss += style_loss.item() if index % 200 == 0: saveEvalImg(img_dir=img_dir, batch_idx=index, imgs=imgs, pred_imgs=pred_imgs, gt_imgs=gt_imgs, normalized_imgs=normalized_imgs) return running_loss, running_pixel_loss, running_content_loss, running_style_loss def model_pruning(): model, mean, std = M64ColorNet.load_trained_model("output/model.pt") model.to(device) # Compress this model. config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] pruner = L1NormPruner(model, config_list) _, masks = pruner.compress() print('\nThe accuracy with masks:') evaluator(model, 0, test_loader, "masks") pruner._unwrap_model() ModelSpeedup(model, dummy_input=torch.rand(1, 3, 256, 256).to(device), masks_file=masks).speedup_model() print('\nThe accuracy after speedup:') evaluator(model, 0, test_loader, "speedup") # Need a new optimizer due to the modules in model will be replaced during speedup. optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) criterion = DocCleanLoss(device=device) print('\nFinetune the model after speedup:') for i in range(5): trainer(model, criterion, optimizer, "train_finetune", i) evaluator(model, i, test_loader, "eval_finetune") def pretrain(): print(f"device={device} \ develop={args.develop} \ lr={args.lr} \ mean={mean} \ std={std} \ shuffle={args.shuffle}") model_cls = M64ColorNet model = model_cls() model.to(device) # summary(model, input_size=(batch_size, 3, 256, 256)) optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=15, gamma=0.8) model_path = f"{output}/model.pt" current_epoch = 1 previous_loss = float('inf') criterion = DocCleanLoss(device) if os.path.exists(model_path): checkpoint = torch.load(model_path) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) current_epoch = checkpoint['epoch'] + 1 previous_loss = checkpoint['loss'] for epoch in range(current_epoch, current_epoch+args.epochs): running_loss, running_pixel_loss, running_content_loss, running_style_loss = trainer(model, criterion, optimizer, "train", epoch) train_loss = running_loss / len(train_loader) train_content_loss = running_content_loss / len(train_loader) train_style_loss = running_style_loss / len(train_loader) train_pixel_loss = running_pixel_loss / len(train_loader) # evaluate validation_loss, psnr_score, ssim_score = evaluator(model, epoch, test_loader, "eval") writer.add_scalar("Loss/train", train_loss, epoch) writer.add_scalar("Loss/validation", validation_loss, epoch) writer.add_scalar("metric/psnr", psnr_score, epoch) writer.add_scalar("metric/ssim", ssim_score, epoch) if previous_loss > validation_loss: # This model_path is used for resume training. Hold the latest ckpt. saveCkpt(model, model_path, epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score) # This for each epoch ckpt. saveCkpt(model, f"{output}/model_{epoch}.pt", epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score) infer_test(f"{output}/infer_test/{epoch}", "infer_imgs", model_path, model_cls) previous_loss = validation_loss scheduler.step() print( f"epoch:{epoch} \ train_loss:{round(train_loss, 4)} \ validation_loss:{round(validation_loss, 4)} \ pixel_loss:{round(train_pixel_loss, 4)} \ content_loss:{round(train_content_loss, 8)} \ style_loss:{round(train_style_loss, 4)} \ lr:{round(optimizer.param_groups[0]['lr'], 5)} \ psnr:{round(psnr_score, 3)} \ ssim:{round(ssim_score, 3)}" ) if __name__ == "__main__": args = parser.parse_args() train_img_names, eval_img_names, imgs_dir = DocCleanDataset.prepareDataset(args.dataset, args.shuffle) output = "output" if args.retrain == True: shutil.rmtree(output, ignore_errors=True) if os.path.exists(output) == False: os.mkdir(output) print( f"trainset num:{len(train_img_names)}\nevalset num:{len(eval_img_names)}") dataset = DocCleanDataset( img_names=train_img_names, imgs_dir=imgs_dir, dev=args.develop) mean, std = batch_mean_std(DataLoader( dataset=dataset, batch_size=args.batch_size)) # mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] # transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)]) train_set = DocCleanDataset( img_names=train_img_names, imgs_dir=imgs_dir, normalized_tuple=(mean, std), dev=args.develop, img_aug=True) test_set = DocCleanDataset( img_names=eval_img_names, imgs_dir=imgs_dir, normalized_tuple=(mean, std), dev=args.develop) train_loader = DataLoader( dataset=train_set, batch_size=args.batch_size, shuffle=args.shuffle) test_loader = DataLoader(dataset=test_set, batch_size=args.batch_size) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') pretrain() # model_pruning() writer.flush()