loss.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import torch
  2. from MS_SSIM_L1_loss import MS_SSIM_L1_LOSS
  3. from vgg19 import VGG19
  4. def gram(x):
  5. (bs, ch, h, w) = x.size()
  6. f = x.view(bs, ch, w * h)
  7. f_T = f.transpose(1, 2)
  8. G = f.bmm(f_T) / (ch * h * w)
  9. return G
  10. # class PixelLevelLoss(torch.nn.Module):
  11. # def __init__(self, device):
  12. # super(PixelLevelLoss, self).__init__()
  13. # # self.l1_loss = torch.nn.L1Loss()
  14. # self.l1_loss = MS_SSIM_L1_LOSS(device=device)
  15. # def forward(self, pred, gt):
  16. # # pred_yuv = rgb_to_ycbcr(pred)
  17. # # gt_yuv = rgb_to_ycbcr(gt)
  18. # # loss = torch.norm(torch.sub(pred_yuv, gt_yuv), p=1)
  19. # # loss = l1_loss(pred_yuv, gt_yuv)
  20. # loss = self.l1_loss(pred, gt)
  21. # return loss
  22. # https://zhuanlan.zhihu.com/p/92102879
  23. class PerceptualLoss(torch.nn.Module):
  24. def __init__(self):
  25. super(PerceptualLoss, self).__init__()
  26. # self.l1_loss = torch.nn.L1Loss()
  27. self.l1_loss = torch.nn.MSELoss()
  28. def tv_loss(self, y_hat):
  29. return 0.5 * (torch.abs(y_hat[:, :, 1:, :] - y_hat[:, :, :-1, :]).mean() +
  30. torch.abs(y_hat[:, :, :, 1:] - y_hat[:, :, :, :-1]).mean())
  31. def forward(self, y_hat, contents, style_pred_list, style_gt_list):
  32. content_pred, content_gt = contents
  33. _, c, h, w = content_pred.shape
  34. content_loss = self.l1_loss(content_pred, content_gt) / float(c * h * w)
  35. style_loss = 0
  36. for style_pred, style_gt in zip(style_pred_list, style_gt_list):
  37. style_loss += self.l1_loss(gram(style_pred), gram(style_gt))
  38. tv_l = self.tv_loss(y_hat)
  39. return content_loss, style_loss, tv_l
  40. class DocCleanLoss(torch.nn.Module):
  41. def __init__(self, device) -> None:
  42. super(DocCleanLoss, self).__init__()
  43. self.vgg19 = VGG19()
  44. self.vgg19.to(device)
  45. self.vgg19.eval()
  46. self.pixel_level_loss = MS_SSIM_L1_LOSS(device=device)
  47. self.perceptual_loss = PerceptualLoss()
  48. def forward(self, pred_imgs, gt_imgs):
  49. p_l_loss = self.pixel_level_loss(pred_imgs, gt_imgs)
  50. contents, style_pred_list, style_gt_list = self.vgg19(pred_imgs, gt_imgs)
  51. content_loss, style_loss, tv_l = self.perceptual_loss(pred_imgs, contents, style_pred_list, style_gt_list)
  52. # return 1e1*p_l_loss + 1e-1*content_loss + 1e1*style_loss, p_l_loss, content_loss, style_loss
  53. return 1e1*p_l_loss + 1e1*content_loss + 1e-1*style_loss + 1e1*tv_l, p_l_loss, content_loss, style_loss