diff --git a/vllm/config.py b/vllm/config.py index 028f4eed8f4a2..0524514f6633a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1068,7 +1068,7 @@ def maybe_create_spec_config( draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( target_parallel_config, - speculative_draft_tensor_parallel_size)) + speculative_draft_tensor_parallel_size, draft_hf_config)) if num_speculative_tokens is None: raise ValueError( @@ -1136,15 +1136,23 @@ def _maybe_override_draft_max_model_len( @staticmethod def create_draft_parallel_config( target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int] + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig, ) -> ParallelConfig: """Create a parallel config for use by the draft worker. This is mostly a copy of the target parallel config, except the tp_size. """ if speculative_draft_tensor_parallel_size is None: - speculative_draft_tensor_parallel_size = \ - target_parallel_config.tensor_parallel_size + if draft_hf_config.model_type == "mlp_speculator": + speculative_draft_tensor_parallel_size = 1 + if target_parallel_config.tensor_parallel_size > 1: + logger.warning( + "MLPSpeculator cannot currently be run with tp>1; " + "setting speculative_draft_tensor_parallel_size=1") + else: + speculative_draft_tensor_parallel_size = \ + target_parallel_config.tensor_parallel_size elif speculative_draft_tensor_parallel_size != 1: # TODO(wooyeon): allow tp values larger than 1 raise ValueError(