Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ingest FP8 attn scales and use them in ROCm FlashAttention #338

Merged
merged 7 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,6 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
raise NotImplementedError
2 changes: 1 addition & 1 deletion vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:

# TODO: directly write to output tensor
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Expand Down
12 changes: 10 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def forward(
v_scale: torch.Tensor,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: torch.Tensor = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

Expand Down Expand Up @@ -601,6 +601,8 @@ def forward(
Returns:
shape = [num_tokens, num_heads * head_size]
"""
q_scale, prob_scale, fp8_out_scale = fp8_comp_scales or (None, None,
None)

query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
Expand Down Expand Up @@ -681,6 +683,12 @@ def forward(
query.dtype,
seq_lens,
make_attn_mask=False) # type: ignore
full_scales = (
1.0 / q_scale.item(), 1.0 / k_scale.item(),
1.0 / v_scale.item(), 1.0 / prob_scale.item(),
fp8_out_scale.item()) if (
fp8_out_scale
and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None
out, _ = self.attn_func(
query,
key,
Expand All @@ -694,7 +702,7 @@ def forward(
self.scale,
attn_masks[0][None]
if attn_masks is not None else None,
None,
full_scales,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def forward(
v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
19 changes: 11 additions & 8 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Attention layer."""
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -75,6 +75,8 @@ def __init__(
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None:
Expand Down Expand Up @@ -106,11 +108,11 @@ def __init__(
self.num_kv_heads = num_kv_heads
self.backend = backend_name_to_enum(attn_backend.get_name())

# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# For cuda and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.is_cuda_alike(
self.use_direct_call = not current_platform.is_cuda(
) and not current_platform.is_cpu()

# For some attention backends, we allocate an output tensor before
Expand All @@ -124,6 +126,7 @@ def __init__(
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix

self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)

Expand All @@ -135,12 +138,11 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_type: str = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
fp8_comp_scales: Optional[Tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
if self.calculate_kv_scales and \
attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)

self.calc_kv_scales(query, key, value)
if self.use_direct_call:
return self.impl.forward(query,
key,
Expand All @@ -150,7 +152,7 @@ def forward(
self._k_scale,
self._v_scale,
attn_type=attn_type,
fp8_out_scale=fp8_out_scale)
fp8_comp_scales=fp8_comp_scales)
elif self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
Expand All @@ -172,7 +174,8 @@ def forward(
kv_cache, attn_type,
self.layer_name)

def calc_kv_scales(self, key, value):
def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
# We only calculate the scales once
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def attn_fwd(
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None] >=
out_mask_boundary[None, :])
z = 0.0
z = tl.zeros((1, ), tl.float32)
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
Expand Down
29 changes: 21 additions & 8 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
VLLM_USE_ROCM_SKINNY_GEMM: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True
VLLM_USE_ROCM_FP8_FLASH_ATTN: bool = False
RANK: int = 0
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
Expand Down Expand Up @@ -83,8 +84,9 @@
VLLM_FP8_PADDING: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
K_SCALE_CONSTANT: int = 200
V_SCALE_CONSTANT: int = 100
Q_SCALE_CONSTANT: int = 20
K_SCALE_CONSTANT: int = 20
V_SCALE_CONSTANT: int = 10


def get_default_cache_root():
Expand Down Expand Up @@ -242,13 +244,18 @@ def get_default_config_root():
# custom paged attention implemented for MI3* cards
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN":
lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1") != "0"),
("true", "1")),

# have custom paged attention implemented for MI3* cards write out fp8
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT":
lambda:
(os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in
("true", "1") != "0"),
("true", "1")),

# use quantized q,k,v,softmax(qk^T), attn output during prefill
"VLLM_USE_ROCM_FP8_FLASH_ATTN":
lambda: (os.getenv("VLLM_USE_ROCM_FP8_FLASH_ATTN", "False").lower() in
("true", "1")),

# rank of the process in the distributed setting, used to determine
# the driver worker
Expand Down Expand Up @@ -530,13 +537,19 @@ def get_default_config_root():
"VLLM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))),

# Divisor for dynamic key scale factor calculation for FP8 KV Cache
# Divisor for dynamic query scale factor calculation for FP8 attention
"Q_SCALE_CONSTANT":
lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")),

# Divisor for dynamic key scale factor calculation
# for FP8 KV Cache and attention
"K_SCALE_CONSTANT":
lambda: int(os.getenv("K_SCALE_CONSTANT", "200")),
lambda: int(os.getenv("K_SCALE_CONSTANT", "20")),

# Divisor for dynamic value scale factor calculation for FP8 KV Cache
# Divisor for dynamic value scale factor calculation
# for FP8 KV Cache and attention
"V_SCALE_CONSTANT":
lambda: int(os.getenv("V_SCALE_CONSTANT", "100")),
lambda: int(os.getenv("V_SCALE_CONSTANT", "10")),

# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
if name.endswith(".output_scale") and ".q_proj" in name:
return name.replace(".q_proj.output_scale", ".attn.q_scale")
if name.endswith("self_attn.prob_output_scale"):
return name.replace(".prob_output_scale", ".attn.prob_scale")
# If no matches, return None
return None

Expand Down
Loading
Loading