Skip to content

Commit

Permalink
[CUDNN][SDPA] Fix unsupported trivial stride-1 transpose case (pytorc…
Browse files Browse the repository at this point in the history
…h#134031)

Fixes pytorch#134001
Incorrect assumption that two same-shape tensors being contiguous meant that they would have the same stride

Pull Request resolved: pytorch#134031
Approved by: https://github.com/drisspg, https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <[email protected]>
  • Loading branch information
2 people authored and pytorchmergebot committed Aug 25, 2024
1 parent 08d1112 commit e93ca12
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
10 changes: 7 additions & 3 deletions aten/src/ATen/native/cudnn/MHA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,10 @@ void run_cudnn_SDP_bprop(
" Materializing a contiguous tensor which will increase memory usage...");
dO_ = dO.contiguous();
}
if (!std::equal(
if ( // handle trivial transposed case with a transposed dim of size 1
// see also: https://github.com/pytorch/pytorch/issues/134001
!(dO_.is_contiguous() && o.is_contiguous()) &&
!std::equal(
o.strides().begin(), o.strides().end(), dO.strides().begin())) {
TORCH_WARN(
"cuDNN SDPA backward got grad_output.strides() != output.strides(), "
Expand All @@ -674,8 +677,9 @@ void run_cudnn_SDP_bprop(
}
}
TORCH_INTERNAL_ASSERT(
std::equal(
dO_.strides().begin(), dO_.strides().end(), o.strides().begin()),
(dO_.is_contiguous() && o.is_contiguous()) ||
std::equal(
dO_.strides().begin(), dO_.strides().end(), o.strides().begin()),
"cuDNN SDPA expected grad_output.strides() == output.strides(), "
"the previous step probably failed to materialize a grad_output "
"with matching strides...");
Expand Down
16 changes: 16 additions & 0 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2398,6 +2398,22 @@ def test_cudnn_attention_fail_d128(self, device):
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
o = torch.nn.functional.scaled_dot_product_attention(q, k, v)

@skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_trivial_output_transpose(self, device):
# see also: https://github.com/pytorch/pytorch/issues/134001
x = torch.randn(2, 4, 1, 64, device='cuda', dtype=torch.float16, requires_grad=True)
x2 = x.transpose(1, 2)
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
o = torch.nn.functional.scaled_dot_product_attention(x2, x2, x2).transpose(1, 2).reshape(2, 64, 4)
o.backward(o)
x_cpu = x.clone().cpu().detach()
x_cpu.requires_grad = True
x2_cpu = x_cpu.transpose(1, 2)
o = torch.nn.functional.scaled_dot_product_attention(x2_cpu, x2_cpu, x2_cpu).transpose(1, 2).reshape(2, 64, 4)
o.backward(o)
torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3)

@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("mask_dim", [1, 2, 3, 4])
def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]):
Expand Down

0 comments on commit e93ca12

Please sign in to comment.