Skip to content

Commit

Permalink
multi-step + chunked-prefill + flashinfer
Browse files Browse the repository at this point in the history
  • Loading branch information
elfiegg committed Dec 20, 2024
1 parent 47a0b61 commit 32f925b
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 114 deletions.
10 changes: 10 additions & 0 deletions csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel(
long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
int const* block_tables_ptr, int64_t const block_tables_stride,
int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}
int num_query_blocks = div_ceil(num_queries, num_threads);

if (blockIdx.x < num_query_blocks) {
Expand Down
19 changes: 17 additions & 2 deletions tests/multi_step/test_correctness_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pytest

from tests.kernels.utils import override_backend_env_variable

from ..models.utils import check_logprobs_close, check_outputs_equal

MODELS = [
Expand All @@ -19,10 +21,11 @@
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_multi_step_llm(
hf_runner,
vllm_runner,
Expand All @@ -36,6 +39,8 @@ def test_multi_step_llm(
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
Expand Down Expand Up @@ -63,6 +68,7 @@ def test_multi_step_llm(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> 1 logprob returned.
"""
override_backend_env_variable(monkeypatch, attention_backend)

prompts = example_prompts
if len(prompts) < num_prompts:
Expand Down Expand Up @@ -110,10 +116,11 @@ def test_multi_step_llm(
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_multi_step_llm_w_prompt_logprobs(
vllm_runner,
example_prompts,
Expand All @@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs(
num_prompts: int,
num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
Expand Down Expand Up @@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs(
note that this argument is not supported by the
OpenAI completions endpoint.
"""
override_backend_env_variable(monkeypatch, attention_backend)

prompts = example_prompts
if len(prompts) < num_prompts:
Expand Down Expand Up @@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs(
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"])
def test_multi_step_llm_chunked_prefill_prefix_cache(
vllm_runner,
example_prompts,
Expand All @@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None:
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
Expand Down Expand Up @@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
#
# The Incorrect scheduling behavior - if it occurs - will cause an exception
# in the model runner resulting from `do_sample=False`.
override_backend_env_variable(monkeypatch, attention_backend)

assert len(example_prompts) >= 2
challenge_prompts = copy.deepcopy(example_prompts)
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient '
Expand Down
Loading

0 comments on commit 32f925b

Please sign in to comment.