From e875afaa50836b34e070b6ee6f5002f8780aba6c Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 22 Nov 2024 21:32:57 +0000 Subject: [PATCH] Remove additional broadcast needed for proposer prefill pass --- vllm/spec_decode/spec_decode_worker.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 35e2eb625705a..70f15c72fbee4 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -408,10 +408,13 @@ def execute_model( disable_all_speculation = self._should_disable_all_speculation( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots + all_prompt = None + atleast_one_prompt = False + for sgm in execute_model_req.seq_group_metadata_list: + all_prompt = (sgm.is_prompt if all_prompt is None else all_prompt + and sgm.is_prompt) + atleast_one_prompt = atleast_one_prompt or sgm.is_prompt - all_prompt = (all( - sgm.is_prompt - for sgm in execute_model_req.seq_group_metadata_list)) if all_prompt: assert num_lookahead_slots == 0, ( "Prompt only runs should have num_lookahead_slots equal to 0. " @@ -448,6 +451,15 @@ def execute_model( num_lookahead_slots=num_lookahead_slots, no_spec=no_spec, disable_all_speculation=disable_all_speculation, + # When both chunked prefill and speculative decoding are enabled + # it is possible that the same batch contains both prefill + # and decodes. If that happens in the scorer we run the batch + # as one single forward pass. However, in the proposer we + # run them as 2 different batches - one for prefill and + # the other for decodes. The variable indicates to the non-driver + # worker that there are prefills as part of the speculative batch + # and hence it needs to run an extra prefill forward pass. + run_spec_proposer_for_prefill=atleast_one_prompt, ) broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) @@ -659,8 +671,7 @@ def _run_non_driver_rank(self) -> bool: if not data["no_spec"]: self.scorer_worker.execute_model() - data = broadcast_tensor_dict(src=self._driver_rank) - if data["run_spec_proposer"]: + if data["run_spec_proposer_for_prefill"]: self.proposer_worker.execute_model() return True @@ -715,8 +726,6 @@ def _run_speculative_decoding_step( idx for idx in non_spec_indices if execute_model_req.seq_group_metadata_list[idx].is_prompt ] - broadcast_dict = dict(run_spec_proposer=bool(non_spec_indices)) - broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) if len(non_spec_indices): all_hidden_states = proposal_scores.hidden_states # TODO fix `return_hidden_states`, same as in `_run_no_spec`