forked from zhengchuanpan/GMAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtf_utils.py
58 lines (54 loc) · 2.31 KB
/
tf_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import tensorflow as tf
def conv2d(x, output_dims, kernel_size, stride = [1, 1],
padding = 'SAME', use_bias = True, activation = tf.nn.relu,
bn = False, bn_decay = None, is_training = None):
input_dims = x.get_shape()[-1].value
kernel_shape = kernel_size + [input_dims, output_dims]
kernel = tf.Variable(
tf.glorot_uniform_initializer()(shape = kernel_shape),
dtype = tf.float32, trainable = True, name = 'kernel')
x = tf.nn.conv2d(x, kernel, [1] + stride + [1], padding = padding)
if use_bias:
bias = tf.Variable(
tf.zeros_initializer()(shape = [output_dims]),
dtype = tf.float32, trainable = True, name = 'bias')
x = tf.nn.bias_add(x, bias)
if activation is not None:
if bn:
x = batch_norm(x, is_training = is_training, bn_decay = bn_decay)
x = activation(x)
return x
def batch_norm(x, is_training, bn_decay):
input_dims = x.get_shape()[-1].value
moment_dims = list(range(len(x.get_shape()) - 1))
beta = tf.Variable(
tf.zeros_initializer()(shape = [input_dims]),
dtype = tf.float32, trainable = True, name = 'beta')
gamma = tf.Variable(
tf.ones_initializer()(shape = [input_dims]),
dtype = tf.float32, trainable = True, name = 'gamma')
batch_mean, batch_var = tf.nn.moments(x, moment_dims, name='moments')
decay = bn_decay if bn_decay is not None else 0.9
ema = tf.train.ExponentialMovingAverage(decay = decay)
# Operator that maintains moving averages of variables.
ema_apply_op = tf.cond(
is_training,
lambda: ema.apply([batch_mean, batch_var]),
lambda: tf.no_op())
# Update moving average and return current batch's avg and var.
def mean_var_with_update():
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
# ema.average returns the Variable holding the average of var.
mean, var = tf.cond(
is_training,
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
x = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
return x
def dropout(x, drop, is_training):
x = tf.cond(
is_training,
lambda: tf.nn.dropout(x, rate = drop),
lambda: x)
return x