Skip to content

Commit

Permalink
Add ops.random.alpha_dropout and layers.AlphaDropout (#18940)
Browse files Browse the repository at this point in the history
* Add `ops.random.alpha_dropout` and `layers.AlphaDropout`

* Move the implementation of `ops.random.alpha_dropout` into layer
  • Loading branch information
james77777778 authored Dec 15, 2023
1 parent 92e7171 commit a14af85
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from keras.layers.regularization.activity_regularization import (
ActivityRegularization,
)
from keras.layers.regularization.alpha_dropout import AlphaDropout
from keras.layers.regularization.dropout import Dropout
from keras.layers.regularization.gaussian_dropout import GaussianDropout
from keras.layers.regularization.gaussian_noise import GaussianNoise
Expand Down
97 changes: 97 additions & 0 deletions keras/layers/regularization/alpha_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from keras import backend
from keras import ops
from keras.api_export import keras_export
from keras.layers.layer import Layer


@keras_export("keras.layers.AlphaDropout")
class AlphaDropout(Layer):
"""Applies Alpha Dropout to the input.
Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
to their original values, in order to ensure the self-normalizing property
even after this dropout.
Alpha Dropout fits well to Scaled Exponential Linear Units (SELU) by
randomly setting activations to the negative saturation value.
Args:
rate: Float between 0 and 1. The multiplicative noise will have
standard deviation `sqrt(rate / (1 - rate))`.
noise_shape: 1D integer tensor representing the shape of the
binary alpha dropout mask that will be multiplied with the input.
For instance, if your inputs have shape
`(batch_size, timesteps, features)` and
you want the alpha dropout mask to be the same for all timesteps,
you can use `noise_shape=(batch_size, 1, features)`.
seed: A Python integer to use as random seed.
Call arguments:
inputs: Input tensor (of any rank).
training: Python boolean indicating whether the layer should behave in
training mode (adding alpha dropout) or in inference mode
(doing nothing).
"""

def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
super().__init__(**kwargs)
if not 0 <= rate <= 1:
raise ValueError(
f"Invalid value received for argument "
"`rate`. Expected a float value between 0 and 1. "
f"Received: rate={rate}"
)
self.rate = rate
self.seed = seed
self.noise_shape = noise_shape
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True
self.built = True

def call(self, inputs, training=False):
if training and self.rate > 0:
noise_shape = self._get_concrete_noise_shape(
inputs, self.noise_shape
)
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
alpha_p = -alpha * scale

kept_idx = ops.greater_equal(
ops.random.uniform(noise_shape, seed=self.seed_generator),
self.rate,
)
kept_idx = ops.cast(kept_idx, inputs.dtype)

# Compute affine transformation parameters
a = ((1 - self.rate) * (1 + self.rate * alpha_p**2)) ** -0.5
b = -a * alpha_p * self.rate

# Apply mask
x = inputs * kept_idx + alpha_p * (1 - kept_idx)
return a * x + b

return inputs

def compute_output_shape(self, input_shape):
return input_shape

def _get_concrete_noise_shape(self, inputs, noise_shape):
if noise_shape is None:
return inputs.shape

concrete_inputs_shape = inputs.shape
concrete_noise_shape = []
for i, value in enumerate(noise_shape):
concrete_noise_shape.append(
concrete_inputs_shape[i] if value is None else value
)
return concrete_noise_shape

def get_config(self):
base_config = super().get_config()
config = {
"rate": self.rate,
"seed": self.seed,
"noise_shape": self.noise_shape,
}
return {**base_config, **config}
60 changes: 60 additions & 0 deletions keras/layers/regularization/alpha_dropout_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import numpy as np
import pytest

from keras import backend
from keras import layers
from keras import testing


class AlphaDropoutTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_alpha_dropout_basics(self):
self.run_layer_test(
layers.AlphaDropout,
init_kwargs={
"rate": 0.2,
},
input_shape=(2, 3),
expected_output_shape=(2, 3),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=1,
expected_num_losses=0,
supports_masking=True,
)

def test_alpha_dropout_correctness(self):
inputs = np.ones((20, 500)).astype("float32")
layer = layers.AlphaDropout(0.3, seed=1337)
outputs = layer(inputs, training=True)
self.assertAllClose(
np.std(backend.convert_to_numpy(outputs)), 1.0, atol=1e-1
)

def test_alpha_dropout_partial_noise_shape_dynamic(self):
inputs = np.ones((20, 5, 10))
layer = layers.AlphaDropout(0.5, noise_shape=(None, 1, None))
outputs = layer(inputs, training=True)
self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])

def test_alpha_dropout_partial_noise_shape_static(self):
inputs = np.ones((20, 5, 10))
layer = layers.AlphaDropout(0.5, noise_shape=(20, 1, 10))
outputs = layer(inputs, training=True)
self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])

def test_alpha_dropout_negative_rate(self):
with self.assertRaisesRegex(
ValueError,
"Invalid value received for argument `rate`. "
"Expected a float value between 0 and 1.",
):
_ = layers.AlphaDropout(rate=-0.5)

def test_alpha_dropout_rate_greater_than_one(self):
with self.assertRaisesRegex(
ValueError,
"Invalid value received for argument `rate`. "
"Expected a float value between 0 and 1.",
):
_ = layers.AlphaDropout(rate=1.5)

0 comments on commit a14af85

Please sign in to comment.