-
Notifications
You must be signed in to change notification settings - Fork 1
/
transform.py
57 lines (51 loc) · 2.23 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from keras.layers import Conv2D, Conv2DTranspose, Input, Lambda, Activation
from keras.models import Model
import keras.layers
from keras.initializers import TruncatedNormal
from keras_contrib.layers.normalization import InstanceNormalization
WEIGHTS_INIT_STDEV = .1
def Conv2DInstanceNorm(inputs, filters, kernel_size,
strides=1, relu=True):
weights_init = TruncatedNormal(stddev=WEIGHTS_INIT_STDEV, seed=1)
conv = Conv2D(filters,
(kernel_size, kernel_size),
strides=strides,
padding='same',
kernel_initializer=weights_init,
use_bias=False)(inputs)
norm = InstanceNormalization(axis=3)(conv)
if relu:
norm = Activation('relu')(norm)
return norm
def Conv2DTransposeInstanceNorm(inputs, filters, kernel_size,
strides=1, relu=True):
weights_init = TruncatedNormal(stddev=WEIGHTS_INIT_STDEV, seed=1)
conv = Conv2DTranspose(filters,
(kernel_size, kernel_size),
strides=strides,
padding='same',
kernel_initializer=weights_init,
use_bias=False)(inputs)
norm = InstanceNormalization(axis=3)(conv)
if relu:
norm = Activation('relu')(norm)
return norm
def Conv2DResidualBlock(inputs):
tmp = Conv2DInstanceNorm(inputs, 128, 3)
tmp2 = Conv2DInstanceNorm(tmp, 128, 3, relu=False)
return keras.layers.add([inputs, tmp2])
def TransformNet(inputs):
conv1 = Conv2DInstanceNorm(inputs, 32, 9)
conv2 = Conv2DInstanceNorm(conv1, 64, 3, strides=2)
conv3 = Conv2DInstanceNorm(conv2, 128, 3, strides=2)
resid1 = Conv2DResidualBlock(conv3)
resid2 = Conv2DResidualBlock(resid1)
resid3 = Conv2DResidualBlock(resid2)
resid4 = Conv2DResidualBlock(resid3)
resid5 = Conv2DResidualBlock(resid4)
conv_t1 = Conv2DTransposeInstanceNorm(resid5, 64, 3, strides=2)
conv_t2 = Conv2DTransposeInstanceNorm(conv_t1, 32, 3, strides=2)
conv_t3 = Conv2DInstanceNorm(conv_t2, 3, 9, relu=False)
tanh = Activation('tanh')(conv_t3)
preds = Lambda(lambda x : x * 150 + 255./2)(tanh)
return preds