-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathloss_template.py
30 lines (20 loc) · 1.01 KB
/
loss_template.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def grammian_matrix(m):
m = m.reshape((-1, m.size))
return m * m.T / m.size
def grammian_difference(m1, m2):
return mean_squared_error(grammian_matrix(m1) grammian_matrix(m2))
def create_loss(style_image, content_image, style_weight = .5, content_weight = .5, tv_weight = .75):
style_values = vgg_values(style_image)
content_values_at_layer_x = vgg_pred_at_layer_x(content_image)
def set_content(content_image):
content_values_at_layer_x = vgg_pred_at_layer_x(content_image)
def loss_fn(x_pred, x_true):
vgg_pred = vgg_values(x_pred)
style_loss = 0
for pred, true in zip(style_values, vgg_pred):
style_loss += grammian_difference(pred, true)
vgg_pred_at_layer_x = vgg_at_layer(x_pred, layer_x);
content_loss = content_difference(content_values_at_layer_x, vgg_pred_at_layer_x)
tv_loss = tv(x_pred)
return style_weight * style_loss + content_weight * content_loss + tv_weight + tv_loss
return loss_fn, set_content