vgg19.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import torchvision.models as models
  2. from torchsummary import summary
  3. import torch
  4. class VGG19(torch.nn.Module):
  5. def __init__(self):
  6. super(VGG19, self).__init__()
  7. vgg_net = models.vgg19(pretrained=True)
  8. # summary(vgg_net, (3, 224, 224))
  9. features = vgg_net.features
  10. self.relu_1_1 = torch.nn.Sequential()
  11. self.relu_1_2 = torch.nn.Sequential()
  12. self.relu_2_1 = torch.nn.Sequential()
  13. self.relu_3_1 = torch.nn.Sequential()
  14. self.relu_4_1 = torch.nn.Sequential()
  15. self.relu_5_1 = torch.nn.Sequential()
  16. for x in range(0, 2):
  17. self.relu_1_1.add_module(str(x), features[x])
  18. for x in range(2, 4):
  19. self.relu_1_2.add_module(str(x), features[x])
  20. for x in range(4, 7):
  21. self.relu_2_1.add_module(str(x), features[x])
  22. for x in range(7, 12):
  23. self.relu_3_1.add_module(str(x), features[x])
  24. for x in range(12, 21):
  25. self.relu_4_1.add_module(str(x), features[x])
  26. for x in range(21, 30):
  27. self.relu_5_1.add_module(str(x), features[x])
  28. # don't need the gradients, just want the features
  29. for param in self.parameters():
  30. param.requires_grad = False
  31. def forward(self, pred, gt):
  32. h_pred = self.relu_1_1(pred)
  33. h_gt = self.relu_1_1(gt)
  34. style_pred_1 = h_pred
  35. style_gt_1 = h_pred
  36. h_pred = self.relu_1_2(h_pred)
  37. h_gt = self.relu_1_2(h_gt)
  38. content_pred = h_pred
  39. content_gt = h_gt
  40. h_pred = self.relu_2_1(h_pred)
  41. h_gt = self.relu_2_1(h_gt)
  42. style_pred_2 = h_pred
  43. style_gt_2 = h_gt
  44. h_pred = self.relu_3_1(h_pred)
  45. h_gt = self.relu_3_1(h_gt)
  46. style_pred_3 = h_pred
  47. style_gt_3 = h_gt
  48. h_pred = self.relu_4_1(h_pred)
  49. h_gt = self.relu_4_1(h_gt)
  50. style_pred_4 = h_pred
  51. style_gt_4 = h_gt
  52. h_pred = self.relu_5_1(h_pred)
  53. h_gt = self.relu_5_1(h_gt)
  54. style_pred_5 = h_pred
  55. style_gt_5 = h_gt
  56. contents = (content_pred, content_gt)
  57. style_gt_list = [style_gt_1, style_gt_2, style_gt_3, style_gt_4, style_gt_5]
  58. style_pred_list = [style_pred_1, style_pred_2, style_pred_3, style_pred_4, style_pred_5]
  59. return contents, style_pred_list, style_gt_list