diff --git a/keras/layers/__init__.py b/keras/layers/__init__.py index 7e73d451f6f..08096b10a87 100644 --- a/keras/layers/__init__.py +++ b/keras/layers/__init__.py @@ -86,6 +86,7 @@ from keras.layers.regularization.activity_regularization import ( ActivityRegularization, ) +from keras.layers.regularization.alpha_dropout import AlphaDropout from keras.layers.regularization.dropout import Dropout from keras.layers.regularization.gaussian_dropout import GaussianDropout from keras.layers.regularization.gaussian_noise import GaussianNoise diff --git a/keras/layers/regularization/alpha_dropout.py b/keras/layers/regularization/alpha_dropout.py new file mode 100644 index 00000000000..fd9e0da0912 --- /dev/null +++ b/keras/layers/regularization/alpha_dropout.py @@ -0,0 +1,97 @@ +from keras import backend +from keras import ops +from keras.api_export import keras_export +from keras.layers.layer import Layer + + +@keras_export("keras.layers.AlphaDropout") +class AlphaDropout(Layer): + """Applies Alpha Dropout to the input. + + Alpha Dropout is a `Dropout` that keeps mean and variance of inputs + to their original values, in order to ensure the self-normalizing property + even after this dropout. + Alpha Dropout fits well to Scaled Exponential Linear Units (SELU) by + randomly setting activations to the negative saturation value. + + Args: + rate: Float between 0 and 1. The multiplicative noise will have + standard deviation `sqrt(rate / (1 - rate))`. + noise_shape: 1D integer tensor representing the shape of the + binary alpha dropout mask that will be multiplied with the input. + For instance, if your inputs have shape + `(batch_size, timesteps, features)` and + you want the alpha dropout mask to be the same for all timesteps, + you can use `noise_shape=(batch_size, 1, features)`. + seed: A Python integer to use as random seed. + + Call arguments: + inputs: Input tensor (of any rank). + training: Python boolean indicating whether the layer should behave in + training mode (adding alpha dropout) or in inference mode + (doing nothing). + """ + + def __init__(self, rate, noise_shape=None, seed=None, **kwargs): + super().__init__(**kwargs) + if not 0 <= rate <= 1: + raise ValueError( + f"Invalid value received for argument " + "`rate`. Expected a float value between 0 and 1. " + f"Received: rate={rate}" + ) + self.rate = rate + self.seed = seed + self.noise_shape = noise_shape + self.seed_generator = backend.random.SeedGenerator(seed) + self.supports_masking = True + self.built = True + + def call(self, inputs, training=False): + if training and self.rate > 0: + noise_shape = self._get_concrete_noise_shape( + inputs, self.noise_shape + ) + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + alpha_p = -alpha * scale + + kept_idx = ops.greater_equal( + ops.random.uniform(noise_shape, seed=self.seed_generator), + self.rate, + ) + kept_idx = ops.cast(kept_idx, inputs.dtype) + + # Compute affine transformation parameters + a = ((1 - self.rate) * (1 + self.rate * alpha_p**2)) ** -0.5 + b = -a * alpha_p * self.rate + + # Apply mask + x = inputs * kept_idx + alpha_p * (1 - kept_idx) + return a * x + b + + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + def _get_concrete_noise_shape(self, inputs, noise_shape): + if noise_shape is None: + return inputs.shape + + concrete_inputs_shape = inputs.shape + concrete_noise_shape = [] + for i, value in enumerate(noise_shape): + concrete_noise_shape.append( + concrete_inputs_shape[i] if value is None else value + ) + return concrete_noise_shape + + def get_config(self): + base_config = super().get_config() + config = { + "rate": self.rate, + "seed": self.seed, + "noise_shape": self.noise_shape, + } + return {**base_config, **config} diff --git a/keras/layers/regularization/alpha_dropout_test.py b/keras/layers/regularization/alpha_dropout_test.py new file mode 100644 index 00000000000..916ef1fa66c --- /dev/null +++ b/keras/layers/regularization/alpha_dropout_test.py @@ -0,0 +1,60 @@ +import numpy as np +import pytest + +from keras import backend +from keras import layers +from keras import testing + + +class AlphaDropoutTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_alpha_dropout_basics(self): + self.run_layer_test( + layers.AlphaDropout, + init_kwargs={ + "rate": 0.2, + }, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + ) + + def test_alpha_dropout_correctness(self): + inputs = np.ones((20, 500)).astype("float32") + layer = layers.AlphaDropout(0.3, seed=1337) + outputs = layer(inputs, training=True) + self.assertAllClose( + np.std(backend.convert_to_numpy(outputs)), 1.0, atol=1e-1 + ) + + def test_alpha_dropout_partial_noise_shape_dynamic(self): + inputs = np.ones((20, 5, 10)) + layer = layers.AlphaDropout(0.5, noise_shape=(None, 1, None)) + outputs = layer(inputs, training=True) + self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :]) + + def test_alpha_dropout_partial_noise_shape_static(self): + inputs = np.ones((20, 5, 10)) + layer = layers.AlphaDropout(0.5, noise_shape=(20, 1, 10)) + outputs = layer(inputs, training=True) + self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :]) + + def test_alpha_dropout_negative_rate(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value received for argument `rate`. " + "Expected a float value between 0 and 1.", + ): + _ = layers.AlphaDropout(rate=-0.5) + + def test_alpha_dropout_rate_greater_than_one(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value received for argument `rate`. " + "Expected a float value between 0 and 1.", + ): + _ = layers.AlphaDropout(rate=1.5)