forked from lukedeo/keras-acgan
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Minibatch.py
113 lines (99 loc) · 5.4 KB
/
Minibatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from keras import backend as K
from keras.engine import InputSpec, Layer
from keras import initializers, regularizers, constraints
# From a PR that is not pulled into Keras
# https://github.com/fchollet/keras/pull/3677
# I updated the code to work on Keras 2.x
class MinibatchDiscrimination(Layer):
"""Concatenates to each sample information about how different the input
features for that sample are from features of other samples in the same
minibatch, as described in Salimans et. al. (2016). Useful for preventing
GANs from collapsing to a single output. When using this layer, generated
samples and reference samples should be in separate batches.
# Example
```python
# apply a convolution 1d of length 3 to a sequence with 10 timesteps,
# with 64 output filters
model = Sequential()
model.add(Convolution1D(64, 3, border_mode='same', input_shape=(10, 32)))
# now model.output_shape == (None, 10, 64)
# flatten the output so it can be fed into a minibatch discrimination layer
model.add(Flatten())
# now model.output_shape == (None, 640)
# add the minibatch discrimination layer
model.add(MinibatchDiscrimination(5, 3))
# now model.output_shape = (None, 645)
```
# Arguments
nb_kernels: Number of discrimination kernels to use
(dimensionality concatenated to output).
kernel_dim: The dimensionality of the space where closeness of samples
is calculated.
init: name of initialization function for the weights of the layer
(see [initializations](../initializations.md)),
or alternatively, Theano function to use for weights initialization.
This parameter is only relevant if you don't pass a `weights` argument.
weights: list of numpy arrays to set as initial weights.
W_regularizer: instance of [WeightRegularizer](../regularizers.md)
(eg. L1 or L2 regularization), applied to the main weights matrix.
activity_regularizer: instance of [ActivityRegularizer](../regularizers.md),
applied to the network output.
W_constraint: instance of the [constraints](../constraints.md) module
(eg. maxnorm, nonneg), applied to the main weights matrix.
input_dim: Number of channels/dimensions in the input.
Either this argument or the keyword argument `input_shape`must be
provided when using this layer as the first layer in a model.
# Input shape
2D tensor with shape: `(samples, input_dim)`.
# Output shape
2D tensor with shape: `(samples, input_dim + nb_kernels)`.
# References
- [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498)
"""
def __init__(self, nb_kernels, kernel_dim, init='glorot_uniform', weights=None,
W_regularizer=None, activity_regularizer=None,
W_constraint=None, input_dim=None, **kwargs):
self.init = initializers.get(init)
self.nb_kernels = nb_kernels
self.kernel_dim = kernel_dim
self.input_dim = input_dim
self.W_regularizer = regularizers.get(W_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.W_constraint = constraints.get(W_constraint)
self.initial_weights = weights
self.input_spec = [InputSpec(ndim=2)]
if self.input_dim:
kwargs['input_shape'] = (self.input_dim,)
super(MinibatchDiscrimination, self).__init__(**kwargs)
def build(self, input_shape):
assert len(input_shape) == 2
input_dim = input_shape[1]
self.input_spec = [InputSpec(dtype=K.floatx(),
shape=(None, input_dim))]
self.W = self.add_weight(shape=(self.nb_kernels, input_dim, self.kernel_dim),
initializer=self.init,
name='kernel',
regularizer=self.W_regularizer,
trainable=True,
constraint=self.W_constraint)
# Set built to true.
super(MinibatchDiscrimination, self).build(input_shape)
def call(self, x, mask=None):
activation = K.reshape(K.dot(x, self.W), (-1, self.nb_kernels, self.kernel_dim))
diffs = K.expand_dims(activation, 3) - K.expand_dims(K.permute_dimensions(activation, [1, 2, 0]), 0)
abs_diffs = K.sum(K.abs(diffs), axis=2)
minibatch_features = K.sum(K.exp(-abs_diffs), axis=2)
return K.concatenate([x, minibatch_features], 1)
def compute_output_shape(self, input_shape):
assert input_shape and len(input_shape) == 2
return input_shape[0], input_shape[1]+self.nb_kernels
def get_config(self):
config = {'nb_kernels': self.nb_kernels,
'kernel_dim': self.kernel_dim,
'init': self.init.__name__,
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
'W_constraint': self.W_constraint.get_config() if self.W_constraint else None,
'input_dim': self.input_dim}
base_config = super(MinibatchDiscrimination, self).get_config()
return dict(list(base_config.items()) + list(config.items()))