Skip to content

Commit

Permalink
Move sdp_choice to pytorch & remove unimplemented sdpa_mem fallback (#…
Browse files Browse the repository at this point in the history
…1138)

As title.

To work with pytorch/pytorch#140389
  • Loading branch information
DDEle authored Dec 4, 2024
1 parent 41a06fc commit be810b5
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 39 deletions.
38 changes: 6 additions & 32 deletions src/ATen/native/transformers/Attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,36 +93,6 @@ static bool check_for_seq_len_1_nested_tensor(
return true;
}

int64_t _fused_sdp_choice_xpu(
const Tensor& query,
const Tensor& key,
const Tensor& value,
const std::optional<Tensor>& attn_mask_,
double dropout_p,
bool is_causal,
std::optional<double> scale,
bool enable_gqa) {
// We have implemented efficient_attention backend with xetla, flash_attention
// backend is not supported now, which will be implemented in the future. So
// we provide two backends here.
sdp::sdp_params kernel_params{
query, key, value, attn_mask_, dropout_p, is_causal, enable_gqa};
// Because TORCHCHECK checks if condition is true we negate debug so that
// The statements will be printed when debug is true
bool print_debug = false;
sdp::SDPBackend backend =
sdp::can_use_mem_efficient_attention(kernel_params, print_debug)
? sdp::SDPBackend::efficient_attention
: sdp::SDPBackend::math;
if (backend == sdp::SDPBackend::error) {
TORCH_CHECK(
false,
"No viable backend for scaled_dot_product_attention was found. ",
"This is likely due to turning off both the math kernel and the fused kernels.");
}
return static_cast<int64_t>(backend);
}

std::tuple<Tensor, Tensor> native_multi_head_attention_xpu(
const Tensor& query,
const Tensor& key,
Expand Down Expand Up @@ -204,8 +174,12 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_xpu(
value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2);

sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false, false};
auto backend = static_cast<sdp::SDPBackend>(
_fused_sdp_choice_xpu(q, k, v, mask, 0.0, false, {}, false));

sdp::SDPBackend backend = sdp::SDPBackend::math;
if (_fused_sdp_choice_stub.is_device_supported(q.device().type())) {
backend = static_cast<sdp::SDPBackend>(_fused_sdp_choice_stub(
q.device().type(), q, k, v, mask, 0.0, false, std::nullopt, false));
}

// strides from packed projection for nested tensors when seq_len is 1 will
// be and will trigger a contiguous call in the kernel, so we prevent this
Expand Down
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_linalg_svd.U",
"lu_unpack.out",
"ormqr",
"_scaled_dot_product_efficient_attention",
"_scaled_mm",
"_thnn_fused_gru_cell",
"_to_sparse_csr",
Expand Down
6 changes: 0 additions & 6 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5969,12 +5969,6 @@
XPU: native_multi_head_attention_xpu
autogen: _native_multi_head_attention.out

# This aten function is kept so that we can test the choice function from Python
- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int
dispatch:
XPU: _fused_sdp_choice_xpu
tags: nondeterministic_seeded

- func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
structured_delegate: argmin.out
device_check: NoCheck # TensorIterator
Expand Down

0 comments on commit be810b5

Please sign in to comment.