From 4fa3e3334978dce74eba296ee8cc2e970ed20e5e Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 20 Oct 2024 10:57:52 -0700 Subject: [PATCH] [Kernel] Support sliding window in flash attention backend (#9403) --- tests/kernels/test_attention_selector.py | 35 ++++++++++-------------- tests/kernels/test_flash_attn.py | 29 +++++++++++--------- vllm/attention/backends/flash_attn.py | 13 ++++----- vllm/attention/layer.py | 7 ++--- vllm/attention/selector.py | 10 ++----- vllm/worker/cache_engine.py | 1 - vllm/worker/cpu_model_runner.py | 1 - vllm/worker/cpu_worker.py | 1 - vllm/worker/model_runner.py | 1 - vllm/worker/openvino_model_runner.py | 1 - vllm/worker/openvino_worker.py | 1 - vllm/worker/tpu_model_runner.py | 1 - vllm/worker/xpu_model_runner.py | 1 - 13 files changed, 41 insertions(+), 61 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index f471dcee938be..5671207ac847e 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -20,21 +20,21 @@ def test_env(name: str, device: str, monkeypatch): if device == "cpu": with patch("vllm.attention.selector.is_cpu", return_value=True): - backend = which_attn_to_use(16, None, torch.float16, torch.float16, - 16, False) + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, + False) assert backend.name == "TORCH_SDPA" elif device == "hip": with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(16, None, torch.float16, torch.float16, - 16, False) + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, + False) assert backend.name == "ROCM_FLASH" elif device == "openvino": with patch("vllm.attention.selector.is_openvino", return_value=True): - backend = which_attn_to_use(16, None, torch.float16, torch.float16, - 16, False) + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, + False) assert backend.name == "OPENVINO" else: - backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == name @@ -46,37 +46,32 @@ def test_flash_attn(monkeypatch): # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=(7, 5)): - backend = which_attn_to_use(16, None, torch.float16, None, 16, False) + backend = which_attn_to_use(16, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported data type - backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False) + backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported kv cache data type - backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False) + backend = which_attn_to_use(16, torch.float16, "fp8", 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported block size - backend = which_attn_to_use(16, None, torch.float16, None, 8, False) - assert backend.name != STR_FLASH_ATTN_VAL - - # Unsupported sliding window - backend = which_attn_to_use(16, 1, torch.float16, None, 16, False) + backend = which_attn_to_use(16, torch.float16, None, 8, False) assert backend.name != STR_FLASH_ATTN_VAL # flash-attn is not installed with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(16, None, torch.float16, None, 16, False) + backend = which_attn_to_use(16, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported head size - backend = which_attn_to_use(17, None, torch.float16, None, 16, False) + backend = which_attn_to_use(17, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Attention-free models should bypass env and use PlaceholderAttention - backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, - True) + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True) assert backend.name != STR_FLASH_ATTN_VAL @@ -84,4 +79,4 @@ def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" override_backend_env_variable(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): - which_attn_to_use(16, None, torch.float16, None, 16, False) + which_attn_to_use(16, torch.float16, None, 16, False) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 3e9b4d9a4f8a0..35c29c5bd1028 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -78,6 +78,7 @@ def ref_paged_attn( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("sliding_window", [None, 256]) @torch.inference_mode() def test_flash_attn_with_paged_kv( kv_lens: List[int], @@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv( block_size: int, soft_cap: Optional[float], num_blocks: int, + sliding_window: Optional[int], ) -> None: torch.set_default_device("cuda") seed_everything(0) @@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_kv_len = max(kv_lens) scale = head_size**-0.5 + window_size = ((sliding_window - 1, 0) if sliding_window is not None else + (-1, -1)) query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) key_cache = torch.randn(num_blocks, @@ -121,18 +125,18 @@ def test_flash_attn_with_paged_kv( block_table=block_tables, cache_seqlens=kv_lens_tensor, softcap=soft_cap if soft_cap is not None else 0, + window_size=window_size, ).squeeze(1) - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - ) + ref_output = ref_paged_attn(query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + sliding_window=sliding_window) torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" @@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("sliding_window", [None, 256]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @@ -166,8 +170,7 @@ def test_varlen_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) - window_size = ((sliding_window, - sliding_window) if sliding_window is not None else + window_size = ((sliding_window - 1, 0) if sliding_window is not None else (-1, -1)) scale = head_size**-0.5 diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index d54dbdcb19495..d538286a0dddd 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -524,8 +524,8 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. @@ -535,12 +535,6 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if sliding_window is not None: - # NOTE(woosuk): flash-attn's sliding window does not work with - # paged KV cache. - raise ValueError( - "Sliding window is not supported in FlashAttention.") - support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() if head_size not in support_head_sizes: raise ValueError( @@ -704,6 +698,7 @@ def unified_flash_attention( max_seqlen_k=max_seq_len, softmax_scale=softmax_scale, causal=True, + window_size=window_size, alibi_slopes=alibi_slopes, block_table=prefill_meta.block_tables, softcap=logits_soft_cap, @@ -725,6 +720,7 @@ def unified_flash_attention( max_seqlen_k=decode_meta.max_decode_seq_len, softmax_scale=softmax_scale, causal=True, + window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, block_table=decode_meta.block_tables, @@ -739,6 +735,7 @@ def unified_flash_attention( cache_seqlens=decode_meta.seq_lens_tensor, softmax_scale=softmax_scale, causal=True, + window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, ).squeeze(1) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b46f0721d0caf..33d05cbd3fe01 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -78,10 +78,9 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, sliding_window, dtype, - kv_cache_dtype, block_size, - is_attention_free, blocksparse_params - is not None) + attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype, + block_size, is_attention_free, + blocksparse_params is not None) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7edb7676ea2cd..4ff86573e664d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -90,7 +90,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: @lru_cache(maxsize=None) def get_attn_backend( head_size: int, - sliding_window: Optional[int], dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, @@ -105,8 +104,8 @@ def get_attn_backend( BlocksparseFlashAttentionBackend) return BlocksparseFlashAttentionBackend - backend = which_attn_to_use(head_size, sliding_window, dtype, - kv_cache_dtype, block_size, is_attention_free) + backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, + is_attention_free) if backend == _Backend.FLASH_ATTN: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) @@ -155,7 +154,6 @@ def get_attn_backend( def which_attn_to_use( head_size: int, - sliding_window: Optional[int], dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, @@ -243,10 +241,6 @@ def which_attn_to_use( "Cannot use FlashAttention-2 backend for block size not " "divisible by 16.") selected_backend = _Backend.XFORMERS - elif sliding_window is not None: - logger.info( - "Cannot use FlashAttention-2 backend due to sliding window.") - selected_backend = _Backend.XFORMERS # FlashAttn is valid for the model, checking if the package is installed. if selected_backend == _Backend.FLASH_ATTN: diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 090f95e6e892c..ac3270d1c9909 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -53,7 +53,6 @@ def __init__( # Get attention backend. self.attn_backend = get_attn_backend(self.head_size, - model_config.get_sliding_window(), model_config.dtype, cache_config.cache_dtype, self.block_size, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index dd38b550eb011..5032896600b3b 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -420,7 +420,6 @@ def __init__( self.block_size = cache_config.block_size self.attn_backend = get_attn_backend( self.model_config.get_head_size(), - self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index b84562851f0f8..ab93471b5af74 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -57,7 +57,6 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, # Get attention backend. self.attn_backend = get_attn_backend( self.model_config.get_head_size(), - self.model_config.get_sliding_window(), self.model_config.dtype, cache_config.cache_dtype, self.block_size, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index a82956985af55..dc1674cd1ea20 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1011,7 +1011,6 @@ def __init__( self.attn_backend = get_attn_backend( self.model_config.get_head_size(), - self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 760b18427e22b..a164fbe3393c4 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -75,7 +75,6 @@ def __init__( self.attn_backend = get_attn_backend( self.model_config.get_head_size(), - self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size, diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index 24425fece850f..bc245d19663d6 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -71,7 +71,6 @@ def __init__( # Get attention backend. self.attn_backend = get_attn_backend( self.head_size, - self.model_config.get_sliding_window(), self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index f7e5f660c0249..87ced7818a676 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -114,7 +114,6 @@ def __init__( dtype=np.int32) self.attn_backend = get_attn_backend( self.model_config.get_head_size(), - self.model_config.get_sliding_window(), self.model_config.dtype, self.cache_config.cache_dtype, self.block_size, diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 5ff4626c060b3..75a6de3b24ba4 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -374,7 +374,6 @@ def __init__( self.attn_backend = get_attn_backend( self.model_config.get_head_size(), - self.model_config.get_sliding_window(), self.model_config.dtype, self.kv_cache_dtype, self.block_size,