MS_SSIM_L1_loss.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu Dec 3 00:28:15 2020
  4. @author: Yunpeng Li, Tianjin University
  5. """
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. class MS_SSIM_L1_LOSS(nn.Module):
  10. # Have to use cuda, otherwise the speed is too slow.
  11. def __init__(self, gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0],
  12. data_range = 1.0,
  13. K=(0.01, 0.03),
  14. alpha=0.025,
  15. compensation=200.0,
  16. device=torch.device('cpu')):
  17. super(MS_SSIM_L1_LOSS, self).__init__()
  18. self.DR = data_range
  19. self.C1 = (K[0] * data_range) ** 2
  20. self.C2 = (K[1] * data_range) ** 2
  21. self.pad = int(2 * gaussian_sigmas[-1])
  22. self.alpha = alpha
  23. self.compensation=compensation
  24. filter_size = int(4 * gaussian_sigmas[-1] + 1)
  25. g_masks = torch.zeros((3*len(gaussian_sigmas), 1, filter_size, filter_size))
  26. for idx, sigma in enumerate(gaussian_sigmas):
  27. # r0,g0,b0,r1,g1,b1,...,rM,gM,bM
  28. g_masks[3*idx+0, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
  29. g_masks[3*idx+1, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
  30. g_masks[3*idx+2, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
  31. self.g_masks = g_masks.to(device)
  32. def _fspecial_gauss_1d(self, size, sigma):
  33. """Create 1-D gauss kernel
  34. Args:
  35. size (int): the size of gauss kernel
  36. sigma (float): sigma of normal distribution
  37. Returns:
  38. torch.Tensor: 1D kernel (size)
  39. """
  40. coords = torch.arange(size).to(dtype=torch.float)
  41. coords -= size // 2
  42. g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
  43. g /= g.sum()
  44. return g.reshape(-1)
  45. def _fspecial_gauss_2d(self, size, sigma):
  46. """Create 2-D gauss kernel
  47. Args:
  48. size (int): the size of gauss kernel
  49. sigma (float): sigma of normal distribution
  50. Returns:
  51. torch.Tensor: 2D kernel (size x size)
  52. """
  53. gaussian_vec = self._fspecial_gauss_1d(size, sigma)
  54. return torch.outer(gaussian_vec, gaussian_vec)
  55. def forward(self, x, y):
  56. b, c, h, w = x.shape
  57. mux = F.conv2d(x, self.g_masks, groups=3, padding=self.pad)
  58. muy = F.conv2d(y, self.g_masks, groups=3, padding=self.pad)
  59. mux2 = mux * mux
  60. muy2 = muy * muy
  61. muxy = mux * muy
  62. sigmax2 = F.conv2d(x * x, self.g_masks, groups=3, padding=self.pad) - mux2
  63. sigmay2 = F.conv2d(y * y, self.g_masks, groups=3, padding=self.pad) - muy2
  64. sigmaxy = F.conv2d(x * y, self.g_masks, groups=3, padding=self.pad) - muxy
  65. # l(j), cs(j) in MS-SSIM
  66. l = (2 * muxy + self.C1) / (mux2 + muy2 + self.C1) # [B, 15, H, W]
  67. cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2)
  68. lM = l[:, -1, :, :] * l[:, -2, :, :] * l[:, -3, :, :]
  69. PIcs = cs.prod(dim=1)
  70. loss_ms_ssim = 1 - lM*PIcs # [B, H, W]
  71. loss_l1 = F.l1_loss(x, y, reduction='none') # [B, 3, H, W]
  72. # average l1 loss in 3 channels
  73. gaussian_l1 = F.conv2d(loss_l1, self.g_masks.narrow(dim=0, start=-3, length=3),
  74. groups=3, padding=self.pad).mean(1) # [B, H, W]
  75. loss_mix = self.alpha * loss_ms_ssim + (1 - self.alpha) * gaussian_l1 / self.DR
  76. loss_mix = self.compensation*loss_mix
  77. return loss_mix.mean()