import torch from torch import nn import torch.nn.functional as F class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(channels) self.relu6_1 = nn.ReLU6(inplace=True) self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(channels) self.relu6_2 = nn.ReLU6(inplace=True) self.relu6_latest = nn.ReLU6(inplace=True) def forward(self, x): residual = self.conv1(x) residual = self.bn1(residual) residual = self.relu6_1(residual) residual = self.conv2(residual) residual = self.bn2(residual) residual = self.relu6_2(residual) add = x + residual return self.relu6_latest(add) class M64ColorNet(nn.Module): def __init__(self): super(M64ColorNet, self).__init__() self.block1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=False, stride=1), nn.BatchNorm2d(16), nn.ReLU6(inplace=True) ) self.block2 = nn.Sequential( nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, bias=False, stride=2), nn.BatchNorm2d(32), nn.ReLU6(inplace=True) ) self.block3 = nn.Sequential( nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1, bias=False, stride=2), nn.BatchNorm2d(64), nn.ReLU6(inplace=True) ) self.block4 = ResidualBlock(64) self.block5 = ResidualBlock(64) self.block6 = ResidualBlock(64) self.block7 = ResidualBlock(64) self.block8 = ResidualBlock(64) self.block9 = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU6(inplace=True) ) self.block10 = nn.Sequential( # nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, bias=False), nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, bias=False, stride=2, output_padding=1), nn.BatchNorm2d(32), nn.ReLU6(inplace=True) ) self.block11 = nn.Sequential( # nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1, bias=False), nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, padding=1, bias=False, stride=2, output_padding=1), nn.BatchNorm2d(16), nn.ReLU6(inplace=True) ) self.block12 = nn.Sequential( nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(3), nn.ReLU6(inplace=True) ) # self.dropout = nn.Dropout(0.4) def forward(self, x): input = x x = self.block1(x) input2 = x x = self.block2(x) input3 = x x = self.block3(x) input4 = x x = self.block4(x) x = self.block5(x) x = self.block6(x) x = self.block7(x) x = self.block8(x) x = input4 + x x = self.block9(x) x = self.block10(x) x = input3 + x x = self.block11(x) x = input2 + x x = self.block12(x) x = input + x x = torch.sigmoid(x) return x @staticmethod def load_trained_model(ckpt_path): ckpt_dict = torch.load(ckpt_path, map_location=torch.device('cpu')) model = M64ColorNet() model.load_state_dict(ckpt_dict["model_state_dict"]) model.eval() return model, ckpt_dict["mean"], ckpt_dict["std"], ckpt_dict["loss"], ckpt_dict["ssim_score"], ckpt_dict["psnr_score"]