123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- from infer import infer_test
- from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure
- from torchvision import transforms
- from torch.utils.tensorboard import SummaryWriter
- import argparse
- import torchvision.transforms as T
- import shutil
- import os
- from matplotlib import pyplot as plt
- from model import M64ColorNet
- from loss import DocCleanLoss
- from torch.utils.data import DataLoader
- from dataset import DocCleanDataset
- import torch
- from tqdm import tqdm
- from nni.compression.pytorch.pruning import L1NormPruner
- from nni.compression.pytorch.speedup import ModelSpeedup
- import matplotlib
- matplotlib.use('Agg')
- # from torchinfo import summary
- writer = SummaryWriter()
- def boolean_string(s):
- ''' Check s string is true or false.
- Args:
- s: the string
- Returns:
- boolean
- '''
- s = s.lower()
- if s not in {'false', 'true'}:
- raise ValueError('Not a valid boolean string')
- return s == 'true'
- # path parameters
- parser = argparse.ArgumentParser()
- parser.add_argument('--develop',
- type=boolean_string,
- help='Develop mode turn off by default',
- default=False)
- parser.add_argument('--lr',
- type=float,
- help='Develop mode turn off by default',
- default=1e-3)
- parser.add_argument('--batch_size',
- type=int,
- help='Develop mode turn off by default',
- default=16)
- parser.add_argument('--retrain',
- type=boolean_string,
- help='Whether to restore the checkpoint',
- default=False)
- parser.add_argument('--epochs',
- type=int,
- help='Max training epoch',
- default=500)
- parser.add_argument('--dataset',
- type=str,
- help='Max training epoch',
- default="dataset/raw_data/imgs_Trainblocks")
- parser.add_argument('--shuffle',
- type=boolean_string,
- help='Whether to shuffle dataset',
- default=True)
- def saveEvalImg(img_dir: str, batch_idx: int, imgs, pred_imgs, gt_imgs, normalized_imgs):
- transform = T.ToPILImage()
- for idx, (img, normalized_img, pred_img, gt_img) in enumerate(zip(imgs, normalized_imgs, pred_imgs, gt_imgs)):
- img = transform(img)
- normalized_img = transform(normalized_img)
- pred_img = transform(pred_img)
- gt_img = transform(gt_img)
- f, axarr = plt.subplots(1, 4)
- axarr[0].imshow(img)
- axarr[0].title.set_text('orig')
- axarr[1].imshow(normalized_img)
- axarr[1].title.set_text('normal')
- axarr[2].imshow(pred_img)
- axarr[2].title.set_text('pred')
- axarr[3].imshow(gt_img)
- axarr[3].title.set_text('gt')
- f.savefig(f"{img_dir}/{batch_idx:04d}_{idx}.jpg")
- plt.close()
- def evaluator(model:torch.nn.Module, epoch:int, test_loader:DataLoader, tag:str):
- img_dir = f"{output}/{tag}/{epoch}"
- if os.path.exists(img_dir):
- shutil.rmtree(img_dir, ignore_errors=True)
- os.makedirs(img_dir)
- valid_loss = 0
- model.eval()
- eval_criterion = DocCleanLoss(device)
- with torch.no_grad():
- ssim_score = 0
- psnr_score = 0
- for index, (imgs, normalized_imgs, gt_imgs) in enumerate(tqdm(test_loader)):
- imgs = imgs.to(device)
- gt_imgs = gt_imgs.to(device)
- normalized_imgs = normalized_imgs.to(device)
- pred_imgs = model(normalized_imgs)
- ssim_score += structural_similarity_index_measure(
- pred_imgs, gt_imgs).item()
- psnr_score += peak_signal_noise_ratio(pred_imgs, gt_imgs).item()
- loss, _, _, _ = eval_criterion(pred_imgs, gt_imgs)
- valid_loss += loss.item()
- if index % 30 == 0:
- saveEvalImg(img_dir=img_dir, batch_idx=index, imgs=imgs,
- pred_imgs=pred_imgs, gt_imgs=gt_imgs, normalized_imgs=normalized_imgs)
- data_len = len(test_loader)
- valid_loss = valid_loss / data_len
- psnr_score = psnr_score / data_len
- ssim_score = ssim_score / data_len
- return valid_loss, psnr_score, ssim_score
- def batch_mean_std(loader):
- nb_samples = 0.
- channel_mean = torch.zeros(3)
- channel_std = torch.zeros(3)
- for images, _, _ in tqdm(loader):
- # scale image to be between 0 and 1
- N, C, H, W = images.shape[:4]
- data = images.view(N, C, -1)
- channel_mean += data.mean(2).sum(0)
- channel_std += data.std(2).sum(0)
- nb_samples += N
- channel_mean /= nb_samples
- channel_std /= nb_samples
- return channel_mean, channel_std
- def saveCkpt(model, model_path, epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score):
- torch.save({
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'scheduler_state_dict': scheduler.state_dict(),
- 'loss': validation_loss,
- 'mean': mean,
- 'std': std,
- 'psnr_score': psnr_score,
- 'ssim_score': ssim_score
- }, model_path)
- def trainer(model:torch.nn.Module, criterion:DocCleanLoss, optimizer:torch.optim.Adam, tag:str, epoch:int):
- # train
- model.train()
- running_loss = 0
- running_content_loss = 0
- running_style_loss = 0
- running_pixel_loss = 0
- img_dir = f"{output}/{tag}/{epoch}"
- if os.path.exists(img_dir):
- shutil.rmtree(img_dir, ignore_errors=True)
- os.makedirs(img_dir)
- for index, (imgs, normalized_imgs, gt_imgs) in enumerate(tqdm(train_loader)):
- optimizer.zero_grad()
- imgs = imgs.to(device)
- gt_imgs = gt_imgs.to(device)
- normalized_imgs = normalized_imgs.to(device)
- pred_imgs = model(normalized_imgs)
- loss, p_l_loss, content_loss, style_loss = criterion(
- pred_imgs, gt_imgs)
- loss.backward()
- optimizer.step()
- running_loss += loss.item()
- running_pixel_loss += p_l_loss.item()
- running_content_loss += content_loss.item()
- running_style_loss += style_loss.item()
- if index % 200 == 0:
- saveEvalImg(img_dir=img_dir, batch_idx=index, imgs=imgs,
- pred_imgs=pred_imgs, gt_imgs=gt_imgs, normalized_imgs=normalized_imgs)
- return running_loss, running_pixel_loss, running_content_loss, running_style_loss
- def model_pruning():
- model, mean, std = M64ColorNet.load_trained_model("output/model.pt")
- model.to(device)
- # Compress this model.
- config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
- pruner = L1NormPruner(model, config_list)
- _, masks = pruner.compress()
- print('\nThe accuracy with masks:')
- evaluator(model, 0, test_loader, "masks")
- pruner._unwrap_model()
- ModelSpeedup(model, dummy_input=torch.rand(1, 3, 256, 256).to(device), masks_file=masks).speedup_model()
- print('\nThe accuracy after speedup:')
- evaluator(model, 0, test_loader, "speedup")
- # Need a new optimizer due to the modules in model will be replaced during speedup.
- optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
- criterion = DocCleanLoss(device=device)
- print('\nFinetune the model after speedup:')
- for i in range(5):
- trainer(model, criterion, optimizer, "train_finetune", i)
- evaluator(model, i, test_loader, "eval_finetune")
- def pretrain():
- print(f"device={device} \
- develop={args.develop} \
- lr={args.lr} \
- mean={mean} \
- std={std} \
- shuffle={args.shuffle}")
- model_cls = M64ColorNet
- model = model_cls()
- model.to(device)
- # summary(model, input_size=(batch_size, 3, 256, 256))
- optimizer = torch.optim.Adam(params=model.parameters(), lr=args.lr)
- scheduler = torch.optim.lr_scheduler.StepLR(
- optimizer, step_size=15, gamma=0.8)
- model_path = f"{output}/model.pt"
- current_epoch = 1
- previous_loss = float('inf')
- criterion = DocCleanLoss(device)
- if os.path.exists(model_path):
- checkpoint = torch.load(model_path)
- model.load_state_dict(checkpoint['model_state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
- current_epoch = checkpoint['epoch'] + 1
- previous_loss = checkpoint['loss']
- for epoch in range(current_epoch, current_epoch+args.epochs):
- running_loss, running_pixel_loss, running_content_loss, running_style_loss = trainer(model, criterion, optimizer, "train", epoch)
- train_loss = running_loss / len(train_loader)
- train_content_loss = running_content_loss / len(train_loader)
- train_style_loss = running_style_loss / len(train_loader)
- train_pixel_loss = running_pixel_loss / len(train_loader)
- # evaluate
- validation_loss, psnr_score, ssim_score = evaluator(model, epoch, test_loader, "eval")
- writer.add_scalar("Loss/train", train_loss, epoch)
- writer.add_scalar("Loss/validation", validation_loss, epoch)
- writer.add_scalar("metric/psnr", psnr_score, epoch)
- writer.add_scalar("metric/ssim", ssim_score, epoch)
- if previous_loss > validation_loss:
- # This model_path is used for resume training. Hold the latest ckpt.
- saveCkpt(model, model_path, epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score)
- # This for each epoch ckpt.
- saveCkpt(model, f"{output}/model_{epoch}.pt", epoch, optimizer, scheduler, validation_loss, mean, std, psnr_score, ssim_score)
- infer_test(f"{output}/infer_test/{epoch}",
- "infer_imgs", model_path, model_cls)
- previous_loss = validation_loss
- scheduler.step()
- print(
- f"epoch:{epoch} \
- train_loss:{round(train_loss, 4)} \
- validation_loss:{round(validation_loss, 4)} \
- pixel_loss:{round(train_pixel_loss, 4)} \
- content_loss:{round(train_content_loss, 8)} \
- style_loss:{round(train_style_loss, 4)} \
- lr:{round(optimizer.param_groups[0]['lr'], 5)} \
- psnr:{round(psnr_score, 3)} \
- ssim:{round(ssim_score, 3)}"
- )
- if __name__ == "__main__":
- args = parser.parse_args()
- train_img_names, eval_img_names, imgs_dir = DocCleanDataset.prepareDataset(args.dataset, args.shuffle)
- output = "output"
- if args.retrain == True:
- shutil.rmtree(output, ignore_errors=True)
- if os.path.exists(output) == False:
- os.mkdir(output)
- print(
- f"trainset num:{len(train_img_names)}\nevalset num:{len(eval_img_names)}")
- dataset = DocCleanDataset(
- img_names=train_img_names, imgs_dir=imgs_dir, dev=args.develop)
- mean, std = batch_mean_std(DataLoader(
- dataset=dataset, batch_size=args.batch_size))
- # mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
- # transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
- train_set = DocCleanDataset(
- img_names=train_img_names, imgs_dir=imgs_dir, normalized_tuple=(mean, std), dev=args.develop, img_aug=True)
- test_set = DocCleanDataset(
- img_names=eval_img_names, imgs_dir=imgs_dir, normalized_tuple=(mean, std), dev=args.develop)
- train_loader = DataLoader(
- dataset=train_set, batch_size=args.batch_size, shuffle=args.shuffle)
- test_loader = DataLoader(dataset=test_set, batch_size=args.batch_size)
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
- pretrain()
- # model_pruning()
- writer.flush()
|