From 10f69a414401687ae52f3a5425330d5e2f42720c Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Tue, 12 Nov 2024 19:12:20 +0000 Subject: [PATCH] Move fix to scheduler Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/core/scheduler.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index af4671ec29be9..102a0ebf9cdd6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1148,15 +1148,23 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) # Put prefills first due to Attention backend ordering assumption. + scheduled_seq_groups = (prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups) + num_prefill_groups = (len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)) + # If all prompts, then we set num_lookahead_slots to 0 + # this alloows us to go through the `no_spec` path in + # `spec_decode_worker.py` + all_prefills = (len(scheduled_seq_groups) == num_prefill_groups) + num_lookahead_slots = (0 if all_prefills else + running_scheduled.num_lookahead_slots) return SchedulerOutputs( - scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.prefill_seq_groups + - swapped_in.prefill_seq_groups + - running_scheduled.decode_seq_groups + - swapped_in.decode_seq_groups), - num_prefill_groups=(len(prefills.seq_groups) + - len(swapped_in.prefill_seq_groups) + - len(running_scheduled.prefill_seq_groups)), + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, @@ -1164,7 +1172,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, - num_lookahead_slots=running_scheduled.num_lookahead_slots, + num_lookahead_slots=num_lookahead_slots, running_queue_size=len(self.running), preempted=(len(running_scheduled.preempted) + len(running_scheduled.swapped_out)),