model.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import torch
  2. from torch import nn
  3. import torch.nn.functional as F
  4. class ResidualBlock(nn.Module):
  5. def __init__(self, channels):
  6. super(ResidualBlock, self).__init__()
  7. self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
  8. self.bn1 = nn.BatchNorm2d(channels)
  9. self.relu6_1 = nn.ReLU6(inplace=True)
  10. self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
  11. self.bn2 = nn.BatchNorm2d(channels)
  12. self.relu6_2 = nn.ReLU6(inplace=True)
  13. self.relu6_latest = nn.ReLU6(inplace=True)
  14. def forward(self, x):
  15. residual = self.conv1(x)
  16. residual = self.bn1(residual)
  17. residual = self.relu6_1(residual)
  18. residual = self.conv2(residual)
  19. residual = self.bn2(residual)
  20. residual = self.relu6_2(residual)
  21. add = x + residual
  22. return self.relu6_latest(add)
  23. class M64ColorNet(nn.Module):
  24. def __init__(self):
  25. super(M64ColorNet, self).__init__()
  26. self.block1 = nn.Sequential(
  27. nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=False, stride=1),
  28. nn.BatchNorm2d(16),
  29. nn.ReLU6(inplace=True)
  30. )
  31. self.block2 = nn.Sequential(
  32. nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, bias=False, stride=2),
  33. nn.BatchNorm2d(32),
  34. nn.ReLU6(inplace=True)
  35. )
  36. self.block3 = nn.Sequential(
  37. nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1, bias=False, stride=2),
  38. nn.BatchNorm2d(64),
  39. nn.ReLU6(inplace=True)
  40. )
  41. self.block4 = ResidualBlock(64)
  42. self.block5 = ResidualBlock(64)
  43. self.block6 = ResidualBlock(64)
  44. self.block7 = ResidualBlock(64)
  45. self.block8 = ResidualBlock(64)
  46. self.block9 = nn.Sequential(
  47. nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=False),
  48. nn.BatchNorm2d(64),
  49. nn.ReLU6(inplace=True)
  50. )
  51. self.block10 = nn.Sequential(
  52. # nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, bias=False),
  53. nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, bias=False, stride=2, output_padding=1),
  54. nn.BatchNorm2d(32),
  55. nn.ReLU6(inplace=True)
  56. )
  57. self.block11 = nn.Sequential(
  58. # nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1, bias=False),
  59. nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, padding=1, bias=False, stride=2, output_padding=1),
  60. nn.BatchNorm2d(16),
  61. nn.ReLU6(inplace=True)
  62. )
  63. self.block12 = nn.Sequential(
  64. nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, padding=1, bias=False),
  65. nn.BatchNorm2d(3),
  66. nn.ReLU6(inplace=True)
  67. )
  68. # self.dropout = nn.Dropout(0.4)
  69. def forward(self, x):
  70. input = x
  71. x = self.block1(x)
  72. input2 = x
  73. x = self.block2(x)
  74. input3 = x
  75. x = self.block3(x)
  76. input4 = x
  77. x = self.block4(x)
  78. x = self.block5(x)
  79. x = self.block6(x)
  80. x = self.block7(x)
  81. x = self.block8(x)
  82. x = input4 + x
  83. x = self.block9(x)
  84. x = self.block10(x)
  85. x = input3 + x
  86. x = self.block11(x)
  87. x = input2 + x
  88. x = self.block12(x)
  89. x = input + x
  90. x = torch.sigmoid(x)
  91. return x
  92. @staticmethod
  93. def load_trained_model(ckpt_path):
  94. ckpt_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
  95. model = M64ColorNet()
  96. model.load_state_dict(ckpt_dict["model_state_dict"])
  97. model.eval()
  98. return model, ckpt_dict["mean"], ckpt_dict["std"], ckpt_dict["loss"], ckpt_dict["ssim_score"], ckpt_dict["psnr_score"]