-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist_triplet_example.py
145 lines (110 loc) · 5.09 KB
/
mnist_triplet_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""
This is a modified version of the Keras mnist example.
https://keras.io/examples/mnist_cnn/
Instead of using a fixed number of epochs this version continues to train until a stop criteria is reached.
A triplet neural network is used to pre-train an embedding for the network. The resulting embedding is then extended
with a softmax output layer for categorical predictions.
Model performance should be around 99.87% after training. The resulting model is identical in structure to the one in
the example yet shows considerable improvement in relative error confirming that the embedding learned by the triplet
network is useful.
The performance of the final model is slightly better than a siamese neural network on the same model architecture. The
triplet network makes it possible to train more complex models than the siamese network architecture at the cost of
more memory.
"""
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Activation, Concatenate
from keras import backend as K
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.models import Model
from keras.layers import Input, Flatten, Dense
from triplet import TripletNetwork
batch_size = 128
num_classes = 10
epochs = 999999
# input image dimensions
img_rows, img_cols = 28, 28
# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
def create_base_model(input_shape):
model_input = Input(shape=input_shape)
embedding = Conv2D(32, kernel_size=(3, 3), input_shape=input_shape)(model_input)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='relu')(embedding)
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
embedding = Conv2D(64, kernel_size=(3, 3))(embedding)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='relu')(embedding)
embedding = MaxPooling2D(pool_size=(2, 2))(embedding)
embedding = Flatten()(embedding)
embedding = Dense(128)(embedding)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='relu')(embedding)
return Model(model_input, embedding)
def create_head_model(embedding_shape):
embedding_a = Input(shape=embedding_shape)
embedding_b = Input(shape=embedding_shape)
embedding_c = Input(shape=embedding_shape)
head = Concatenate()([embedding_a, embedding_b, embedding_c])
head = Dense(32)(head)
head = BatchNormalization()(head)
head = Activation(activation='sigmoid')(head)
head = Dense(2)(head)
head = BatchNormalization()(head)
head = Activation(activation='sigmoid')(head)
return Model([embedding_a, embedding_b, embedding_c], head)
num_classes = 10
epochs = 999999
base_model = create_base_model(input_shape)
head_model = create_head_model(base_model.output_shape)
triplet_network = TripletNetwork(base_model, head_model, num_classes)
triplet_network.compile(loss='binary_crossentropy', optimizer=keras.optimizers.adam(), metrics=['accuracy'])
triplet_checkpoint_path = "./triplet_checkpoint"
triplet_callbacks = [
EarlyStopping(monitor='val_acc', patience=10, verbose=0),
ModelCheckpoint(triplet_checkpoint_path, monitor='val_acc', save_best_only=True, verbose=0)
]
triplet_network.fit(x_train, y_train,
validation_data=(x_test, y_test),
batch_size=1000,
epochs=epochs,
callbacks=triplet_callbacks)
triplet_network.load_weights(triplet_checkpoint_path)
embedding = base_model.outputs[-1]
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
# Add softmax layer to the pre-trained embedding network
embedding = Dense(num_classes)(embedding)
embedding = BatchNormalization()(embedding)
embedding = Activation(activation='sigmoid')(embedding)
model = Model(base_model.inputs[0], embedding)
model.compile(loss=keras.losses.binary_crossentropy,
optimizer=keras.optimizers.adam(),
metrics=['accuracy'])
model_checkpoint_path = "./model_checkpoint"
model_callbacks = [
EarlyStopping(monitor='val_acc', patience=10, verbose=0),
ModelCheckpoint(model_checkpoint_path, monitor='val_acc', save_best_only=True, verbose=0)
]
model.fit(x_train, y_train,
batch_size=128,
epochs=epochs,
callbacks=model_callbacks,
validation_data=(x_test, y_test))
model.load_weights(model_checkpoint_path)
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])