import torchvision.models as models from torchsummary import summary import torch class VGG19(torch.nn.Module): def __init__(self): super(VGG19, self).__init__() vgg_net = models.vgg19(pretrained=True) # summary(vgg_net, (3, 224, 224)) features = vgg_net.features self.relu_1_1 = torch.nn.Sequential() self.relu_1_2 = torch.nn.Sequential() self.relu_2_1 = torch.nn.Sequential() self.relu_3_1 = torch.nn.Sequential() self.relu_4_1 = torch.nn.Sequential() self.relu_5_1 = torch.nn.Sequential() for x in range(0, 2): self.relu_1_1.add_module(str(x), features[x]) for x in range(2, 4): self.relu_1_2.add_module(str(x), features[x]) for x in range(4, 7): self.relu_2_1.add_module(str(x), features[x]) for x in range(7, 12): self.relu_3_1.add_module(str(x), features[x]) for x in range(12, 21): self.relu_4_1.add_module(str(x), features[x]) for x in range(21, 30): self.relu_5_1.add_module(str(x), features[x]) # don't need the gradients, just want the features for param in self.parameters(): param.requires_grad = False def forward(self, pred, gt): h_pred = self.relu_1_1(pred) h_gt = self.relu_1_1(gt) style_pred_1 = h_pred style_gt_1 = h_pred h_pred = self.relu_1_2(h_pred) h_gt = self.relu_1_2(h_gt) content_pred = h_pred content_gt = h_gt h_pred = self.relu_2_1(h_pred) h_gt = self.relu_2_1(h_gt) style_pred_2 = h_pred style_gt_2 = h_gt h_pred = self.relu_3_1(h_pred) h_gt = self.relu_3_1(h_gt) style_pred_3 = h_pred style_gt_3 = h_gt h_pred = self.relu_4_1(h_pred) h_gt = self.relu_4_1(h_gt) style_pred_4 = h_pred style_gt_4 = h_gt h_pred = self.relu_5_1(h_pred) h_gt = self.relu_5_1(h_gt) style_pred_5 = h_pred style_gt_5 = h_gt contents = (content_pred, content_gt) style_gt_list = [style_gt_1, style_gt_2, style_gt_3, style_gt_4, style_gt_5] style_pred_list = [style_pred_1, style_pred_2, style_pred_3, style_pred_4, style_pred_5] return contents, style_pred_list, style_gt_list