diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index d249b37c780e4..676ac5eb3609d 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -120,6 +120,9 @@ def sampler_output( indices_of_seq_with_bonus_tokens) model_outputs.append(model_output) + # move indices to device to avoid stream sync + indices_of_seq_with_bonus_tokens = torch.tensor( + indices_of_seq_with_bonus_tokens, device=self.device) filtered_model_outputs = self._filter_model_output( model_outputs, indices_of_seq_with_bonus_tokens) return filtered_model_outputs, True @@ -189,7 +192,7 @@ def _expand_execute_model_request( @staticmethod def _filter_model_output( expanded_batch_outputs: List[SamplerOutput], - output_indices_to_retain: List[int]) -> List[SamplerOutput]: + output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]: """ Filters the model output to include only the specified sequence outputs. This method contracts the expanded batch output from the @@ -199,8 +202,8 @@ def _filter_model_output( Args: expanded_batch_output (List[SamplerOutput]): The expanded output batch from the model. - output_indices_to_retain (List[int]): Indices of the model outputs - to retain. + output_indices_to_retain (torch.Tensor): Indices of the model + outputs to retain. Returns: List[SamplerOutput]: A list containing the filtered model