forked from rrmina/fast-neural-style-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vgg.py
50 lines (42 loc) · 1.78 KB
/
vgg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn as nn
from torchvision import models, transforms
import utils
class VGG19(nn.Module):
def __init__(self, vgg_path="models/vgg19-d01eb7cb.pth"):
super(VGG19, self).__init__()
# Load VGG Skeleton, Pretrained Weights
vgg19_features = models.vgg19(pretrained=False)
vgg19_features.load_state_dict(torch.load(vgg_path), strict=False)
self.features = vgg19_features.features
# Turn-off Gradient History
for param in self.features.parameters():
param.requires_grad = False
def forward(self, x):
layers = {'3': 'relu1_2', '8': 'relu2_2', '17': 'relu3_4', '22': 'relu4_2', '26': 'relu4_4', '35': 'relu5_4'}
features = {}
for name, layer in self.features._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
return features
class VGG16(nn.Module):
def __init__(self, vgg_path="models/vgg16-00b39a1b.pth"):
super(VGG16, self).__init__()
# Load VGG Skeleton, Pretrained Weights
vgg16_features = models.vgg16(pretrained=False)
vgg16_features.load_state_dict(torch.load(vgg_path), strict=False)
self.features = vgg16_features.features
# Turn-off Gradient History
for param in self.features.parameters():
param.requires_grad = False
def forward(self, x):
layers = {'3': 'relu1_2', '8': 'relu2_2', '15': 'relu3_3', '22': 'relu4_3'}
features = {}
for name, layer in self.features._modules.items():
x = layer(x)
if name in layers:
features[layers[name]] = x
if (name=='22'):
break
return features