diff --git a/src/ATen/native/transformers/Attention.cpp b/src/ATen/native/transformers/Attention.cpp index bb8b4602b..3090dfbee 100644 --- a/src/ATen/native/transformers/Attention.cpp +++ b/src/ATen/native/transformers/Attention.cpp @@ -93,36 +93,6 @@ static bool check_for_seq_len_1_nested_tensor( return true; } -int64_t _fused_sdp_choice_xpu( - const Tensor& query, - const Tensor& key, - const Tensor& value, - const std::optional& attn_mask_, - double dropout_p, - bool is_causal, - std::optional scale, - bool enable_gqa) { - // We have implemented efficient_attention backend with xetla, flash_attention - // backend is not supported now, which will be implemented in the future. So - // we provide two backends here. - sdp::sdp_params kernel_params{ - query, key, value, attn_mask_, dropout_p, is_causal, enable_gqa}; - // Because TORCHCHECK checks if condition is true we negate debug so that - // The statements will be printed when debug is true - bool print_debug = false; - sdp::SDPBackend backend = - sdp::can_use_mem_efficient_attention(kernel_params, print_debug) - ? sdp::SDPBackend::efficient_attention - : sdp::SDPBackend::math; - if (backend == sdp::SDPBackend::error) { - TORCH_CHECK( - false, - "No viable backend for scaled_dot_product_attention was found. ", - "This is likely due to turning off both the math kernel and the fused kernels."); - } - return static_cast(backend); -} - std::tuple native_multi_head_attention_xpu( const Tensor& query, const Tensor& key, @@ -204,8 +174,12 @@ std::tuple native_multi_head_attention_xpu( value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2); sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false, false}; - auto backend = static_cast( - _fused_sdp_choice_xpu(q, k, v, mask, 0.0, false, {}, false)); + + sdp::SDPBackend backend = sdp::SDPBackend::math; + if (_fused_sdp_choice_stub.is_device_supported(q.device().type())) { + backend = static_cast(_fused_sdp_choice_stub( + q.device().type(), q, k, v, mask, 0.0, false, std::nullopt, false)); + } // strides from packed projection for nested tensors when seq_len is 1 will // be and will trigger a contiguous call in the kernel, so we prevent this diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 8492a98be..1df3cd072 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -184,7 +184,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_linalg_svd.U", "lu_unpack.out", "ormqr", - "_scaled_dot_product_efficient_attention", "_scaled_mm", "_thnn_fused_gru_cell", "_to_sparse_csr", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index e3bec5484..40b710c12 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -5969,12 +5969,6 @@ XPU: native_multi_head_attention_xpu autogen: _native_multi_head_attention.out -# This aten function is kept so that we can test the choice function from Python -- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int - dispatch: - XPU: _fused_sdp_choice_xpu - tags: nondeterministic_seeded - - func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor structured_delegate: argmin.out device_check: NoCheck # TensorIterator