import torch from torch import nn import torch.nn.functional as F from torchinfo import summary class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(channels) self.relu6_1 = nn.ReLU6(inplace=True) self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(channels) self.relu6_2 = nn.ReLU6(inplace=True) self.relu6_latest = nn.ReLU6(inplace=True) def forward(self, x): residual = self.conv1(x) residual = self.bn1(residual) residual = self.relu6_1(residual) residual = self.conv2(residual) residual = self.bn2(residual) residual = self.relu6_2(residual) add = x + residual return self.relu6_latest(add) class M64ColorNet(nn.Module): def __init__(self): super(M64ColorNet, self).__init__() self.block1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, bias=False, stride=1), nn.BatchNorm2d(16), nn.ReLU6(inplace=True) ) self.block2 = nn.Sequential( nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, bias=False, stride=2), nn.BatchNorm2d(32), nn.ReLU6(inplace=True) ) self.block3 = nn.Sequential( nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=False, stride=2), nn.BatchNorm2d(32), nn.ReLU6(inplace=True) ) self.block4 = ResidualBlock(32) # self.block5 = ResidualBlock(32) # self.block6 = nn.Sequential( # nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, bias=False), # nn.BatchNorm2d(128), nn.ReLU6(inplace=True)) # self.block6 = ResidualBlock(64) # self.block6_1 = nn.Sequential( # nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, bias=False, stride=2), # nn.BatchNorm2d(128), # nn.ReLU6(inplace=True) # ) # self.block6_2 = ResidualBlock(128) # self.block6_3 = ResidualBlock(128) # self.block6_3 = nn.Sequential( # nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, bias=False), # nn.BatchNorm2d(64), # nn.ReLU6(inplace=True)) # self.block6_4 = nn.Sequential( # # nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, bias=False), # nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, bias=False, stride=2,output_padding=1), # nn.BatchNorm2d(64), # nn.ReLU6(inplace=True)) # self.block6_3 = ResidualBlock(64) # self.block6_5 = ResidualBlock(64) self.block7 = ResidualBlock(32) self.block8 = ResidualBlock(32) # self.block9 = nn.Sequential( # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=False), # nn.BatchNorm2d(64), # nn.ReLU6(inplace=True)) self.block10 = nn.Sequential( # nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1, bias=False), nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=False, stride=2, output_padding=1), nn.BatchNorm2d(32), nn.ReLU6(inplace=True) ) self.block11 = nn.Sequential( # nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1, bias=False), nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, padding=1, bias=False, stride=2, output_padding=1), nn.BatchNorm2d(16), nn.ReLU6(inplace=True) ) self.block12 = nn.Sequential( nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(3), nn.ReLU6(inplace=True) ) # self.dropout = nn.Dropout(0.4) def forward(self, x): input = x x = self.block1(x) input2 = x x = self.block2(x) input3 = x x = self.block3(x) input4 = x x = self.block4(x) # x = self.block5(x) # x = self.block6(x) # x = self.block6_1(x) # input4_1 = x # x = self.block6_2(x) # x = self.block6_3(x) # x = input4_1 + x # x = self.block6_4(x) # x = self.block6_5(x) x = self.block7(x) x = self.block8(x) x = input4 + x # x = self.block9(x) x = self.block10(x) x = input3 + x x = self.block11(x) x = input2 + x x = self.block12(x) x = input + x x = torch.sigmoid(x) return x @staticmethod def load_trained_model(ckpt_path): ckpt_dict = torch.load(ckpt_path, map_location=torch.device('cpu')) model = M64ColorNet() model.load_state_dict(ckpt_dict["model_state_dict"]) model.eval() return model, ckpt_dict["mean"], ckpt_dict["std"], ckpt_dict["loss"], ckpt_dict["ssim_score"], ckpt_dict["psnr_score"] if __name__ == "__main__": model = M64ColorNet() summary(model, input_size=(16, 3, 256, 256))