123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869 |
- import torchvision.models as models
- from torchsummary import summary
- import torch
- class VGG19(torch.nn.Module):
- def __init__(self):
- super(VGG19, self).__init__()
- vgg_net = models.vgg19(pretrained=True)
- # summary(vgg_net, (3, 224, 224))
- features = vgg_net.features
- self.relu_1_1 = torch.nn.Sequential()
- self.relu_1_2 = torch.nn.Sequential()
- self.relu_2_1 = torch.nn.Sequential()
- self.relu_3_1 = torch.nn.Sequential()
- self.relu_4_1 = torch.nn.Sequential()
- self.relu_5_1 = torch.nn.Sequential()
- for x in range(0, 2):
- self.relu_1_1.add_module(str(x), features[x])
- for x in range(2, 4):
- self.relu_1_2.add_module(str(x), features[x])
- for x in range(4, 7):
- self.relu_2_1.add_module(str(x), features[x])
- for x in range(7, 12):
- self.relu_3_1.add_module(str(x), features[x])
- for x in range(12, 21):
- self.relu_4_1.add_module(str(x), features[x])
- for x in range(21, 30):
- self.relu_5_1.add_module(str(x), features[x])
-
- # don't need the gradients, just want the features
- for param in self.parameters():
- param.requires_grad = False
- def forward(self, pred, gt):
- h_pred = self.relu_1_1(pred)
- h_gt = self.relu_1_1(gt)
- style_pred_1 = h_pred
- style_gt_1 = h_pred
- h_pred = self.relu_1_2(h_pred)
- h_gt = self.relu_1_2(h_gt)
- content_pred = h_pred
- content_gt = h_gt
- h_pred = self.relu_2_1(h_pred)
- h_gt = self.relu_2_1(h_gt)
- style_pred_2 = h_pred
- style_gt_2 = h_gt
- h_pred = self.relu_3_1(h_pred)
- h_gt = self.relu_3_1(h_gt)
- style_pred_3 = h_pred
- style_gt_3 = h_gt
- h_pred = self.relu_4_1(h_pred)
- h_gt = self.relu_4_1(h_gt)
- style_pred_4 = h_pred
- style_gt_4 = h_gt
- h_pred = self.relu_5_1(h_pred)
- h_gt = self.relu_5_1(h_gt)
- style_pred_5 = h_pred
- style_gt_5 = h_gt
- contents = (content_pred, content_gt)
- style_gt_list = [style_gt_1, style_gt_2, style_gt_3, style_gt_4, style_gt_5]
- style_pred_list = [style_pred_1, style_pred_2, style_pred_3, style_pred_4, style_pred_5]
- return contents, style_pred_list, style_gt_list
|