model_l.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import torch
  2. from torch import nn
  3. import torch.nn.functional as F
  4. from torchinfo import summary
  5. class ResidualBlock(nn.Module):
  6. def __init__(self, channels):
  7. super(ResidualBlock, self).__init__()
  8. self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
  9. self.bn1 = nn.BatchNorm2d(channels)
  10. self.relu6_1 = nn.ReLU6(inplace=True)
  11. self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False)
  12. self.bn2 = nn.BatchNorm2d(channels)
  13. self.relu6_2 = nn.ReLU6(inplace=True)
  14. self.relu6_latest = nn.ReLU6(inplace=True)
  15. def forward(self, x):
  16. residual = self.conv1(x)
  17. residual = self.bn1(residual)
  18. residual = self.relu6_1(residual)
  19. residual = self.conv2(residual)
  20. residual = self.bn2(residual)
  21. residual = self.relu6_2(residual)
  22. add = x + residual
  23. return self.relu6_latest(add)
  24. class ResidualBlock_2(nn.Module):
  25. def __init__(self, channels, reduction=2):
  26. super(ResidualBlock_2, self).__init__()
  27. self.conv1 = nn.Conv2d(channels, channels//reduction, kernel_size=1, padding=0, bias=False)
  28. self.bn1 = nn.BatchNorm2d(channels//reduction)
  29. self.relu1= nn.ReLU6(inplace=True)
  30. self.conv2 = nn.Conv2d(channels//reduction, channels//reduction, kernel_size=3, padding=1, bias=False)
  31. self.bn2 = nn.BatchNorm2d(channels//reduction)
  32. self.relu2 = nn.ReLU6(inplace=True)
  33. self.conv3 = nn.Conv2d(channels // reduction, channels // reduction, kernel_size=3, padding=1, bias=False)
  34. self.bn3 = nn.BatchNorm2d(channels // reduction)
  35. self.relu3 = nn.ReLU6(inplace=True)
  36. self.conv4 = nn.Conv2d(channels//reduction, channels, kernel_size=1, padding=0, bias=False)
  37. self.bn4 = nn.BatchNorm2d(channels)
  38. self.relu4 = nn.ReLU6(inplace=True)
  39. self.relu_latest = nn.ReLU6(inplace=True)
  40. def forward(self, x):
  41. residual = self.conv1(x)
  42. residual = self.bn1(residual)
  43. residual = self.relu1(residual)
  44. residual = self.conv2(residual)
  45. residual = self.bn2(residual)
  46. residual = self.relu2(residual)
  47. residual = self.conv3(residual)
  48. residual = self.bn3(residual)
  49. residual = self.relu3(residual)
  50. residual = self.conv4(residual)
  51. residual = self.bn4(residual)
  52. residual = self.relu4(residual)
  53. add = x + residual
  54. return self.relu_latest(add)
  55. class M64ColorNet(nn.Module):
  56. def __init__(self):
  57. super(M64ColorNet, self).__init__()
  58. self.block1 = nn.Sequential(
  59. nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=False, stride=1),
  60. nn.BatchNorm2d(16),
  61. nn.ReLU6(inplace=True)
  62. )
  63. self.block2 = nn.Sequential(
  64. nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, bias=False, stride=2),
  65. nn.BatchNorm2d(32),
  66. nn.ReLU6(inplace=True)
  67. )
  68. self.block3 = nn.Sequential(
  69. nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1, bias=False, stride=2),
  70. nn.BatchNorm2d(64),
  71. nn.ReLU6(inplace=True)
  72. )
  73. self.block4 = ResidualBlock(64)
  74. self.block5 = ResidualBlock(64)
  75. # self.block6 = nn.Sequential(
  76. # nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, bias=False),
  77. # nn.BatchNorm2d(128), nn.ReLU6(inplace=True))
  78. # self.block6 = ResidualBlock(64)
  79. self.block6_1 = nn.Sequential(
  80. nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, bias=False, stride=2),
  81. nn.BatchNorm2d(128),
  82. nn.ReLU6(inplace=True)
  83. )
  84. self.block6_2 = ResidualBlock(128)
  85. self.block6_3 = ResidualBlock(128)
  86. # self.block6_3 = nn.Sequential(
  87. # nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, bias=False),
  88. # nn.BatchNorm2d(64),
  89. # nn.ReLU6(inplace=True))
  90. self.block6_4 = nn.Sequential(
  91. # nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, bias=False),
  92. nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, bias=False, stride=2,output_padding=1),
  93. nn.BatchNorm2d(64),
  94. nn.ReLU6(inplace=True)
  95. )
  96. # self.block6_3 = ResidualBlock(64)
  97. # self.block6_5 = ResidualBlock(64)
  98. self.block7 = ResidualBlock(64)
  99. self.block8 = ResidualBlock(64)
  100. # self.block9 = nn.Sequential(
  101. # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=False),
  102. # nn.BatchNorm2d(64),
  103. # nn.ReLU6(inplace=True)
  104. # )
  105. self.block10 = nn.Sequential(
  106. # nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, bias=False),
  107. nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, bias=False, stride=2, output_padding=1),
  108. nn.BatchNorm2d(32),
  109. nn.ReLU6(inplace=True)
  110. )
  111. self.block11 = nn.Sequential(
  112. # nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1, bias=False),
  113. nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, padding=1, bias=False, stride=2, output_padding=1),
  114. nn.BatchNorm2d(16),
  115. nn.ReLU6(inplace=True)
  116. )
  117. self.block12 = nn.Sequential(
  118. nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, padding=1, bias=False),
  119. nn.BatchNorm2d(3),
  120. nn.ReLU6(inplace=True)
  121. )
  122. # self.dropout = nn.Dropout(0.4)
  123. def forward(self, x):
  124. input = x
  125. x = self.block1(x)
  126. input2 = x
  127. x = self.block2(x)
  128. input3 = x
  129. x = self.block3(x)
  130. input4 = x
  131. x = self.block4(x)
  132. x = self.block5(x)
  133. # x = self.block6(x)
  134. x = self.block6_1(x)
  135. input4_1 = x
  136. x = self.block6_2(x)
  137. x = self.block6_3(x)
  138. x = input4_1 + x
  139. x = self.block6_4(x)
  140. # x = self.block6_5(x)
  141. x = self.block7(x)
  142. x = self.block8(x)
  143. x = input4 + x
  144. # x = self.block9(x)
  145. x = self.block10(x)
  146. x = input3 + x
  147. x = self.block11(x)
  148. x = input2 + x
  149. x = self.block12(x)
  150. x = input + x
  151. x = torch.sigmoid(x)
  152. return x
  153. @staticmethod
  154. def load_trained_model(ckpt_path):
  155. ckpt_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
  156. model = M64ColorNet()
  157. model.load_state_dict(ckpt_dict["model_state_dict"])
  158. model.eval()
  159. return model, ckpt_dict["mean"], ckpt_dict["std"], ckpt_dict["loss"], ckpt_dict["ssim_score"], ckpt_dict["psnr_score"]
  160. if __name__ == "__main__":
  161. model = M64ColorNet()
  162. summary(model, input_size=(16, 3, 256, 256))