Skip to content

Commit

Permalink
[Bugfix][Hardware][CPU] Enable Gemma2 with SDPA on CPU backend (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
janimo authored Dec 13, 2024
1 parent 0920ab9 commit 0a56bcc
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import make_tensor_with_pad
from vllm.utils import make_tensor_with_pad, print_warning_once
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder


Expand Down Expand Up @@ -395,7 +395,8 @@ def __init__(
raise ValueError(
"Torch SPDA does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError("Torch SPDA does not support logits soft cap.")
print_warning_once("Torch SPDA does not support logits soft cap. "
"Outputs may be slightly off.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand Down Expand Up @@ -619,7 +620,7 @@ def _run_sdpa_forward(
value[None, :, start_kv:end_kv, :],
attn_mask=mask,
dropout_p=0.0,
is_causal=causal_attn and not self.need_mask,
is_causal=causal_attn and mask is None,
scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0)
output[start_q:end_q, :, :] = sub_out
start_q, start_kv = end_q, end_kv
Expand Down

0 comments on commit 0a56bcc

Please sign in to comment.