From 84d1387373257b183b4ebbd4a7843b295acb38bb Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+hliuca@users.noreply.github.com> Date: Wed, 20 Nov 2024 13:54:30 -0800 Subject: [PATCH 1/6] enable softcap for gemma2 --- vllm/attention/backends/rocm_flash_attn.py | 30 ++++++++++++++-------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 591bedfa3a6f1..252e5af2bca91 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -218,12 +218,6 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) - # Batch may be composed of prefill|decodes, adjust query start indices - # to refer to the start of decodes when the two are split apart. - # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. - if self._cached_decode_metadata.query_start_loc is not None: - qs = self._cached_decode_metadata.query_start_loc - self._cached_decode_metadata.query_start_loc = qs - qs[0] return self._cached_decode_metadata def advance_step(self, @@ -459,10 +453,12 @@ def __init__( if blocksparse_params is not None: raise ValueError( "ROCmFlashAttention does not support blocksparse attention.") - if logits_soft_cap is not None: - raise ValueError( - "ROCmFlashAttention does not support attention logits soft " - "capping.") + + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -487,6 +483,14 @@ def __init__( # NOTE: Allow for switching between Triton and CK. Defaulting to triton. self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN if self.use_triton_flash_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Triton FlashAttention does not support attention" + "logits soft capping." + " please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) self.attn_func = triton_attention @@ -511,6 +515,11 @@ def __init__( self.use_naive_attn = True if self.use_naive_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Naive FlashAttention does not support attention" + "logits soft capping.") + self.attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") @@ -716,6 +725,7 @@ def forward( causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, ) # common code for prefill From 1348c232972306a4827e23a8f93018b5753a0cfa Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+hliuca@users.noreply.github.com> Date: Wed, 20 Nov 2024 14:02:41 -0800 Subject: [PATCH 2/6] fix lint --- csrc/activation_kernels.cu | 29 +++++++++++----------- csrc/layernorm_kernels.cu | 2 +- vllm/attention/backends/rocm_flash_attn.py | 23 ++++++++--------- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 94a1c4a678cf7..8c3e74e502a55 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -93,20 +93,21 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { // Launch activation and gating kernel. #ifdef USE_ROCM -#define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ - vllm::scaled_act_and_mul_kernel> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d, \ - 1.0 / (*scale.data_ptr())); \ - }); + #define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ + vllm::scaled_act_and_mul_kernel> \ + <<>>( \ + out.data_ptr(), \ + input.data_ptr(), d, \ + 1.0 / (*scale.data_ptr())); \ + }); #endif void silu_and_mul(torch::Tensor& out, // [..., d] diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index de098a9ee0c19..405ba213628f6 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -247,7 +247,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] LAUNCH_RMS_NORM(0); } #else - LAUNCH_RMS_NORM(0); + LAUNCH_RMS_NORM(0); #endif } diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 252e5af2bca91..c2f448d08398e 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch +from flash_attn import flash_attn_varlen_func # noqa: F401 import vllm.envs as envs from vllm import _custom_ops as ops @@ -485,11 +486,11 @@ def __init__( if self.use_triton_flash_attn: if logits_soft_cap is not None: raise ValueError( - "ROCm Triton FlashAttention does not support attention" - "logits soft capping." - " please try using the ROCm CK " - "FA backend instead by setting the env var " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") + "ROCm Triton FlashAttention does not support attention" + "logits soft capping." + " please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) @@ -507,18 +508,14 @@ def __init__( if not current_platform.has_device_capability(90): self.use_naive_attn = True else: - try: - from flash_attn import flash_attn_varlen_func # noqa: F401 - self.attn_func = flash_attn_varlen_func - logger.debug("Using CK FA in ROCmBackend") - except ModuleNotFoundError: - self.use_naive_attn = True + self.attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") if self.use_naive_attn: if logits_soft_cap is not None: raise ValueError( - "ROCm Naive FlashAttention does not support attention" - "logits soft capping.") + "ROCm Naive FlashAttention does not support" + "attention logits soft capping.") self.attn_func = _sdpa_attention logger.debug("Using naive (SDPA) attention in ROCmBackend") From e5cf3dab7582766ee940c73ba84969c4986f6297 Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+hliuca@users.noreply.github.com> Date: Wed, 20 Nov 2024 14:07:58 -0800 Subject: [PATCH 3/6] restore fa --- vllm/attention/backends/rocm_flash_attn.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c2f448d08398e..77eadb2997689 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch -from flash_attn import flash_attn_varlen_func # noqa: F401 import vllm.envs as envs from vllm import _custom_ops as ops @@ -508,8 +507,12 @@ def __init__( if not current_platform.has_device_capability(90): self.use_naive_attn = True else: - self.attn_func = flash_attn_varlen_func - logger.debug("Using CK FA in ROCmBackend") + try: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + except ModuleNotFoundError: + self.use_naive_attn = True if self.use_naive_attn: if logits_soft_cap is not None: From 9242621cea291883ee71be072ecf672d1fc5b510 Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+hliuca@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:14:25 -0800 Subject: [PATCH 4/6] restore accidental deletion --- vllm/attention/backends/rocm_flash_attn.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index abec46b2e599c..904e4ede90f8f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -218,6 +218,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, cross_block_tables=self.cross_block_tables) + # Batch may be composed of prefill|decodes, adjust query start indices + # to refer to the start of decodes when the two are split apart. + # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + if self._cached_decode_metadata.query_start_loc is not None: + qs = self._cached_decode_metadata.query_start_loc + self._cached_decode_metadata.query_start_loc = qs - qs[0] return self._cached_decode_metadata def advance_step(self, From 8cdb96f1d26beb8609cdb4d56dee1ef9b22b089b Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+hliuca@users.noreply.github.com> Date: Mon, 2 Dec 2024 14:26:36 -0800 Subject: [PATCH 5/6] fix logits_soft_cap constructor --- vllm/attention/backends/rocm_flash_attn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 904e4ede90f8f..cdf1c52b92af8 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -462,8 +462,9 @@ def __init__( if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap + self.logits_soft_cap = 0 + else: + self.logits_soft_cap = logits_soft_cap self.num_heads = num_heads self.head_size = head_size From 566ebdbc9a25376fd611ec00b107e610425508f9 Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+hliuca@users.noreply.github.com> Date: Tue, 3 Dec 2024 15:36:37 -0800 Subject: [PATCH 6/6] use 0.0 instead of 0 --- vllm/attention/backends/rocm_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index cdf1c52b92af8..b4f4e5bb1500a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -462,7 +462,7 @@ def __init__( if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. - self.logits_soft_cap = 0 + self.logits_soft_cap = 0.0 else: self.logits_soft_cap = logits_soft_cap