Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions Minibatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from keras import backend as K
from keras.engine import InputSpec, Layer
from keras import initializers, regularizers, constraints

# From a PR that is not pulled into Keras
# https://github.com/fchollet/keras/pull/3677
# I updated the code to work on Keras 2.x

class MinibatchDiscrimination(Layer):
"""Concatenates to each sample information about how different the input
features for that sample are from features of other samples in the same
minibatch, as described in Salimans et. al. (2016). Useful for preventing
GANs from collapsing to a single output. When using this layer, generated
samples and reference samples should be in separate batches.
# Example
```python
# apply a convolution 1d of length 3 to a sequence with 10 timesteps,
# with 64 output filters
model = Sequential()
model.add(Convolution1D(64, 3, border_mode='same', input_shape=(10, 32)))
# now model.output_shape == (None, 10, 64)
# flatten the output so it can be fed into a minibatch discrimination layer
model.add(Flatten())
# now model.output_shape == (None, 640)
# add the minibatch discrimination layer
model.add(MinibatchDiscrimination(5, 3))
# now model.output_shape = (None, 645)
```
# Arguments
nb_kernels: Number of discrimination kernels to use
(dimensionality concatenated to output).
kernel_dim: The dimensionality of the space where closeness of samples
is calculated.
init: name of initialization function for the weights of the layer
(see [initializations](../initializations.md)),
or alternatively, Theano function to use for weights initialization.
This parameter is only relevant if you don't pass a `weights` argument.
weights: list of numpy arrays to set as initial weights.
W_regularizer: instance of [WeightRegularizer](../regularizers.md)
(eg. L1 or L2 regularization), applied to the main weights matrix.
activity_regularizer: instance of [ActivityRegularizer](../regularizers.md),
applied to the network output.
W_constraint: instance of the [constraints](../constraints.md) module
(eg. maxnorm, nonneg), applied to the main weights matrix.
input_dim: Number of channels/dimensions in the input.
Either this argument or the keyword argument `input_shape`must be
provided when using this layer as the first layer in a model.
# Input shape
2D tensor with shape: `(samples, input_dim)`.
# Output shape
2D tensor with shape: `(samples, input_dim + nb_kernels)`.
# References
- [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498)
"""

def __init__(self, nb_kernels, kernel_dim, init='glorot_uniform', weights=None,
W_regularizer=None, activity_regularizer=None,
W_constraint=None, input_dim=None, **kwargs):
self.init = initializers.get(init)
self.nb_kernels = nb_kernels
self.kernel_dim = kernel_dim
self.input_dim = input_dim

self.W_regularizer = regularizers.get(W_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)

self.W_constraint = constraints.get(W_constraint)

self.initial_weights = weights
self.input_spec = [InputSpec(ndim=2)]

if self.input_dim:
kwargs['input_shape'] = (self.input_dim,)
super(MinibatchDiscrimination, self).__init__(**kwargs)

def build(self, input_shape):
assert len(input_shape) == 2

input_dim = input_shape[1]
self.input_spec = [InputSpec(dtype=K.floatx(),
shape=(None, input_dim))]

self.W = self.add_weight(shape=(self.nb_kernels, input_dim, self.kernel_dim),
initializer=self.init,
name='kernel',
regularizer=self.W_regularizer,
trainable=True,
constraint=self.W_constraint)

# Set built to true.
super(MinibatchDiscrimination, self).build(input_shape)

def call(self, x, mask=None):
activation = K.reshape(K.dot(x, self.W), (-1, self.nb_kernels, self.kernel_dim))
diffs = K.expand_dims(activation, 3) - K.expand_dims(K.permute_dimensions(activation, [1, 2, 0]), 0)
abs_diffs = K.sum(K.abs(diffs), axis=2)
minibatch_features = K.sum(K.exp(-abs_diffs), axis=2)
return K.concatenate([x, minibatch_features], 1)

def compute_output_shape(self, input_shape):
assert input_shape and len(input_shape) == 2
return input_shape[0], input_shape[1]+self.nb_kernels

def get_config(self):
config = {'nb_kernels': self.nb_kernels,
'kernel_dim': self.kernel_dim,
'init': self.init.__name__,
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
'W_constraint': self.W_constraint.get_config() if self.W_constraint else None,
'input_dim': self.input_dim}
base_config = super(MinibatchDiscrimination, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Loading