Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ops.random.alpha_dropout and layers.AlphaDropout #18940

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions keras/backend/jax/random.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import jax

from keras.backend.config import floatx
from keras.backend.jax.core import cast
from keras.random.seed_generator import SeedGenerator
from keras.random.seed_generator import draw_seed
from keras.random.seed_generator import make_default_seed
Expand Down Expand Up @@ -81,6 +82,24 @@ def dropout(inputs, rate, noise_shape=None, seed=None):
)


def alpha_dropout(inputs, rate, noise_shape=None, seed=None):
noise_shape = _get_concrete_noise_shape(inputs, noise_shape)
alpha = 1.6732632423543772848170429916717
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason for having more than 6-7 decimal points?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice it will be cast to float32 so some precision will be lost. The number above is just taken from the paper.

scale = 1.0507009873554804934193349852946
alpha_p = -alpha * scale

kept_idx = jax.numpy.greater_equal(uniform(noise_shape, seed=seed), rate)
kept_idx = cast(kept_idx, inputs.dtype)

# Compute affine transformation parameters
a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5
b = -a * alpha_p * rate

# Apply mask
x = inputs * kept_idx + alpha_p * (1 - kept_idx)
return a * x + b


def shuffle(x, axis=0, seed=None):
seed = jax_draw_seed(seed)
return jax.random.shuffle(seed, x, axis)
Expand Down
27 changes: 27 additions & 0 deletions keras/backend/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,33 @@ def dropout(inputs, rate, noise_shape=None, seed=None):
return np.where(mask, inputs / keep_prob, np.zeros_like(inputs))


def alpha_dropout(inputs, rate, noise_shape=None, seed=None):
# If noise_shape is not provided, use the shape of inputs
if noise_shape is None:
noise_shape = inputs.shape
else:
# If noise_shape is provided, replace None with corresponding
# input shape
noise_shape = [
n if n is not None else inputs.shape[i]
for i, n in enumerate(noise_shape)
]
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
alpha_p = -alpha * scale

kept_idx = np.greater_equal(uniform(noise_shape, seed=seed), rate)
kept_idx = kept_idx.astype(inputs.dtype)

# Compute affine transformation parameters
a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5
b = -a * alpha_p * rate

# Apply mask
x = inputs * kept_idx + alpha_p * (1 - kept_idx)
return (a * x + b).astype(inputs.dtype)


def shuffle(x, axis=0, seed=None):
seed = draw_seed(seed)
rng = np.random.default_rng(seed)
Expand Down
19 changes: 19 additions & 0 deletions keras/backend/tensorflow/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,25 @@ def dropout(inputs, rate, noise_shape=None, seed=None):
)


def alpha_dropout(inputs, rate, noise_shape=None, seed=None):
noise_shape = _get_concrete_noise_shape(inputs, noise_shape)

alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
alpha_p = -alpha * scale

kept_idx = tf.greater_equal(uniform(noise_shape, seed=seed), rate)
kept_idx = tf.cast(kept_idx, inputs.dtype)

# Compute affine transformation parameters
a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5
b = -a * alpha_p * rate

# Apply mask
x = inputs * kept_idx + alpha_p * (1 - kept_idx)
return a * x + b


def shuffle(x, axis=0, seed=None):
seed = tf_draw_seed(seed)
if axis == 0:
Expand Down
19 changes: 19 additions & 0 deletions keras/backend/torch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as tnn

from keras.backend.config import floatx
from keras.backend.torch.core import cast
from keras.backend.torch.core import convert_to_tensor
from keras.backend.torch.core import get_device
from keras.backend.torch.core import to_torch_dtype
Expand Down Expand Up @@ -166,6 +167,24 @@ def dropout(inputs, rate, noise_shape=None, seed=None):
)


def alpha_dropout(inputs, rate, noise_shape=None, seed=None):
noise_shape = _get_concrete_noise_shape(inputs, noise_shape)
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
alpha_p = -alpha * scale

kept_idx = torch.greater_equal(uniform(noise_shape, seed=seed), rate)
kept_idx = cast(kept_idx, inputs.dtype)

# Compute affine transformation parameters
a = ((1 - rate) * (1 + rate * alpha_p**2)) ** -0.5
b = -a * alpha_p * rate

# Apply mask
x = inputs * kept_idx + alpha_p * (1 - kept_idx)
return a * x + b


def shuffle(x, axis=0, seed=None):
# Ref: https://github.com/pytorch/pytorch/issues/71409
x = convert_to_tensor(x)
Expand Down
1 change: 1 addition & 0 deletions keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from keras.layers.regularization.activity_regularization import (
ActivityRegularization,
)
from keras.layers.regularization.alpha_dropout import AlphaDropout
from keras.layers.regularization.dropout import Dropout
from keras.layers.regularization.gaussian_dropout import GaussianDropout
from keras.layers.regularization.gaussian_noise import GaussianNoise
Expand Down
69 changes: 69 additions & 0 deletions keras/layers/regularization/alpha_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from keras import backend
from keras.api_export import keras_export
from keras.layers.layer import Layer


@keras_export("keras.layers.AlphaDropout")
class AlphaDropout(Layer):
"""Applies Alpha Dropout to the input.

Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
to their original values, in order to ensure the self-normalizing property
even after this dropout.
Alpha Dropout fits well to Scaled Exponential Linear Units (SELU) by
randomly setting activations to the negative saturation value.

Args:
rate: Float between 0 and 1. The multiplicative noise will have
standard deviation `sqrt(rate / (1 - rate))`.
noise_shape: 1D integer tensor representing the shape of the
binary alpha dropout mask that will be multiplied with the input.
For instance, if your inputs have shape
`(batch_size, timesteps, features)` and
you want the alpha dropout mask to be the same for all timesteps,
you can use `noise_shape=(batch_size, 1, features)`.
seed: A Python integer to use as random seed.

Call arguments:
inputs: Input tensor (of any rank).
training: Python boolean indicating whether the layer should behave in
training mode (adding alpha dropout) or in inference mode
(doing nothing).
"""

def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
super().__init__(**kwargs)
if not 0 <= rate <= 1:
raise ValueError(
f"Invalid value received for argument "
"`rate`. Expected a float value between 0 and 1. "
f"Received: rate={rate}"
)
self.rate = rate
self.seed = seed
self.noise_shape = noise_shape
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True
self.built = True

def call(self, inputs, training=False):
if training and self.rate > 0:
return backend.random.alpha_dropout(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can implement the layer in terms of backend ops and random.uniform. There is no real need for the alpha_dropout op -- only the layer would get used (even then, it's fairly niche usage).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I've made the changes.

inputs,
self.rate,
noise_shape=self.noise_shape,
seed=self.seed_generator,
)
return inputs

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self):
base_config = super().get_config()
config = {
"rate": self.rate,
"seed": self.seed,
"noise_shape": self.noise_shape,
}
return {**base_config, **config}
51 changes: 51 additions & 0 deletions keras/layers/regularization/alpha_dropout_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import pytest

from keras import layers
from keras import testing


class AlphaDropoutTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_alpha_dropout_basics(self):
self.run_layer_test(
layers.AlphaDropout,
init_kwargs={
"rate": 0.2,
},
input_shape=(2, 3),
expected_output_shape=(2, 3),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=1,
expected_num_losses=0,
supports_masking=True,
)

def test_alpha_dropout_partial_noise_shape_dynamic(self):
inputs = np.ones((20, 5, 10))
layer = layers.AlphaDropout(0.5, noise_shape=(None, 1, None))
outputs = layer(inputs, training=True)
self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])

def test_alpha_dropout_partial_noise_shape_static(self):
inputs = np.ones((20, 5, 10))
layer = layers.AlphaDropout(0.5, noise_shape=(20, 1, 10))
outputs = layer(inputs, training=True)
self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])

def test_alpha_dropout_negative_rate(self):
with self.assertRaisesRegex(
ValueError,
"Invalid value received for argument `rate`. "
"Expected a float value between 0 and 1.",
):
_ = layers.AlphaDropout(rate=-0.5)

def test_alpha_dropout_rate_greater_than_one(self):
with self.assertRaisesRegex(
ValueError,
"Invalid value received for argument `rate`. "
"Expected a float value between 0 and 1.",
):
_ = layers.AlphaDropout(rate=1.5)
7 changes: 7 additions & 0 deletions keras/random/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ def dropout(inputs, rate, noise_shape=None, seed=None):
)


@keras_export("keras.random.alpha_dropout")
def alpha_dropout(inputs, rate, noise_shape=None, seed=None):
return backend.random.alpha_dropout(
inputs, rate, noise_shape=noise_shape, seed=seed
)


@keras_export("keras.random.shuffle")
def shuffle(x, axis=0, seed=None):
"""Shuffle the elements of a tensor uniformly at random along an axis.
Expand Down
21 changes: 21 additions & 0 deletions keras/random/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,27 @@ def test_dropout_noise_shape(self):
)
self.assertEqual(x.shape, (2, 3, 5, 7))

def test_alpha_dropout(self):
x = ops.random.normal((10000,))
y = random.alpha_dropout(x, rate=0, seed=0)
self.assertAllClose(y, x)
self.assertEqual(x.dtype, y.dtype)

# standard deviation check
y = random.alpha_dropout(x, rate=0.8, seed=0)
self.assertAllClose(ops.std(y), 1.0, atol=1e-1)
y = random.alpha_dropout(x, rate=0.5, seed=0)
self.assertAllClose(ops.std(y), 1.0, atol=1e-1)
y = random.alpha_dropout(x, rate=0.3, seed=0)
self.assertAllClose(ops.std(y), 1.0, atol=1e-1)

def test_alpha_dropout_noise_shape(self):
inputs = ops.ones((2, 3, 5, 7))
x = random.alpha_dropout(
inputs, rate=0.3, noise_shape=[None, 3, 5, None], seed=0
)
self.assertEqual(x.shape, (2, 3, 5, 7))

@pytest.mark.skipif(
keras.backend.backend() != "jax",
reason="This test requires `jax` as the backend.",
Expand Down