Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable softcap and gemma2 #288

Merged
merged 9 commits into from
Dec 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,13 @@ def __init__(
if blocksparse_params is not None:
raise ValueError(
"ROCmFlashAttention does not support blocksparse attention.")
if logits_soft_cap is not None:
raise ValueError(
"ROCmFlashAttention does not support attention logits soft "
"capping.")

if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
self.logits_soft_cap = 0.0
else:
self.logits_soft_cap = logits_soft_cap

self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand All @@ -487,6 +490,14 @@ def __init__(
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
if logits_soft_cap is not None:
raise ValueError(
"ROCm Triton FlashAttention does not support attention"
"logits soft capping."
" please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")

from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
Expand All @@ -511,6 +522,11 @@ def __init__(
self.use_naive_attn = True

if self.use_naive_attn:
if logits_soft_cap is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You make sure it's not None in the constructor. So naive flash attention can never be used now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed. thanks.

raise ValueError(
"ROCm Naive FlashAttention does not support"
"attention logits soft capping.")

self.attn_func = _sdpa_attention
logger.debug("Using naive (SDPA) attention in ROCmBackend")

Expand Down Expand Up @@ -717,6 +733,7 @@ def forward(
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap,
)

# common code for prefill
Expand Down
Loading