import torch from MS_SSIM_L1_loss import MS_SSIM_L1_LOSS from vgg19 import VGG19 def gram(x): (bs, ch, h, w) = x.size() f = x.view(bs, ch, w * h) f_T = f.transpose(1, 2) G = f.bmm(f_T) / (ch * h * w) return G # class PixelLevelLoss(torch.nn.Module): # def __init__(self, device): # super(PixelLevelLoss, self).__init__() # # self.l1_loss = torch.nn.L1Loss() # self.l1_loss = MS_SSIM_L1_LOSS(device=device) # def forward(self, pred, gt): # # pred_yuv = rgb_to_ycbcr(pred) # # gt_yuv = rgb_to_ycbcr(gt) # # loss = torch.norm(torch.sub(pred_yuv, gt_yuv), p=1) # # loss = l1_loss(pred_yuv, gt_yuv) # loss = self.l1_loss(pred, gt) # return loss # https://zhuanlan.zhihu.com/p/92102879 class PerceptualLoss(torch.nn.Module): def __init__(self): super(PerceptualLoss, self).__init__() # self.l1_loss = torch.nn.L1Loss() self.l1_loss = torch.nn.MSELoss() def tv_loss(self, y_hat): return 0.5 * (torch.abs(y_hat[:, :, 1:, :] - y_hat[:, :, :-1, :]).mean() + torch.abs(y_hat[:, :, :, 1:] - y_hat[:, :, :, :-1]).mean()) def forward(self, y_hat, contents, style_pred_list, style_gt_list): content_pred, content_gt = contents _, c, h, w = content_pred.shape content_loss = self.l1_loss(content_pred, content_gt) / float(c * h * w) style_loss = 0 for style_pred, style_gt in zip(style_pred_list, style_gt_list): style_loss += self.l1_loss(gram(style_pred), gram(style_gt)) tv_l = self.tv_loss(y_hat) return content_loss, style_loss, tv_l class DocCleanLoss(torch.nn.Module): def __init__(self, device) -> None: super(DocCleanLoss, self).__init__() self.vgg19 = VGG19() self.vgg19.to(device) self.vgg19.eval() self.pixel_level_loss = MS_SSIM_L1_LOSS(device=device) self.perceptual_loss = PerceptualLoss() def forward(self, pred_imgs, gt_imgs): p_l_loss = self.pixel_level_loss(pred_imgs, gt_imgs) contents, style_pred_list, style_gt_list = self.vgg19(pred_imgs, gt_imgs) content_loss, style_loss, tv_l = self.perceptual_loss(pred_imgs, contents, style_pred_list, style_gt_list) # return 1e1*p_l_loss + 1e-1*content_loss + 1e1*style_loss, p_l_loss, content_loss, style_loss return 1e1*p_l_loss + 1e1*content_loss + 1e-1*style_loss + 1e1*tv_l, p_l_loss, content_loss, style_loss