Skip to content

Commit

Permalink
Nits and add multi step test
Browse files Browse the repository at this point in the history
Signed-off-by: andoorve <[email protected]>
  • Loading branch information
andoorve committed Nov 21, 2024
1 parent 3c266a2 commit 472ab34
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
18 changes: 12 additions & 6 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,37 +413,43 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots):
assert out.num_batched_tokens == max_num_batched_tokens


def test_chunked_prefill_spec_prefill():
"""Verify preempt works with chunked prefill requests"""
@pytest.mark.parametrize("num_scheduler_steps", [1, 5])
def test_chunked_prefill_spec_prefill(num_scheduler_steps):
"""Verify that the num_lookahead_slots is set appropriately for an all"""
"""prefill batch depending on whether multi-step scheduling is enabled"""
"""or not"""
block_size = 4
max_seqs = 30
max_model_len = 200
max_num_batched_tokens = 30
num_lookahead_slots = 4
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True,
num_lookahead_slots=5,
num_lookahead_slots=num_lookahead_slots,
num_scheduler_steps=num_scheduler_steps,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16
cache_config.num_gpu_blocks = 16
scheduler = Scheduler(scheduler_config, cache_config, None)

_, seq_group = create_dummy_prompt("1",
prompt_length=60,
prompt_length=30,
block_size=block_size)
scheduler.add_seq_group(seq_group)
_, out = schedule_and_update_computed_tokens(scheduler)
# The request is chunked.
# prefill scheduled now.
assert len(out.scheduled_seq_groups) == 1
assert out.num_prefill_groups == 1
assert seq_group.is_prefill()
assert out.num_batched_tokens == max_num_batched_tokens
assert out.num_lookahead_slots == 0
print(out.num_lookahead_slots)
assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else
num_lookahead_slots)


def test_chunked_prefill_max_seqs():
Expand Down
3 changes: 2 additions & 1 deletion tests/spec_decode/e2e/test_integration_dist_tp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, seed: int):
"""Verify spec decode works well with smaller tp for draft models.
"""Verify spec decode works well with same and different TP size for
the draft model with chunked prefill.
"""
run_equality_correctness_test_tp(model,
common_llm_kwargs,
Expand Down

0 comments on commit 472ab34

Please sign in to comment.