1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- import torch
- from MS_SSIM_L1_loss import MS_SSIM_L1_LOSS
- from vgg19 import VGG19
- def gram(x):
- (bs, ch, h, w) = x.size()
- f = x.view(bs, ch, w * h)
- f_T = f.transpose(1, 2)
- G = f.bmm(f_T) / (ch * h * w)
- return G
- # class PixelLevelLoss(torch.nn.Module):
- # def __init__(self, device):
- # super(PixelLevelLoss, self).__init__()
- # # self.l1_loss = torch.nn.L1Loss()
- # self.l1_loss = MS_SSIM_L1_LOSS(device=device)
- # def forward(self, pred, gt):
- # # pred_yuv = rgb_to_ycbcr(pred)
- # # gt_yuv = rgb_to_ycbcr(gt)
- # # loss = torch.norm(torch.sub(pred_yuv, gt_yuv), p=1)
- # # loss = l1_loss(pred_yuv, gt_yuv)
- # loss = self.l1_loss(pred, gt)
- # return loss
- # https://zhuanlan.zhihu.com/p/92102879
- class PerceptualLoss(torch.nn.Module):
- def __init__(self):
- super(PerceptualLoss, self).__init__()
- # self.l1_loss = torch.nn.L1Loss()
- self.l1_loss = torch.nn.MSELoss()
-
- def tv_loss(self, y_hat):
- return 0.5 * (torch.abs(y_hat[:, :, 1:, :] - y_hat[:, :, :-1, :]).mean() +
- torch.abs(y_hat[:, :, :, 1:] - y_hat[:, :, :, :-1]).mean())
- def forward(self, y_hat, contents, style_pred_list, style_gt_list):
- content_pred, content_gt = contents
- _, c, h, w = content_pred.shape
- content_loss = self.l1_loss(content_pred, content_gt) / float(c * h * w)
- style_loss = 0
- for style_pred, style_gt in zip(style_pred_list, style_gt_list):
- style_loss += self.l1_loss(gram(style_pred), gram(style_gt))
- tv_l = self.tv_loss(y_hat)
- return content_loss, style_loss, tv_l
- class DocCleanLoss(torch.nn.Module):
- def __init__(self, device) -> None:
- super(DocCleanLoss, self).__init__()
- self.vgg19 = VGG19()
- self.vgg19.to(device)
- self.vgg19.eval()
- self.pixel_level_loss = MS_SSIM_L1_LOSS(device=device)
- self.perceptual_loss = PerceptualLoss()
- def forward(self, pred_imgs, gt_imgs):
- p_l_loss = self.pixel_level_loss(pred_imgs, gt_imgs)
- contents, style_pred_list, style_gt_list = self.vgg19(pred_imgs, gt_imgs)
- content_loss, style_loss, tv_l = self.perceptual_loss(pred_imgs, contents, style_pred_list, style_gt_list)
- # return 1e1*p_l_loss + 1e-1*content_loss + 1e1*style_loss, p_l_loss, content_loss, style_loss
- return 1e1*p_l_loss + 1e1*content_loss + 1e-1*style_loss + 1e1*tv_l, p_l_loss, content_loss, style_loss
|