From 9b73a2f498e8bf8e18305a8afea84536e9330088 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 21 Aug 2024 12:23:22 -0400 Subject: [PATCH] [Spec Decoding] Use target model max length as default for draft model (#7706) --- vllm/config.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 0d5d098bc8858..7e62a727115ef 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -127,6 +127,7 @@ def __init__( rope_theta: Optional[float] = None, tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, quantization: Optional[str] = None, quantization_param_path: Optional[str] = None, enforce_eager: Optional[bool] = None, @@ -210,7 +211,8 @@ def __init__( hf_config=self.hf_text_config, max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, - sliding_window_len=self.get_hf_config_sliding_window()) + sliding_window_len=self.get_hf_config_sliding_window(), + spec_target_max_model_len=spec_target_max_model_len) self.served_model_name = get_served_model_name(model, served_model_name) self.multimodal_config = self._init_multimodal_config( @@ -1134,6 +1136,7 @@ def maybe_create_spec_config( code_revision=draft_code_revision, tokenizer_revision=target_model_config.tokenizer_revision, max_model_len=None, + spec_target_max_model_len=target_model_config.max_model_len, quantization=draft_quantization, enforce_eager=target_model_config.enforce_eager, max_seq_len_to_capture=target_model_config. @@ -1563,6 +1566,7 @@ def _get_and_verify_max_len( max_model_len: Optional[int], disable_sliding_window: bool, sliding_window_len: Optional[int], + spec_target_max_model_len: Optional[int] = None, ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") @@ -1605,6 +1609,11 @@ def _get_and_verify_max_len( # If max_model_len is specified, we use it. return max_model_len + if spec_target_max_model_len is not None: + # If this is a speculative draft model, we use the max model len + # from the target model. + return spec_target_max_model_len + default_max_len = 2048 logger.warning( "The model's config.json does not contain any of the following "