Skip to content

Commit

Permalink
Remove additional broadcast needed for proposer prefill pass
Browse files Browse the repository at this point in the history
  • Loading branch information
sroy745 authored and andoorve committed Nov 25, 2024
1 parent 2b7fcc4 commit e875afa
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,13 @@ def execute_model(
disable_all_speculation = self._should_disable_all_speculation(
execute_model_req)
num_lookahead_slots = execute_model_req.num_lookahead_slots
all_prompt = None
atleast_one_prompt = False
for sgm in execute_model_req.seq_group_metadata_list:
all_prompt = (sgm.is_prompt if all_prompt is None else all_prompt
and sgm.is_prompt)
atleast_one_prompt = atleast_one_prompt or sgm.is_prompt

all_prompt = (all(
sgm.is_prompt
for sgm in execute_model_req.seq_group_metadata_list))
if all_prompt:
assert num_lookahead_slots == 0, (
"Prompt only runs should have num_lookahead_slots equal to 0. "
Expand Down Expand Up @@ -448,6 +451,15 @@ def execute_model(
num_lookahead_slots=num_lookahead_slots,
no_spec=no_spec,
disable_all_speculation=disable_all_speculation,
# When both chunked prefill and speculative decoding are enabled
# it is possible that the same batch contains both prefill
# and decodes. If that happens in the scorer we run the batch
# as one single forward pass. However, in the proposer we
# run them as 2 different batches - one for prefill and
# the other for decodes. The variable indicates to the non-driver
# worker that there are prefills as part of the speculative batch
# and hence it needs to run an extra prefill forward pass.
run_spec_proposer_for_prefill=atleast_one_prompt,
)
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)

Expand Down Expand Up @@ -659,8 +671,7 @@ def _run_non_driver_rank(self) -> bool:

if not data["no_spec"]:
self.scorer_worker.execute_model()
data = broadcast_tensor_dict(src=self._driver_rank)
if data["run_spec_proposer"]:
if data["run_spec_proposer_for_prefill"]:
self.proposer_worker.execute_model()

return True
Expand Down Expand Up @@ -715,8 +726,6 @@ def _run_speculative_decoding_step(
idx for idx in non_spec_indices
if execute_model_req.seq_group_metadata_list[idx].is_prompt
]
broadcast_dict = dict(run_spec_proposer=bool(non_spec_indices))
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
if len(non_spec_indices):
all_hidden_states = proposal_scores.hidden_states
# TODO fix `return_hidden_states`, same as in `_run_no_spec`
Expand Down

0 comments on commit e875afa

Please sign in to comment.