-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgan_vae.py
123 lines (102 loc) · 3.89 KB
/
gan_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Input, Dense, Lambda, Activation, Convolution2D, MaxPooling2D
from keras.models import Model, Sequential
from keras import backend as K
from keras.layers.core import Reshape
from keras.layers.convolutional import UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.optimizers import SGD
x_train = np.load("D:/Bitcamp/Project/Frontalization/Imagenius/Numpy/data_x.npy")
y_train = np.load("D:/Bitcamp/Project/Frontalization/Imagenius/Numpy/data_y.npy")
x_train = x_train.astype('float32') / 255.
y_train = y_train.astype('float32') / 255.
zero=np.where(y_train==0)
x_train=x_train[zero][0:20]
shape=128
batch_size = 30
nb_classes = 10
img_rows, img_cols = shape, shape
nb_filters = 32
pool_size = (2, 2)
kernel_size = (3, 3)
input_shape=(shape,shape,1)
original_dim = 784
latent_dim = 2
intermediate_dim = 256
epsilon_std = 1.0
learning_rate = 0.028
decay_rate = 5e-5
momentum = 0.9
sgd = SGD(lr=learning_rate,momentum=momentum, decay=decay_rate, nesterov=False)
part=8
thre=1
### START GENERATOR
recog=Sequential()
recog.add(Dense(64,activation='relu',input_shape=(49,152,),init='glorot_uniform'))
get_0_layer_output=K.function([recog.layers[0].input,
K.learning_phase()],[recog.layers[0].output])
c=get_0_layer_output([x_train[0].reshape((x1,49,152)), 0])[0][0]
recog_left=recog
recog_left.add(Lambda(lambda x: x + np.mean(c), output_shape=(64,)))
recog_right=recog
recog_right.add(Lambda(lambda x: x + K.exp(x / 2) * K.random_normal(shape=(1, 64), mean=0., stddev=epsilon_std), output_shape=(64,)))
recog1=Sequential()
#recog1.add(keras.layers.Average()([recog_left, recog_right]))
recog1.add(Dense(64, activation='relu',init='glorot_uniform'))
recog1.add(Dense(49152, activation='relu',init='glorot_uniform'))
recog1.compile(loss='mean_squared_error', optimizer=sgd,metrics = ['mae'])
### END FIRST MODEL
### START DISCRIMINATOR
recog12=Sequential()
recog12.add(Reshape((128,128,3),input_shape=(49152,)))
recog12.add(Convolution2D(20, 3,3,
border_mode='valid',
input_shape=input_shape))
recog12.add(BatchNormalization())
recog12.add(Activation('relu'))
recog12.add(UpSampling2D(size=(2, 2)))
recog12.add(Convolution2D(20, 3, 3,
init='glorot_uniform'))
recog12.add(BatchNormalization())
recog12.add(Activation('relu'))
recog12.add(Convolution2D(20, 3, 3,init='glorot_uniform'))
recog12.add(BatchNormalization())
recog12.add(Activation('relu'))
recog12.add(MaxPooling2D(pool_size=(3,3)))
recog12.add(Convolution2D(4, 3, 3,init='glorot_uniform'))
recog12.add(BatchNormalization())
recog12.add(Activation('relu'))
recog12.add(Reshape((128,128,3)))
recog12.add(Reshape((49,152,)))
recog12.add(Dense(49,152, activation='sigmoid',init='glorot_uniform'))
recog12.compile(loss='mean_squared_error', optimizer=sgd,metrics = ['mae'])
recog12.fit(x_train[0].reshape((1,49,152)), x_train[0].reshape((1,49,152)),
nb_epoch=1,
batch_size=30,verbose=1)
################## GAN
def not_train(net, val):
net.trainable = val
for k in net.layers:
k.trainable = val
not_train(recog1, False)
gan_input = Input(batch_shape=(1,49,152))
gan_level2 = recog12(recog1(gan_input))
GAN = Model(gan_input, gan_level2)
GAN.compile(loss='mean_squared_error', optimizer='adam',metrics = ['mae'])
GAN.fit(x_train[0].reshape(1,49,152), x_train[0].reshape((1,49,152)),
batch_size=30, nb_epoch=1,verbose=1)
x_train_GAN=x_train[0].reshape(1,49,152)
a=GAN.predict(x_train[0].reshape(1,49,152),verbose=1)
plt.figure(figsize=(10, 10))
ax = plt.subplot(1, 2, 1)
plt.imshow(x_train_GAN.reshape(128, 128, 3))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(1, 2, 2)
plt.imshow(a.reshape(128, 128, 3))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()