-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstyle.py
117 lines (91 loc) · 3.41 KB
/
style.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import sys
from keras.models import load_model
from keras.preprocessing import image
from keras import backend as K
import numpy as np
from scipy.misc import imsave
from vgg import VGG19, preprocess_input
def get_vgg_features(vgg, inputs, target_layer):
output_layers = [
'block1_conv1',
'block2_conv1',
'block3_conv1',
'block4_conv1',
'block5_conv1'
]
outputs = [layer.output for layer in vgg.layers
if layer.name == output_layers[target_layer-1]]
f = K.function([vgg.input] + [K.learning_phase()], outputs)
return f([inputs, 1.])
def wct(content, style, alpha=0.6, eps=1e-5):
'''
https://github.com/eridgd/WCT-TF/blob/master/ops.py
Perform Whiten-Color Transform on feature maps using numpy
See p.4 of the Universal Style Transfer paper for equations:
https://arxiv.org/pdf/1705.08086.pdf
'''
# 1xHxWxC -> CxHxW
content_t = np.transpose(np.squeeze(content), (2, 0, 1))
style_t = np.transpose(np.squeeze(style), (2, 0, 1))
# CxHxW -> CxH*W
content_flat = content_t.reshape(-1, content_t.shape[1]*content_t.shape[2])
style_flat = style_t.reshape(-1, style_t.shape[1]*style_t.shape[2])
# Whitening transform
mc = content_flat.mean(axis=1, keepdims=True)
fc = content_flat - mc
fcfc = np.dot(fc, fc.T) / (content_t.shape[1]*content_t.shape[2] - 1)
Ec, wc, _ = np.linalg.svd(fcfc)
k_c = (wc > 1e-5).sum()
Dc = np.diag((wc[:k_c]+eps)**-0.5)
fc_hat = Ec[:,:k_c].dot(Dc).dot(Ec[:,:k_c].T).dot(fc)
# Coloring transform
ms = style_flat.mean(axis=1, keepdims=True)
fs = style_flat - ms
fsfs = np.dot(fs, fs.T) / (style_t.shape[1]*style_t.shape[2] - 1)
Es, ws, _ = np.linalg.svd(fsfs)
k_s = (ws > 1e-5).sum()
Ds = np.sqrt(np.diag(ws[:k_s]+eps))
fcs_hat = Es[:,:k_s].dot(Ds).dot(Es[:,:k_s].T).dot(fc_hat)
fcs_hat = fcs_hat + ms
# Blend transform features with original features
blended = alpha*fcs_hat + (1 - alpha)*(fc)
# CxH*W -> CxHxW
blended = blended.reshape(content_t.shape)
# CxHxW -> 1xHxWxC
blended = np.expand_dims(np.transpose(blended, (1,2,0)), 0)
return np.float32(blended)
img_c = image.load_img(sys.argv[1])
img_c = image.img_to_array(img_c)
img_c_shape = img_c.shape
img_c = np.expand_dims(img_c, axis=0)
img_s = image.load_img(sys.argv[2])
img_s = image.img_to_array(img_s)
img_s_shape = img_s.shape
img_s = np.expand_dims(img_s, axis=0)
assert img_c_shape == img_s_shape, \
'Content and style image should be the same shape, %s != %s' \
% (str(img_c_shape), str(img_s_shape))
input_shape = img_c_shape
print('Loading decoders...')
decoders = {}
decoders[1] = load_model('./models/decoder_1.h5')
decoders[2] = load_model('./models/decoder_2.h5')
decoders[3] = load_model('./models/decoder_3.h5')
decoders[4] = load_model('./models/decoder_4.h5')
decoders[5] = load_model('./models/decoder_5.h5')
print('Loading VGG...')
vgg = VGG19(input_shape=input_shape, target_layer=5)
import matplotlib.pyplot as plt
plt.imshow(np.clip(img_c[0] / 255, 0, 1))
plt.show()
print('Styling...')
for i in [3, 1]:
feats_c = get_vgg_features(vgg, img_c, i)
feats_s = get_vgg_features(vgg, img_s, i)
feats_cs = wct(feats_c, feats_s)
img_c = decoders[i].predict(feats_cs)
plt.imshow(np.clip(img_c[0] / 255, 0, 1))
plt.show()
print('Saving output...')
output_img = img_c[0]
imsave(sys.argv[3], output_img)