From 0a56bcc03de0857be464c3f8783258d590cbc762 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Fri, 13 Dec 2024 20:00:40 +0200 Subject: [PATCH] [Bugfix][Hardware][CPU] Enable Gemma2 with SDPA on CPU backend (#11169) --- vllm/attention/backends/torch_sdpa.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 86e952a903f36..0cff6f5952aba 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -13,7 +13,7 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.ipex_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttentionMetadata -from vllm.utils import make_tensor_with_pad +from vllm.utils import make_tensor_with_pad, print_warning_once from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder @@ -395,7 +395,8 @@ def __init__( raise ValueError( "Torch SPDA does not support block-sparse attention.") if logits_soft_cap is not None: - raise ValueError("Torch SPDA does not support logits soft cap.") + print_warning_once("Torch SPDA does not support logits soft cap. " + "Outputs may be slightly off.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -619,7 +620,7 @@ def _run_sdpa_forward( value[None, :, start_kv:end_kv, :], attn_mask=mask, dropout_p=0.0, - is_causal=causal_attn and not self.need_mask, + is_causal=causal_attn and mask is None, scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) output[start_q:end_q, :, :] = sub_out start_q, start_kv = end_q, end_kv