diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 2c7c47412f23d..eaaf004df38b2 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -413,19 +413,24 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots): assert out.num_batched_tokens == max_num_batched_tokens -def test_chunked_prefill_spec_prefill(): - """Verify preempt works with chunked prefill requests""" +@pytest.mark.parametrize("num_scheduler_steps", [1, 5]) +def test_chunked_prefill_spec_prefill(num_scheduler_steps): + """Verify that the num_lookahead_slots is set appropriately for an all""" + """prefill batch depending on whether multi-step scheduling is enabled""" + """or not""" block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 + num_lookahead_slots = 4 scheduler_config = SchedulerConfig( "generate", max_num_batched_tokens, max_seqs, max_model_len, enable_chunked_prefill=True, - num_lookahead_slots=5, + num_lookahead_slots=num_lookahead_slots, + num_scheduler_steps=num_scheduler_steps, ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 @@ -433,7 +438,7 @@ def test_chunked_prefill_spec_prefill(): scheduler = Scheduler(scheduler_config, cache_config, None) _, seq_group = create_dummy_prompt("1", - prompt_length=60, + prompt_length=30, block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) @@ -441,9 +446,10 @@ def test_chunked_prefill_spec_prefill(): # prefill scheduled now. assert len(out.scheduled_seq_groups) == 1 assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() assert out.num_batched_tokens == max_num_batched_tokens - assert out.num_lookahead_slots == 0 + print(out.num_lookahead_slots) + assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else + num_lookahead_slots) def test_chunked_prefill_max_seqs(): diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index 7fea161a4a7be..02cba92795142 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -160,7 +160,8 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, seed: int): - """Verify spec decode works well with smaller tp for draft models. + """Verify spec decode works well with same and different TP size for + the draft model with chunked prefill. """ run_equality_correctness_test_tp(model, common_llm_kwargs,