From acd03c1c4e4edf3d22601e2ba7a6234d1008bfef Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 14 Dec 2023 14:12:16 +0800 Subject: [PATCH 1/2] Add `ops.random.alpha_dropout` and `layers.AlphaDropout` --- keras/backend/jax/random.py | 19 +++++ keras/backend/numpy/random.py | 27 ++++++++ keras/backend/tensorflow/random.py | 19 +++++ keras/backend/torch/random.py | 19 +++++ keras/layers/__init__.py | 1 + keras/layers/regularization/alpha_dropout.py | 69 +++++++++++++++++++ .../regularization/alpha_dropout_test.py | 51 ++++++++++++++ keras/random/random.py | 7 ++ keras/random/random_test.py | 21 ++++++ 9 files changed, 233 insertions(+) create mode 100644 keras/layers/regularization/alpha_dropout.py create mode 100644 keras/layers/regularization/alpha_dropout_test.py diff --git a/keras/backend/jax/random.py b/keras/backend/jax/random.py index 4b89d3e952f..389cb2a3ac1 100644 --- a/keras/backend/jax/random.py +++ b/keras/backend/jax/random.py @@ -1,6 +1,7 @@ import jax from keras.backend.config import floatx +from keras.backend.jax.core import cast from keras.random.seed_generator import SeedGenerator from keras.random.seed_generator import draw_seed from keras.random.seed_generator import make_default_seed @@ -81,6 +82,24 @@ def dropout(inputs, rate, noise_shape=None, seed=None): ) +def alpha_dropout(inputs, rate, noise_shape=None, seed=None): + noise_shape = _get_concrete_noise_shape(inputs, noise_shape) + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + alpha_p = -alpha * scale + + kept_idx = jax.numpy.greater_equal(uniform(noise_shape, seed=seed), rate) + kept_idx = cast(kept_idx, inputs.dtype) + + # Compute affine transformation parameters + a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5 + b = -a * alpha_p * rate + + # Apply mask + x = inputs * kept_idx + alpha_p * (1 - kept_idx) + return a * x + b + + def shuffle(x, axis=0, seed=None): seed = jax_draw_seed(seed) return jax.random.shuffle(seed, x, axis) diff --git a/keras/backend/numpy/random.py b/keras/backend/numpy/random.py index dd7e7234a33..4b3022725d2 100644 --- a/keras/backend/numpy/random.py +++ b/keras/backend/numpy/random.py @@ -88,6 +88,33 @@ def dropout(inputs, rate, noise_shape=None, seed=None): return np.where(mask, inputs / keep_prob, np.zeros_like(inputs)) +def alpha_dropout(inputs, rate, noise_shape=None, seed=None): + # If noise_shape is not provided, use the shape of inputs + if noise_shape is None: + noise_shape = inputs.shape + else: + # If noise_shape is provided, replace None with corresponding + # input shape + noise_shape = [ + n if n is not None else inputs.shape[i] + for i, n in enumerate(noise_shape) + ] + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + alpha_p = -alpha * scale + + kept_idx = np.greater_equal(uniform(noise_shape, seed=seed), rate) + kept_idx = kept_idx.astype(inputs.dtype) + + # Compute affine transformation parameters + a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5 + b = -a * alpha_p * rate + + # Apply mask + x = inputs * kept_idx + alpha_p * (1 - kept_idx) + return (a * x + b).astype(inputs.dtype) + + def shuffle(x, axis=0, seed=None): seed = draw_seed(seed) rng = np.random.default_rng(seed) diff --git a/keras/backend/tensorflow/random.py b/keras/backend/tensorflow/random.py index 5ffb4c6d2d7..8c8873097a7 100644 --- a/keras/backend/tensorflow/random.py +++ b/keras/backend/tensorflow/random.py @@ -86,6 +86,25 @@ def dropout(inputs, rate, noise_shape=None, seed=None): ) +def alpha_dropout(inputs, rate, noise_shape=None, seed=None): + noise_shape = _get_concrete_noise_shape(inputs, noise_shape) + + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + alpha_p = -alpha * scale + + kept_idx = tf.greater_equal(uniform(noise_shape, seed=seed), rate) + kept_idx = tf.cast(kept_idx, inputs.dtype) + + # Compute affine transformation parameters + a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5 + b = -a * alpha_p * rate + + # Apply mask + x = inputs * kept_idx + alpha_p * (1 - kept_idx) + return a * x + b + + def shuffle(x, axis=0, seed=None): seed = tf_draw_seed(seed) if axis == 0: diff --git a/keras/backend/torch/random.py b/keras/backend/torch/random.py index c76618a4983..62016f03fe0 100644 --- a/keras/backend/torch/random.py +++ b/keras/backend/torch/random.py @@ -3,6 +3,7 @@ import torch.nn.functional as tnn from keras.backend.config import floatx +from keras.backend.torch.core import cast from keras.backend.torch.core import convert_to_tensor from keras.backend.torch.core import get_device from keras.backend.torch.core import to_torch_dtype @@ -166,6 +167,24 @@ def dropout(inputs, rate, noise_shape=None, seed=None): ) +def alpha_dropout(inputs, rate, noise_shape=None, seed=None): + noise_shape = _get_concrete_noise_shape(inputs, noise_shape) + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + alpha_p = -alpha * scale + + kept_idx = torch.greater_equal(uniform(noise_shape, seed=seed), rate) + kept_idx = cast(kept_idx, inputs.dtype) + + # Compute affine transformation parameters + a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5 + b = -a * alpha_p * rate + + # Apply mask + x = inputs * kept_idx + alpha_p * (1 - kept_idx) + return a * x + b + + def shuffle(x, axis=0, seed=None): # Ref: https://github.com/pytorch/pytorch/issues/71409 x = convert_to_tensor(x) diff --git a/keras/layers/__init__.py b/keras/layers/__init__.py index 7e73d451f6f..08096b10a87 100644 --- a/keras/layers/__init__.py +++ b/keras/layers/__init__.py @@ -86,6 +86,7 @@ from keras.layers.regularization.activity_regularization import ( ActivityRegularization, ) +from keras.layers.regularization.alpha_dropout import AlphaDropout from keras.layers.regularization.dropout import Dropout from keras.layers.regularization.gaussian_dropout import GaussianDropout from keras.layers.regularization.gaussian_noise import GaussianNoise diff --git a/keras/layers/regularization/alpha_dropout.py b/keras/layers/regularization/alpha_dropout.py new file mode 100644 index 00000000000..309f6ae774c --- /dev/null +++ b/keras/layers/regularization/alpha_dropout.py @@ -0,0 +1,69 @@ +from keras import backend +from keras.api_export import keras_export +from keras.layers.layer import Layer + + +@keras_export("keras.layers.AlphaDropout") +class AlphaDropout(Layer): + """Applies Alpha Dropout to the input. + + Alpha Dropout is a `Dropout` that keeps mean and variance of inputs + to their original values, in order to ensure the self-normalizing property + even after this dropout. + Alpha Dropout fits well to Scaled Exponential Linear Units (SELU) by + randomly setting activations to the negative saturation value. + + Args: + rate: Float between 0 and 1. The multiplicative noise will have + standard deviation `sqrt(rate / (1 - rate))`. + noise_shape: 1D integer tensor representing the shape of the + binary alpha dropout mask that will be multiplied with the input. + For instance, if your inputs have shape + `(batch_size, timesteps, features)` and + you want the alpha dropout mask to be the same for all timesteps, + you can use `noise_shape=(batch_size, 1, features)`. + seed: A Python integer to use as random seed. + + Call arguments: + inputs: Input tensor (of any rank). + training: Python boolean indicating whether the layer should behave in + training mode (adding alpha dropout) or in inference mode + (doing nothing). + """ + + def __init__(self, rate, noise_shape=None, seed=None, **kwargs): + super().__init__(**kwargs) + if not 0 <= rate <= 1: + raise ValueError( + f"Invalid value received for argument " + "`rate`. Expected a float value between 0 and 1. " + f"Received: rate={rate}" + ) + self.rate = rate + self.seed = seed + self.noise_shape = noise_shape + self.seed_generator = backend.random.SeedGenerator(seed) + self.supports_masking = True + self.built = True + + def call(self, inputs, training=False): + if training and self.rate > 0: + return backend.random.alpha_dropout( + inputs, + self.rate, + noise_shape=self.noise_shape, + seed=self.seed_generator, + ) + return inputs + + def compute_output_shape(self, input_shape): + return input_shape + + def get_config(self): + base_config = super().get_config() + config = { + "rate": self.rate, + "seed": self.seed, + "noise_shape": self.noise_shape, + } + return {**base_config, **config} diff --git a/keras/layers/regularization/alpha_dropout_test.py b/keras/layers/regularization/alpha_dropout_test.py new file mode 100644 index 00000000000..d6ea4a0db96 --- /dev/null +++ b/keras/layers/regularization/alpha_dropout_test.py @@ -0,0 +1,51 @@ +import numpy as np +import pytest + +from keras import layers +from keras import testing + + +class AlphaDropoutTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_alpha_dropout_basics(self): + self.run_layer_test( + layers.AlphaDropout, + init_kwargs={ + "rate": 0.2, + }, + input_shape=(2, 3), + expected_output_shape=(2, 3), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + ) + + def test_alpha_dropout_partial_noise_shape_dynamic(self): + inputs = np.ones((20, 5, 10)) + layer = layers.AlphaDropout(0.5, noise_shape=(None, 1, None)) + outputs = layer(inputs, training=True) + self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :]) + + def test_alpha_dropout_partial_noise_shape_static(self): + inputs = np.ones((20, 5, 10)) + layer = layers.AlphaDropout(0.5, noise_shape=(20, 1, 10)) + outputs = layer(inputs, training=True) + self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :]) + + def test_alpha_dropout_negative_rate(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value received for argument `rate`. " + "Expected a float value between 0 and 1.", + ): + _ = layers.AlphaDropout(rate=-0.5) + + def test_alpha_dropout_rate_greater_than_one(self): + with self.assertRaisesRegex( + ValueError, + "Invalid value received for argument `rate`. " + "Expected a float value between 0 and 1.", + ): + _ = layers.AlphaDropout(rate=1.5) diff --git a/keras/random/random.py b/keras/random/random.py index 585e90ed8f9..5789175a1af 100644 --- a/keras/random/random.py +++ b/keras/random/random.py @@ -190,6 +190,13 @@ def dropout(inputs, rate, noise_shape=None, seed=None): ) +@keras_export("keras.random.alpha_dropout") +def alpha_dropout(inputs, rate, noise_shape=None, seed=None): + return backend.random.alpha_dropout( + inputs, rate, noise_shape=noise_shape, seed=seed + ) + + @keras_export("keras.random.shuffle") def shuffle(x, axis=0, seed=None): """Shuffle the elements of a tensor uniformly at random along an axis. diff --git a/keras/random/random_test.py b/keras/random/random_test.py index b9b788cc5a6..d13bf927716 100644 --- a/keras/random/random_test.py +++ b/keras/random/random_test.py @@ -140,6 +140,27 @@ def test_dropout_noise_shape(self): ) self.assertEqual(x.shape, (2, 3, 5, 7)) + def test_alpha_dropout(self): + x = ops.random.normal((10000,)) + y = random.alpha_dropout(x, rate=0, seed=0) + self.assertAllClose(y, x) + self.assertEqual(x.dtype, y.dtype) + + # standard deviation check + y = random.alpha_dropout(x, rate=0.8, seed=0) + self.assertAllClose(ops.std(y), 1.0, atol=1e-1) + y = random.alpha_dropout(x, rate=0.5, seed=0) + self.assertAllClose(ops.std(y), 1.0, atol=1e-1) + y = random.alpha_dropout(x, rate=0.3, seed=0) + self.assertAllClose(ops.std(y), 1.0, atol=1e-1) + + def test_alpha_dropout_noise_shape(self): + inputs = ops.ones((2, 3, 5, 7)) + x = random.alpha_dropout( + inputs, rate=0.3, noise_shape=[None, 3, 5, None], seed=0 + ) + self.assertEqual(x.shape, (2, 3, 5, 7)) + @pytest.mark.skipif( keras.backend.backend() != "jax", reason="This test requires `jax` as the backend.", From 7b1b3e81154c9ca89fb1e495957270520c65e89c Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 15 Dec 2023 09:42:45 +0800 Subject: [PATCH 2/2] Move the implementation of `ops.random.alpha_dropout` into layer --- keras/backend/jax/random.py | 19 ---------- keras/backend/numpy/random.py | 27 -------------- keras/backend/tensorflow/random.py | 19 ---------- keras/backend/torch/random.py | 19 ---------- keras/layers/regularization/alpha_dropout.py | 36 ++++++++++++++++--- .../regularization/alpha_dropout_test.py | 9 +++++ keras/random/random.py | 7 ---- keras/random/random_test.py | 21 ----------- 8 files changed, 41 insertions(+), 116 deletions(-) diff --git a/keras/backend/jax/random.py b/keras/backend/jax/random.py index 389cb2a3ac1..4b89d3e952f 100644 --- a/keras/backend/jax/random.py +++ b/keras/backend/jax/random.py @@ -1,7 +1,6 @@ import jax from keras.backend.config import floatx -from keras.backend.jax.core import cast from keras.random.seed_generator import SeedGenerator from keras.random.seed_generator import draw_seed from keras.random.seed_generator import make_default_seed @@ -82,24 +81,6 @@ def dropout(inputs, rate, noise_shape=None, seed=None): ) -def alpha_dropout(inputs, rate, noise_shape=None, seed=None): - noise_shape = _get_concrete_noise_shape(inputs, noise_shape) - alpha = 1.6732632423543772848170429916717 - scale = 1.0507009873554804934193349852946 - alpha_p = -alpha * scale - - kept_idx = jax.numpy.greater_equal(uniform(noise_shape, seed=seed), rate) - kept_idx = cast(kept_idx, inputs.dtype) - - # Compute affine transformation parameters - a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5 - b = -a * alpha_p * rate - - # Apply mask - x = inputs * kept_idx + alpha_p * (1 - kept_idx) - return a * x + b - - def shuffle(x, axis=0, seed=None): seed = jax_draw_seed(seed) return jax.random.shuffle(seed, x, axis) diff --git a/keras/backend/numpy/random.py b/keras/backend/numpy/random.py index 4b3022725d2..dd7e7234a33 100644 --- a/keras/backend/numpy/random.py +++ b/keras/backend/numpy/random.py @@ -88,33 +88,6 @@ def dropout(inputs, rate, noise_shape=None, seed=None): return np.where(mask, inputs / keep_prob, np.zeros_like(inputs)) -def alpha_dropout(inputs, rate, noise_shape=None, seed=None): - # If noise_shape is not provided, use the shape of inputs - if noise_shape is None: - noise_shape = inputs.shape - else: - # If noise_shape is provided, replace None with corresponding - # input shape - noise_shape = [ - n if n is not None else inputs.shape[i] - for i, n in enumerate(noise_shape) - ] - alpha = 1.6732632423543772848170429916717 - scale = 1.0507009873554804934193349852946 - alpha_p = -alpha * scale - - kept_idx = np.greater_equal(uniform(noise_shape, seed=seed), rate) - kept_idx = kept_idx.astype(inputs.dtype) - - # Compute affine transformation parameters - a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5 - b = -a * alpha_p * rate - - # Apply mask - x = inputs * kept_idx + alpha_p * (1 - kept_idx) - return (a * x + b).astype(inputs.dtype) - - def shuffle(x, axis=0, seed=None): seed = draw_seed(seed) rng = np.random.default_rng(seed) diff --git a/keras/backend/tensorflow/random.py b/keras/backend/tensorflow/random.py index 8c8873097a7..5ffb4c6d2d7 100644 --- a/keras/backend/tensorflow/random.py +++ b/keras/backend/tensorflow/random.py @@ -86,25 +86,6 @@ def dropout(inputs, rate, noise_shape=None, seed=None): ) -def alpha_dropout(inputs, rate, noise_shape=None, seed=None): - noise_shape = _get_concrete_noise_shape(inputs, noise_shape) - - alpha = 1.6732632423543772848170429916717 - scale = 1.0507009873554804934193349852946 - alpha_p = -alpha * scale - - kept_idx = tf.greater_equal(uniform(noise_shape, seed=seed), rate) - kept_idx = tf.cast(kept_idx, inputs.dtype) - - # Compute affine transformation parameters - a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5 - b = -a * alpha_p * rate - - # Apply mask - x = inputs * kept_idx + alpha_p * (1 - kept_idx) - return a * x + b - - def shuffle(x, axis=0, seed=None): seed = tf_draw_seed(seed) if axis == 0: diff --git a/keras/backend/torch/random.py b/keras/backend/torch/random.py index 62016f03fe0..c76618a4983 100644 --- a/keras/backend/torch/random.py +++ b/keras/backend/torch/random.py @@ -3,7 +3,6 @@ import torch.nn.functional as tnn from keras.backend.config import floatx -from keras.backend.torch.core import cast from keras.backend.torch.core import convert_to_tensor from keras.backend.torch.core import get_device from keras.backend.torch.core import to_torch_dtype @@ -167,24 +166,6 @@ def dropout(inputs, rate, noise_shape=None, seed=None): ) -def alpha_dropout(inputs, rate, noise_shape=None, seed=None): - noise_shape = _get_concrete_noise_shape(inputs, noise_shape) - alpha = 1.6732632423543772848170429916717 - scale = 1.0507009873554804934193349852946 - alpha_p = -alpha * scale - - kept_idx = torch.greater_equal(uniform(noise_shape, seed=seed), rate) - kept_idx = cast(kept_idx, inputs.dtype) - - # Compute affine transformation parameters - a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5 - b = -a * alpha_p * rate - - # Apply mask - x = inputs * kept_idx + alpha_p * (1 - kept_idx) - return a * x + b - - def shuffle(x, axis=0, seed=None): # Ref: https://github.com/pytorch/pytorch/issues/71409 x = convert_to_tensor(x) diff --git a/keras/layers/regularization/alpha_dropout.py b/keras/layers/regularization/alpha_dropout.py index 309f6ae774c..fd9e0da0912 100644 --- a/keras/layers/regularization/alpha_dropout.py +++ b/keras/layers/regularization/alpha_dropout.py @@ -1,4 +1,5 @@ from keras import backend +from keras import ops from keras.api_export import keras_export from keras.layers.layer import Layer @@ -48,17 +49,44 @@ def __init__(self, rate, noise_shape=None, seed=None, **kwargs): def call(self, inputs, training=False): if training and self.rate > 0: - return backend.random.alpha_dropout( - inputs, + noise_shape = self._get_concrete_noise_shape( + inputs, self.noise_shape + ) + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + alpha_p = -alpha * scale + + kept_idx = ops.greater_equal( + ops.random.uniform(noise_shape, seed=self.seed_generator), self.rate, - noise_shape=self.noise_shape, - seed=self.seed_generator, ) + kept_idx = ops.cast(kept_idx, inputs.dtype) + + # Compute affine transformation parameters + a = ((1 - self.rate) * (1 + self.rate * alpha_p**2)) ** -0.5 + b = -a * alpha_p * self.rate + + # Apply mask + x = inputs * kept_idx + alpha_p * (1 - kept_idx) + return a * x + b + return inputs def compute_output_shape(self, input_shape): return input_shape + def _get_concrete_noise_shape(self, inputs, noise_shape): + if noise_shape is None: + return inputs.shape + + concrete_inputs_shape = inputs.shape + concrete_noise_shape = [] + for i, value in enumerate(noise_shape): + concrete_noise_shape.append( + concrete_inputs_shape[i] if value is None else value + ) + return concrete_noise_shape + def get_config(self): base_config = super().get_config() config = { diff --git a/keras/layers/regularization/alpha_dropout_test.py b/keras/layers/regularization/alpha_dropout_test.py index d6ea4a0db96..916ef1fa66c 100644 --- a/keras/layers/regularization/alpha_dropout_test.py +++ b/keras/layers/regularization/alpha_dropout_test.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from keras import backend from keras import layers from keras import testing @@ -22,6 +23,14 @@ def test_alpha_dropout_basics(self): supports_masking=True, ) + def test_alpha_dropout_correctness(self): + inputs = np.ones((20, 500)).astype("float32") + layer = layers.AlphaDropout(0.3, seed=1337) + outputs = layer(inputs, training=True) + self.assertAllClose( + np.std(backend.convert_to_numpy(outputs)), 1.0, atol=1e-1 + ) + def test_alpha_dropout_partial_noise_shape_dynamic(self): inputs = np.ones((20, 5, 10)) layer = layers.AlphaDropout(0.5, noise_shape=(None, 1, None)) diff --git a/keras/random/random.py b/keras/random/random.py index 5789175a1af..585e90ed8f9 100644 --- a/keras/random/random.py +++ b/keras/random/random.py @@ -190,13 +190,6 @@ def dropout(inputs, rate, noise_shape=None, seed=None): ) -@keras_export("keras.random.alpha_dropout") -def alpha_dropout(inputs, rate, noise_shape=None, seed=None): - return backend.random.alpha_dropout( - inputs, rate, noise_shape=noise_shape, seed=seed - ) - - @keras_export("keras.random.shuffle") def shuffle(x, axis=0, seed=None): """Shuffle the elements of a tensor uniformly at random along an axis. diff --git a/keras/random/random_test.py b/keras/random/random_test.py index d13bf927716..b9b788cc5a6 100644 --- a/keras/random/random_test.py +++ b/keras/random/random_test.py @@ -140,27 +140,6 @@ def test_dropout_noise_shape(self): ) self.assertEqual(x.shape, (2, 3, 5, 7)) - def test_alpha_dropout(self): - x = ops.random.normal((10000,)) - y = random.alpha_dropout(x, rate=0, seed=0) - self.assertAllClose(y, x) - self.assertEqual(x.dtype, y.dtype) - - # standard deviation check - y = random.alpha_dropout(x, rate=0.8, seed=0) - self.assertAllClose(ops.std(y), 1.0, atol=1e-1) - y = random.alpha_dropout(x, rate=0.5, seed=0) - self.assertAllClose(ops.std(y), 1.0, atol=1e-1) - y = random.alpha_dropout(x, rate=0.3, seed=0) - self.assertAllClose(ops.std(y), 1.0, atol=1e-1) - - def test_alpha_dropout_noise_shape(self): - inputs = ops.ones((2, 3, 5, 7)) - x = random.alpha_dropout( - inputs, rate=0.3, noise_shape=[None, 3, 5, None], seed=0 - ) - self.assertEqual(x.shape, (2, 3, 5, 7)) - @pytest.mark.skipif( keras.backend.backend() != "jax", reason="This test requires `jax` as the backend.",