diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index b63a01fa5c591..7278674a00623 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -409,8 +409,10 @@ def execute_model( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots - if (all(sgm.is_prompt - for sgm in execute_model_req.seq_group_metadata_list)): + all_prompt = (all( + sgm.is_prompt + for sgm in execute_model_req.seq_group_metadata_list)) + if (all_prompt): assert num_lookahead_slots == 0, ( "Prompt only runs should have num_lookahead_slots equal to 0. " "This should never happen, please file a bug at "