Skip to content

Commit

Permalink
Move fix to scheduler
Browse files Browse the repository at this point in the history
Signed-off-by: andoorve <[email protected]>
  • Loading branch information
andoorve committed Nov 12, 2024
1 parent 902daaa commit 10f69a4
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,23 +1148,31 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
# Put prefills first due to Attention backend ordering assumption.
scheduled_seq_groups = (prefills.seq_groups +
running_scheduled.prefill_seq_groups +
swapped_in.prefill_seq_groups +
running_scheduled.decode_seq_groups +
swapped_in.decode_seq_groups)
num_prefill_groups = (len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups))
# If all prompts, then we set num_lookahead_slots to 0
# this alloows us to go through the `no_spec` path in
# `spec_decode_worker.py`
all_prefills = (len(scheduled_seq_groups) == num_prefill_groups)
num_lookahead_slots = (0 if all_prefills else
running_scheduled.num_lookahead_slots)
return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +
running_scheduled.prefill_seq_groups +
swapped_in.prefill_seq_groups +
running_scheduled.decode_seq_groups +
swapped_in.decode_seq_groups),
num_prefill_groups=(len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups)),
scheduled_seq_groups=scheduled_seq_groups,
num_prefill_groups=num_prefill_groups,
num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=running_scheduled.blocks_to_copy +
swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
num_lookahead_slots=num_lookahead_slots,
running_queue_size=len(self.running),
preempted=(len(running_scheduled.preempted) +
len(running_scheduled.swapped_out)),
Expand Down

0 comments on commit 10f69a4

Please sign in to comment.