diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 70f15c72fbee4..8f3c824c4d2c3 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -408,12 +408,14 @@ def execute_model( disable_all_speculation = self._should_disable_all_speculation( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots - all_prompt = None + all_prompt = True atleast_one_prompt = False + all_zero_spec_tokens = True for sgm in execute_model_req.seq_group_metadata_list: - all_prompt = (sgm.is_prompt if all_prompt is None else all_prompt - and sgm.is_prompt) + all_prompt = all_prompt and sgm.is_prompt atleast_one_prompt = atleast_one_prompt or sgm.is_prompt + all_zero_spec_tokens = all_zero_spec_tokens and ( + sgm.num_speculative_tokens == 0) if all_prompt: assert num_lookahead_slots == 0, ( @@ -430,9 +432,8 @@ def execute_model( # In any of these cases, the proposer and scorer workers # are called normally. # We expect `num_speculative_tokens` to be None for prefills. - no_spec = num_lookahead_slots == 0 or disable_all_speculation or all( - sgm.num_speculative_tokens == 0 - for sgm in execute_model_req.seq_group_metadata_list) + no_spec = (num_lookahead_slots == 0 or disable_all_speculation + or all_zero_spec_tokens) # Broadcast how many lookahead slots are scheduled for this step, and # whether all speculation is disabled, to all non-driver workers.