-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add `ops.random.alpha_dropout` and `layers.AlphaDropout` * Move the implementation of `ops.random.alpha_dropout` into layer
- Loading branch information
1 parent
92e7171
commit a14af85
Showing
3 changed files
with
158 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from keras import backend | ||
from keras import ops | ||
from keras.api_export import keras_export | ||
from keras.layers.layer import Layer | ||
|
||
|
||
@keras_export("keras.layers.AlphaDropout") | ||
class AlphaDropout(Layer): | ||
"""Applies Alpha Dropout to the input. | ||
Alpha Dropout is a `Dropout` that keeps mean and variance of inputs | ||
to their original values, in order to ensure the self-normalizing property | ||
even after this dropout. | ||
Alpha Dropout fits well to Scaled Exponential Linear Units (SELU) by | ||
randomly setting activations to the negative saturation value. | ||
Args: | ||
rate: Float between 0 and 1. The multiplicative noise will have | ||
standard deviation `sqrt(rate / (1 - rate))`. | ||
noise_shape: 1D integer tensor representing the shape of the | ||
binary alpha dropout mask that will be multiplied with the input. | ||
For instance, if your inputs have shape | ||
`(batch_size, timesteps, features)` and | ||
you want the alpha dropout mask to be the same for all timesteps, | ||
you can use `noise_shape=(batch_size, 1, features)`. | ||
seed: A Python integer to use as random seed. | ||
Call arguments: | ||
inputs: Input tensor (of any rank). | ||
training: Python boolean indicating whether the layer should behave in | ||
training mode (adding alpha dropout) or in inference mode | ||
(doing nothing). | ||
""" | ||
|
||
def __init__(self, rate, noise_shape=None, seed=None, **kwargs): | ||
super().__init__(**kwargs) | ||
if not 0 <= rate <= 1: | ||
raise ValueError( | ||
f"Invalid value received for argument " | ||
"`rate`. Expected a float value between 0 and 1. " | ||
f"Received: rate={rate}" | ||
) | ||
self.rate = rate | ||
self.seed = seed | ||
self.noise_shape = noise_shape | ||
self.seed_generator = backend.random.SeedGenerator(seed) | ||
self.supports_masking = True | ||
self.built = True | ||
|
||
def call(self, inputs, training=False): | ||
if training and self.rate > 0: | ||
noise_shape = self._get_concrete_noise_shape( | ||
inputs, self.noise_shape | ||
) | ||
alpha = 1.6732632423543772848170429916717 | ||
scale = 1.0507009873554804934193349852946 | ||
alpha_p = -alpha * scale | ||
|
||
kept_idx = ops.greater_equal( | ||
ops.random.uniform(noise_shape, seed=self.seed_generator), | ||
self.rate, | ||
) | ||
kept_idx = ops.cast(kept_idx, inputs.dtype) | ||
|
||
# Compute affine transformation parameters | ||
a = ((1 - self.rate) * (1 + self.rate * alpha_p**2)) ** -0.5 | ||
b = -a * alpha_p * self.rate | ||
|
||
# Apply mask | ||
x = inputs * kept_idx + alpha_p * (1 - kept_idx) | ||
return a * x + b | ||
|
||
return inputs | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape | ||
|
||
def _get_concrete_noise_shape(self, inputs, noise_shape): | ||
if noise_shape is None: | ||
return inputs.shape | ||
|
||
concrete_inputs_shape = inputs.shape | ||
concrete_noise_shape = [] | ||
for i, value in enumerate(noise_shape): | ||
concrete_noise_shape.append( | ||
concrete_inputs_shape[i] if value is None else value | ||
) | ||
return concrete_noise_shape | ||
|
||
def get_config(self): | ||
base_config = super().get_config() | ||
config = { | ||
"rate": self.rate, | ||
"seed": self.seed, | ||
"noise_shape": self.noise_shape, | ||
} | ||
return {**base_config, **config} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from keras import backend | ||
from keras import layers | ||
from keras import testing | ||
|
||
|
||
class AlphaDropoutTest(testing.TestCase): | ||
@pytest.mark.requires_trainable_backend | ||
def test_alpha_dropout_basics(self): | ||
self.run_layer_test( | ||
layers.AlphaDropout, | ||
init_kwargs={ | ||
"rate": 0.2, | ||
}, | ||
input_shape=(2, 3), | ||
expected_output_shape=(2, 3), | ||
expected_num_trainable_weights=0, | ||
expected_num_non_trainable_weights=0, | ||
expected_num_seed_generators=1, | ||
expected_num_losses=0, | ||
supports_masking=True, | ||
) | ||
|
||
def test_alpha_dropout_correctness(self): | ||
inputs = np.ones((20, 500)).astype("float32") | ||
layer = layers.AlphaDropout(0.3, seed=1337) | ||
outputs = layer(inputs, training=True) | ||
self.assertAllClose( | ||
np.std(backend.convert_to_numpy(outputs)), 1.0, atol=1e-1 | ||
) | ||
|
||
def test_alpha_dropout_partial_noise_shape_dynamic(self): | ||
inputs = np.ones((20, 5, 10)) | ||
layer = layers.AlphaDropout(0.5, noise_shape=(None, 1, None)) | ||
outputs = layer(inputs, training=True) | ||
self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :]) | ||
|
||
def test_alpha_dropout_partial_noise_shape_static(self): | ||
inputs = np.ones((20, 5, 10)) | ||
layer = layers.AlphaDropout(0.5, noise_shape=(20, 1, 10)) | ||
outputs = layer(inputs, training=True) | ||
self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :]) | ||
|
||
def test_alpha_dropout_negative_rate(self): | ||
with self.assertRaisesRegex( | ||
ValueError, | ||
"Invalid value received for argument `rate`. " | ||
"Expected a float value between 0 and 1.", | ||
): | ||
_ = layers.AlphaDropout(rate=-0.5) | ||
|
||
def test_alpha_dropout_rate_greater_than_one(self): | ||
with self.assertRaisesRegex( | ||
ValueError, | ||
"Invalid value received for argument `rate`. " | ||
"Expected a float value between 0 and 1.", | ||
): | ||
_ = layers.AlphaDropout(rate=1.5) |