Skip to content

Commit

Permalink
[Bugfix] Fix illegal memory access error with chunked prefill, pref…
Browse files Browse the repository at this point in the history
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
  • Loading branch information
sasha0552 authored Oct 31, 2024
1 parent 77f7ef2 commit 55650c8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
28 changes: 28 additions & 0 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@
import pytest

from tests.kernels.utils import override_backend_env_variable
from vllm import SamplingParams, TokensPrompt

from ..models.utils import check_outputs_equal

MODELS = [
"facebook/opt-125m",
]

UNSTABLE_PROMPT_SEQUENCE = [
([0] * 588) + ([1] * 1332) + ([2] * 30) + ([3] * 1),
([0] * 588) + ([1] * 1332) + ([4] * 3) + ([5] * 50),
([0] * 588) + ([1] * 1332) + ([2] * 30) + ([6] * 95),
([0] * 588) + ([1] * 1332) + ([4] * 3) + ([7] * 174),
([0] * 588) + ([8] * 1539),
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
Expand Down Expand Up @@ -57,3 +66,22 @@ def test_mixed_requests(
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
def test_unstable_prompt_sequence(
vllm_runner,
backend: str,
monkeypatch,
) -> None:
override_backend_env_variable(monkeypatch, backend)

with vllm_runner(
"Qwen/Qwen2.5-0.5B-Instruct",
enable_chunked_prefill=True,
enable_prefix_caching=True,
max_model_len=4096,
) as vllm_model:
for prompt in UNSTABLE_PROMPT_SEQUENCE:
vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
SamplingParams(max_tokens=1))
9 changes: 6 additions & 3 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def _add_seq_group(
chunked_prefill_enabled: bool):
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
computed_block_nums = inter_data.computed_block_nums

for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
Expand All @@ -164,10 +163,14 @@ def _add_seq_group(
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if inter_data.prefix_cache_hit:
block_table = computed_block_nums
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)

# Compute slot mapping.
Expand Down

0 comments on commit 55650c8

Please sign in to comment.