Skip to content

Commit

Permalink
[Kernel] Support sliding window in flash attention backend (vllm-proj…
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored Oct 20, 2024
1 parent 962d2c6 commit 4fa3e33
Show file tree
Hide file tree
Showing 13 changed files with 41 additions and 61 deletions.
35 changes: 15 additions & 20 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@ def test_env(name: str, device: str, monkeypatch):

if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name

Expand All @@ -46,42 +46,37 @@ def test_flash_attn(monkeypatch):

# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported data type
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False)
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported kv cache data type
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False)
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported block size
backend = which_attn_to_use(16, None, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported sliding window
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
backend = which_attn_to_use(16, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL

# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Unsupported head size
backend = which_attn_to_use(17, None, torch.float16, None, 16, False)
backend = which_attn_to_use(17, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL

# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
True)
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
assert backend.name != STR_FLASH_ATTN_VAL


def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(16, None, torch.float16, None, 16, False)
which_attn_to_use(16, torch.float16, None, 16, False)
29 changes: 16 additions & 13 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256])
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
kv_lens: List[int],
Expand All @@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
sliding_window: Optional[int],
) -> None:
torch.set_default_device("cuda")
seed_everything(0)
Expand All @@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))

query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks,
Expand All @@ -121,18 +125,18 @@ def test_flash_attn_with_paged_kv(
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
).squeeze(1)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"

Expand All @@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
Expand All @@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window,
sliding_window) if sliding_window is not None else
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
scale = head_size**-0.5

Expand Down
13 changes: 5 additions & 8 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,8 @@ def __init__(
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
Expand All @@ -535,12 +535,6 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if sliding_window is not None:
# NOTE(woosuk): flash-attn's sliding window does not work with
# paged KV cache.
raise ValueError(
"Sliding window is not supported in FlashAttention.")

support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
Expand Down Expand Up @@ -704,6 +698,7 @@ def unified_flash_attention(
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
Expand All @@ -725,6 +720,7 @@ def unified_flash_attention(
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
Expand All @@ -739,6 +735,7 @@ def unified_flash_attention(
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)
Expand Down
7 changes: 3 additions & 4 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,9 @@ def __init__(
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size, sliding_window, dtype,
kv_cache_dtype, block_size,
is_attention_free, blocksparse_params
is not None)
attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
block_size, is_attention_free,
blocksparse_params is not None)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
Expand Down
10 changes: 2 additions & 8 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
@lru_cache(maxsize=None)
def get_attn_backend(
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
Expand All @@ -105,8 +104,8 @@ def get_attn_backend(
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend

backend = which_attn_to_use(head_size, sliding_window, dtype,
kv_cache_dtype, block_size, is_attention_free)
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
is_attention_free)
if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
Expand Down Expand Up @@ -155,7 +154,6 @@ def get_attn_backend(

def which_attn_to_use(
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
Expand Down Expand Up @@ -243,10 +241,6 @@ def which_attn_to_use(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
selected_backend = _Backend.XFORMERS
elif sliding_window is not None:
logger.info(
"Cannot use FlashAttention-2 backend due to sliding window.")
selected_backend = _Backend.XFORMERS

# FlashAttn is valid for the model, checking if the package is installed.
if selected_backend == _Backend.FLASH_ATTN:
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(

# Get attention backend.
self.attn_backend = get_attn_backend(self.head_size,
model_config.get_sliding_window(),
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,6 @@ def __init__(
self.block_size = cache_config.block_size
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
# Get attention backend.
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype,
cache_config.cache_dtype,
self.block_size,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,6 @@ def __init__(

self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/openvino_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(

self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/openvino_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def __init__(
# Get attention backend.
self.attn_backend = get_attn_backend(
self.head_size,
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def __init__(
dtype=np.int32)
self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
Expand Down
1 change: 0 additions & 1 deletion vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ def __init__(

self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
Expand Down

0 comments on commit 4fa3e33

Please sign in to comment.