From e93ca12c881c0c79e9ff652abf79f6b6e02feb00 Mon Sep 17 00:00:00 2001 From: eqy Date: Sun, 25 Aug 2024 14:31:30 +0000 Subject: [PATCH] [CUDNN][SDPA] Fix unsupported trivial stride-1 transpose case (#134031) Fixes #134001 Incorrect assumption that two same-shape tensors being contiguous meant that they would have the same stride Pull Request resolved: https://github.com/pytorch/pytorch/pull/134031 Approved by: https://github.com/drisspg, https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan --- aten/src/ATen/native/cudnn/MHA.cpp | 10 +++++++--- test/test_transformers.py | 16 ++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index 8e8e6d0df4d0a..d00a8eb30698b 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -662,7 +662,10 @@ void run_cudnn_SDP_bprop( " Materializing a contiguous tensor which will increase memory usage..."); dO_ = dO.contiguous(); } - if (!std::equal( + if ( // handle trivial transposed case with a transposed dim of size 1 + // see also: https://github.com/pytorch/pytorch/issues/134001 + !(dO_.is_contiguous() && o.is_contiguous()) && + !std::equal( o.strides().begin(), o.strides().end(), dO.strides().begin())) { TORCH_WARN( "cuDNN SDPA backward got grad_output.strides() != output.strides(), " @@ -674,8 +677,9 @@ void run_cudnn_SDP_bprop( } } TORCH_INTERNAL_ASSERT( - std::equal( - dO_.strides().begin(), dO_.strides().end(), o.strides().begin()), + (dO_.is_contiguous() && o.is_contiguous()) || + std::equal( + dO_.strides().begin(), dO_.strides().end(), o.strides().begin()), "cuDNN SDPA expected grad_output.strides() == output.strides(), " "the previous step probably failed to materialize a grad_output " "with matching strides..."); diff --git a/test/test_transformers.py b/test/test_transformers.py index 68f46330a120d..d0e3495a68651 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2398,6 +2398,22 @@ def test_cudnn_attention_fail_d128(self, device): with self.assertRaisesRegex(RuntimeError, "No available kernel."): o = torch.nn.functional.scaled_dot_product_attention(q, k, v) + @skipIfRocm # No cuDNN Attention + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") + def test_cudnn_attention_trivial_output_transpose(self, device): + # see also: https://github.com/pytorch/pytorch/issues/134001 + x = torch.randn(2, 4, 1, 64, device='cuda', dtype=torch.float16, requires_grad=True) + x2 = x.transpose(1, 2) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + o = torch.nn.functional.scaled_dot_product_attention(x2, x2, x2).transpose(1, 2).reshape(2, 64, 4) + o.backward(o) + x_cpu = x.clone().cpu().detach() + x_cpu.requires_grad = True + x2_cpu = x_cpu.transpose(1, 2) + o = torch.nn.functional.scaled_dot_product_attention(x2_cpu, x2_cpu, x2_cpu).transpose(1, 2).reshape(2, 64, 4) + o.backward(o) + torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3) + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]):