Skip to content

Commit

Permalink
Revert behavior of Dropout2d on 3D inputs to 1D channel-wise dropout …
Browse files Browse the repository at this point in the history
…behavior & warn

Pull Request resolved: pytorch#79549

Approved by: https://github.com/ngimel, https://github.com/albanD
  • Loading branch information
jbschlosser authored and pytorchmergebot committed Jun 15, 2022
1 parent 2d73c8e commit 5953fd9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 12 deletions.
14 changes: 10 additions & 4 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14432,10 +14432,16 @@ def test_Dropout2d(self, device):
with self.assertWarnsRegex(UserWarning, "Received a 2-D input to dropout2d"):
nn.Dropout2d(p=0.5)(torch.rand(1, 2, device=device))

# no batch dims
input = torch.rand(50, 2, 2, device=device)
self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input)
self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input)
# TODO: Uncomment these lines once no-batch-dim inputs are supported.
# For now, the historical dropout1d behavior is performed for 3D inputs.
# See https://github.com/pytorch/pytorch/issues/77081

# input = torch.rand(50, 2, 2, device=device)
# self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input)
# self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input)

with self.assertWarnsRegex(UserWarning, "assuming that channel-wise 1D dropout behavior is desired"):
nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, device=device))

# check that complete channels are dropped
input = torch.ones(10, 4, 2, 2, device=device)
Expand Down
16 changes: 10 additions & 6 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,15 +1330,19 @@ def dropout2d(input: Tensor, p: float = 0.5, training: bool = True, inplace: boo
"a channel dimension, and an optional batch dimension (i.e. 3D or 4D inputs).")
warnings.warn(warn_msg)

is_batched = inp_dim == 4
if not is_batched:
input = input.unsqueeze_(0) if inplace else input.unsqueeze(0)
# TODO: Properly support no-batch-dim inputs. For now, these are NOT supported; passing
# a 3D input will perform dropout1d behavior instead. This was done historically and the
# behavior is maintained here for now.
# See https://github.com/pytorch/pytorch/issues/77081
if inp_dim == 3:
warnings.warn("dropout2d: Received a 3D input to dropout2d and assuming that channel-wise "
"1D dropout behavior is desired - input is interpreted as shape (N, C, L), where C "
"is the channel dim. This behavior will change in a future release to interpret the "
"input as one without a batch dimension, i.e. shape (C, H, W). To maintain the 1D "
"channel-wise dropout behavior, please switch to using dropout1d instead.")

result = _VF.feature_dropout_(input, p, training) if inplace else _VF.feature_dropout(input, p, training)

if not is_batched:
result = result.squeeze_(0) if inplace else result.squeeze(0)

return result


Expand Down
11 changes: 9 additions & 2 deletions torch/nn/modules/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,16 @@ class Dropout2d(_DropoutNd):
inplace (bool, optional): If set to ``True``, will do this operation
in-place
.. warning ::
Due to historical reasons, this class will perform 1D channel-wise dropout
for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT
support inputs without a batch dimension of shape :math:`(C, H, W)`. This
behavior will change in a future release to interpret 3D inputs as no-batch-dim
inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`.
Shape:
- Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`.
- Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input).
- Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
- Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
Examples::
Expand Down

0 comments on commit 5953fd9

Please sign in to comment.