From 2185ac96824f51cca28fdbbf4f7f2d1446c54560 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 3 Dec 2024 02:17:00 -0500 Subject: [PATCH] [Core][Performance] Add XGrammar support for guided decoding and set it as default (#10785) Signed-off-by: Aaron Pham Signed-off-by: mgoin Co-authored-by: mgoin --- docs/source/conf.py | 1 + requirements-common.txt | 1 + tests/entrypoints/llm/test_guided_generate.py | 27 + .../model_executor/test_guided_processors.py | 3 +- vllm/config.py | 1028 ++++++++++------- vllm/engine/arg_utils.py | 9 +- vllm/engine/async_llm_engine.py | 18 +- vllm/engine/llm_engine.py | 15 +- vllm/engine/multiprocessing/client.py | 5 +- .../guided_decoding/__init__.py | 73 +- .../guided_decoding/xgrammar_decoding.py | 251 ++++ 11 files changed, 1012 insertions(+), 419 deletions(-) create mode 100644 vllm/model_executor/guided_decoding/xgrammar_decoding.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 96ad9a4c26b09..260fc588d05d5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -112,6 +112,7 @@ def setup(app): "tensorizer", "pynvml", "outlines", + "xgrammar," "librosa", "soundfile", "gguf", diff --git a/requirements-common.txt b/requirements-common.txt index f62ad66a1ecc4..2df21eae87d00 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines >= 0.0.43, < 0.1 +xgrammar typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 partial-json-parser # used for parsing partial JSON outputs diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 67c79415f322a..c3706f696b264 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -159,3 +159,30 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm): sampling_params=sampling_params, use_tqdm=True, guided_options_request=dict(guided_regex=sample_regex)) + + +@pytest.mark.skip_global_cleanup +def test_guided_json_object(llm): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=100, + guided_decoding=GuidedDecodingParams(json_object=True)) + + outputs = llm.generate( + prompts=("Generate a JSON object describing a person with name " + "and age for John Smith who is 31 years old."), + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 45fab8e96b968..9f4d81b583141 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.asyncio -@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) +@pytest.mark.parametrize("backend", + ["outlines", "lm-format-enforcer", "xgrammar"]) async def test_guided_logits_processor_black_box(backend: str, sample_regex, sample_json_schema): tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') diff --git a/vllm/config.py b/vllm/config.py index 1c190da1d327e..57296b07a6d43 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,8 +3,22 @@ import json import warnings from dataclasses import dataclass, field, replace -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List, - Literal, Mapping, Optional, Set, Tuple, Type, Union) +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Final, + List, + Literal, + Mapping, + Optional, + Set, + Tuple, + Type, + Union, +) import torch from transformers import PretrainedConfig @@ -16,21 +30,34 @@ from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( - ConfigFormat, get_config, get_hf_image_processor_config, - get_hf_text_config, get_pooling_config, - get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) -from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - identity, print_warning_once) + ConfigFormat, + get_config, + get_hf_image_processor_config, + get_hf_text_config, + get_pooling_config, + get_sentence_transformer_tokenizer_config, + is_encoder_decoder, + uses_mrope, +) +from vllm.utils import ( + GiB_bytes, + cuda_device_count_stateless, + get_cpu_memory, + identity, + print_warning_once, +) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup from vllm.executor.executor_base import ExecutorBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, + ) from vllm.model_executor.model_loader.loader import BaseModelLoader from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( - BaseTokenizerGroup) + BaseTokenizerGroup, + ) else: QuantizationConfig = None @@ -44,8 +71,9 @@ # "draft" is only used internally for speculative decoding _Task = Literal["generate", "embedding", "draft"] -HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig], - PretrainedConfig]] +HfOverrides = Union[ + Dict[str, Any], Callable[[PretrainedConfig], PretrainedConfig] +] class ModelConfig: @@ -128,38 +156,39 @@ class ModelConfig: """ def __init__( - self, - model: str, - task: Union[TaskOption, _Task], - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - allowed_local_media_path: str = "", - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[Dict[str, Any]] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - config_format: ConfigFormat = ConfigFormat.AUTO, - chat_template_text_format: str = "string", - hf_overrides: Optional[HfOverrides] = None, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - override_neuron_config: Optional[Dict[str, Any]] = None, - override_pooler_config: Optional["PoolerConfig"] = None) -> None: + self, + model: str, + task: Union[TaskOption, _Task], + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + allowed_local_media_path: str = "", + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + config_format: ConfigFormat = ConfigFormat.AUTO, + chat_template_text_format: str = "string", + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + override_neuron_config: Optional[Dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None, + ) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -182,14 +211,18 @@ def __init__( if rope_scaling is not None: hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling} hf_overrides_kw.update(hf_override) - msg = ("`--rope-scaling` will be removed in a future release. " - f"'Please instead use `--hf-overrides '{hf_override!r}'`") + msg = ( + "`--rope-scaling` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_override!r}'`" + ) warnings.warn(DeprecationWarning(msg), stacklevel=2) if rope_theta is not None: hf_override = {"rope_theta": rope_theta} hf_overrides_kw.update(hf_override) - msg = ("`--rope-theta` will be removed in a future release. " - f"'Please instead use `--hf-overrides '{hf_override!r}'`") + msg = ( + "`--rope-theta` will be removed in a future release. " + f"'Please instead use `--hf-overrides '{hf_override!r}'`" + ) warnings.warn(DeprecationWarning(msg), stacklevel=2) # The tokenizer version is consistent with the model version by default. @@ -205,15 +238,22 @@ def __init__( self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init - hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, config_format, **hf_overrides_kw) + hf_config = get_config( + self.model, + trust_remote_code, + revision, + code_revision, + config_format, + **hf_overrides_kw, + ) hf_config = hf_overrides_fn(hf_config) self.hf_config = hf_config self.hf_text_config = get_hf_text_config(self.hf_config) self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( - self.model, revision) + self.model, revision + ) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc self.chat_template_text_format = chat_template_text_format @@ -225,18 +265,21 @@ def __init__( sliding_window = getattr(self.hf_text_config, "sliding_window", None) has_interleaved_attention = (sliding_window is not None) and ( - isinstance(sliding_window, list) or - (self.hf_text_config.model_type in ["gemma2"])) + isinstance(sliding_window, list) + or (self.hf_text_config.model_type in ["gemma2"]) + ) - if (not self.disable_sliding_window and has_interleaved_attention): + if not self.disable_sliding_window and has_interleaved_attention: sliding_window_len_min = get_min_sliding_window( - self.hf_text_config.sliding_window) + self.hf_text_config.sliding_window + ) print_warning_once( f"{self.hf_text_config.model_type} has interleaved attention, " "which is currently not supported by vLLM. Disabling sliding " "window and capping the max length to the sliding window size " - f"({sliding_window_len_min}).") + f"({sliding_window_len_min})." + ) self.disable_sliding_window = True self.max_model_len = _get_and_verify_max_len( @@ -245,11 +288,12 @@ def __init__( disable_sliding_window=self.disable_sliding_window, sliding_window_len=self.get_hf_config_sliding_window(), spec_target_max_model_len=spec_target_max_model_len, - encoder_config=self.encoder_config) - self.served_model_name = get_served_model_name(model, - served_model_name) + encoder_config=self.encoder_config, + ) + self.served_model_name = get_served_model_name(model, served_model_name) self.multimodal_config = self._init_multimodal_config( - limit_mm_per_prompt) + limit_mm_per_prompt + ) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() @@ -278,20 +322,22 @@ def _init_multimodal_config( return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) if limit_mm_per_prompt: - raise ValueError("`limit_mm_per_prompt` is only supported for " - "multimodal models.") + raise ValueError( + "`limit_mm_per_prompt` is only supported for " + "multimodal models." + ) return None def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( - self.model, self.revision) + self.model, self.revision + ) def _init_pooler_config( self, override_pooler_config: Optional["PoolerConfig"], ) -> Optional["PoolerConfig"]: - if self.task == "embedding": user_config = override_pooler_config or PoolerConfig() @@ -319,7 +365,8 @@ def _verify_tokenizer_mode(self) -> None: if tokenizer_mode not in ["auto", "slow", "mistral"]: raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " - "either 'auto', 'slow' or 'mistral'.") + "either 'auto', 'slow' or 'mistral'." + ) self.tokenizer_mode = tokenizer_mode def _resolve_task( @@ -349,12 +396,16 @@ def _resolve_task( if len(supported_tasks) > 1: logger.info( "This model supports multiple tasks: %s. " - "Defaulting to '%s'.", supported_tasks, selected_task) + "Defaulting to '%s'.", + supported_tasks, + selected_task, + ) else: if task_option not in supported_tasks: msg = ( f"This model does not support the '{task_option}' task. " - f"Supported tasks: {supported_tasks}") + f"Supported tasks: {supported_tasks}" + ) raise ValueError(msg) selected_task = task_option @@ -371,13 +422,24 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = [ - "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", - "fbgemm_fp8" + "awq", + "gptq", + "fp8", + "compressed_tensors", + "compressed-tensors", + "fbgemm_fp8", ] optimized_quantization_methods = [ - "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", - "awq_marlin", "fbgemm_fp8", "compressed_tensors", - "compressed-tensors", "experts_int8" + "fp8", + "marlin", + "modelopt", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "fbgemm_fp8", + "compressed_tensors", + "compressed-tensors", + "experts_int8", ] tpu_supported_quantization = ["tpu_int8"] neuron_supported_quantization = ["neuron_quant"] @@ -393,7 +455,8 @@ def _verify_quantization(self) -> None: # Detect which checkpoint is it for _, method in QUANTIZATION_METHODS.items(): quantization_override = method.override_quantization_method( - quant_cfg, self.quantization) + quant_cfg, self.quantization + ) if quantization_override: quant_method = quantization_override self.quantization = quantization_override @@ -407,45 +470,63 @@ def _verify_quantization(self) -> None: "Quantization method specified in the model config " f"({quant_method}) does not match the quantization " f"method specified in the `quantization` argument " - f"({self.quantization}).") + f"({self.quantization})." + ) if self.quantization is not None: if self.quantization not in supported_quantization: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " - f"be one of {supported_quantization}.") - if current_platform.is_rocm( - ) and self.quantization not in rocm_supported_quantization: + f"be one of {supported_quantization}." + ) + if ( + current_platform.is_rocm() + and self.quantization not in rocm_supported_quantization + ): raise ValueError( f"{self.quantization} quantization is currently not " - f"supported in ROCm.") - if current_platform.is_tpu( - ) and self.quantization not in tpu_supported_quantization: + f"supported in ROCm." + ) + if ( + current_platform.is_tpu() + and self.quantization not in tpu_supported_quantization + ): raise ValueError( f"{self.quantization} quantization is currently not " - f"supported in TPU Backend.") + f"supported in TPU Backend." + ) if self.quantization not in optimized_quantization_methods: logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " - "non-quantized models.", self.quantization) - if (self.quantization == "awq" and current_platform.is_rocm() - and not envs.VLLM_USE_TRITON_AWQ): + "non-quantized models.", + self.quantization, + ) + if ( + self.quantization == "awq" + and current_platform.is_rocm() + and not envs.VLLM_USE_TRITON_AWQ + ): logger.warning( "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" - " is not set, enabling VLLM_USE_TRITON_AWQ.") + " is not set, enabling VLLM_USE_TRITON_AWQ." + ) envs.VLLM_USE_TRITON_AWQ = True - if current_platform.is_neuron( - ) and self.quantization not in neuron_supported_quantization: + if ( + current_platform.is_neuron() + and self.quantization not in neuron_supported_quantization + ): raise ValueError( f"{self.quantization} quantization is currently not " - f"supported in Neuron Backend.") + f"supported in Neuron Backend." + ) def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: self.max_seq_len_to_capture = self.max_model_len - self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, - self.max_model_len) + self.max_seq_len_to_capture = min( + self.max_seq_len_to_capture, self.max_model_len + ) def _verify_bnb_config(self) -> None: """ @@ -453,31 +534,40 @@ def _verify_bnb_config(self) -> None: yet support CUDA graph. """ is_bitsandbytes = self.quantization == "bitsandbytes" - has_quantization_config = (getattr(self.hf_config, - "quantization_config", None) - is not None) - is_8bit = (self.hf_config.quantization_config.get( - "load_in_8bit", False) if has_quantization_config else False) - if all([ + has_quantization_config = ( + getattr(self.hf_config, "quantization_config", None) is not None + ) + is_8bit = ( + self.hf_config.quantization_config.get("load_in_8bit", False) + if has_quantization_config + else False + ) + if all( + [ is_bitsandbytes, has_quantization_config, is_8bit, not self.enforce_eager, - ]): + ] + ): logger.warning( "CUDA graph is not supported on BitAndBytes 8bit yet, " - "fallback to the eager mode.") + "fallback to the eager mode." + ) self.enforce_eager = True - def verify_async_output_proc(self, parallel_config, speculative_config, - device_config) -> None: + def verify_async_output_proc( + self, parallel_config, speculative_config, device_config + ) -> None: if not self.use_async_output_proc: # Nothing to check return if parallel_config.pipeline_parallel_size > 1: - logger.warning("Async output processing can not be enabled " - "with pipeline parallel") + logger.warning( + "Async output processing can not be enabled " + "with pipeline parallel" + ) self.use_async_output_proc = False return @@ -487,13 +577,15 @@ def verify_async_output_proc(self, parallel_config, speculative_config, logger.warning( "Async output processing is only supported for CUDA, TPU, XPU " "and HPU." - "Disabling it for other platforms.") + "Disabling it for other platforms." + ) self.use_async_output_proc = False return if envs.VLLM_USE_RAY_SPMD_WORKER: logger.warning( - "Async output processing can not be enabled with ray spmd") + "Async output processing can not be enabled with ray spmd" + ) self.use_async_output_proc = False return @@ -503,7 +595,8 @@ def verify_async_output_proc(self, parallel_config, speculative_config, logger.warning( "To see benefits of async output processing, enable CUDA " "graph. Since, enforce-eager is enabled, async output " - "processor cannot be used") + "processor cannot be used" + ) self.use_async_output_proc = not self.enforce_eager return @@ -515,22 +608,26 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Reminder: Please update docs/source/serving/compatibility_matrix.rst # If the feature combo become valid if speculative_config: - logger.warning("Async output processing is not supported with" - " speculative decoding currently.") + logger.warning( + "Async output processing is not supported with" + " speculative decoding currently." + ) self.use_async_output_proc = False def verify_with_parallel_config( self, parallel_config: "ParallelConfig", ) -> None: - total_num_attention_heads = getattr(self.hf_text_config, - "num_attention_heads", 0) + total_num_attention_heads = getattr( + self.hf_text_config, "num_attention_heads", 0 + ) tensor_parallel_size = parallel_config.tensor_parallel_size if total_num_attention_heads % tensor_parallel_size != 0: raise ValueError( f"Total number of attention heads ({total_num_attention_heads})" " must be divisible by tensor parallel size " - f"({tensor_parallel_size}).") + f"({tensor_parallel_size})." + ) pipeline_parallel_size = parallel_config.pipeline_parallel_size if pipeline_parallel_size > 1: @@ -538,28 +635,33 @@ def verify_with_parallel_config( if not ModelRegistry.is_pp_supported_model(architectures): raise NotImplementedError( "Pipeline parallelism is not supported for this model. " - "Supported models implement the `SupportsPP` interface.") + "Supported models implement the `SupportsPP` interface." + ) if self.use_async_output_proc: - logger.warning("Async output processor is not supported with " - "pipeline parallelism currently. Disabling it.") + logger.warning( + "Async output processor is not supported with " + "pipeline parallelism currently. Disabling it." + ) self.use_async_output_proc = False def get_hf_config_sliding_window( - self) -> Union[Optional[int], List[Optional[int]]]: + self, + ) -> Union[Optional[int], List[Optional[int]]]: """Get the sliding window size, or None if disabled.""" # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in # addition to sliding window size. We check if that field is present # and if it's False, return None. - if (hasattr(self.hf_text_config, "use_sliding_window") - and not self.hf_text_config.use_sliding_window): + if ( + hasattr(self.hf_text_config, "use_sliding_window") + and not self.hf_text_config.use_sliding_window + ): return None return getattr(self.hf_text_config, "sliding_window", None) def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]: - """Get the sliding window size, or None if disabled. - """ + """Get the sliding window size, or None if disabled.""" # If user disables sliding window, return None. if self.disable_sliding_window: return None @@ -574,8 +676,10 @@ def get_hidden_size(self) -> int: def get_head_size(self) -> int: # TODO remove hard code - if hasattr(self.hf_text_config, "model_type" - ) and self.hf_text_config.model_type == 'deepseek_v2': + if ( + hasattr(self.hf_text_config, "model_type") + and self.hf_text_config.model_type == "deepseek_v2" + ): # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 return 256 @@ -586,8 +690,10 @@ def get_head_size(self) -> int: if hasattr(self.hf_text_config, "head_dim"): return self.hf_text_config.head_dim # FIXME(woosuk): This may not be true for all models. - return (self.hf_text_config.hidden_size // - self.hf_text_config.num_attention_heads) + return ( + self.hf_text_config.hidden_size + // self.hf_text_config.num_attention_heads + ) def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" @@ -598,9 +704,11 @@ def get_total_num_kv_heads(self) -> int: falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] new_decoder_arch_falcon = ( self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False)) - if not new_decoder_arch_falcon and getattr(self.hf_text_config, - "multi_query", False): + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_text_config, "multi_query", False + ): # Multi-query attention, only one KV head. # Currently, tensor parallelism is not supported in this case. return 1 @@ -611,8 +719,11 @@ def get_total_num_kv_heads(self) -> int: return self.hf_config.attn_config["kv_n_heads"] return self.hf_config.num_attention_heads if self.hf_config.model_type == "dbrx": - return getattr(self.hf_config.attn_config, "kv_n_heads", - self.hf_config.num_attention_heads) + return getattr( + self.hf_config.attn_config, + "kv_n_heads", + self.hf_config.num_attention_heads, + ) if self.is_attention_free: return 0 @@ -642,33 +753,37 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: # the tensor parallel size. We will replicate the KV heads in the # case where the number of KV heads is smaller than the tensor # parallel size so each GPU has at least one KV head. - return max(1, - total_num_kv_heads // parallel_config.tensor_parallel_size) + return max( + 1, total_num_kv_heads // parallel_config.tensor_parallel_size + ) - def get_num_attention_heads(self, - parallel_config: "ParallelConfig") -> int: + def get_num_attention_heads(self, parallel_config: "ParallelConfig") -> int: num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) return num_heads // parallel_config.tensor_parallel_size def get_num_layers(self, parallel_config: "ParallelConfig") -> int: from vllm.distributed.utils import get_pp_indices - total_num_hidden_layers = getattr(self.hf_text_config, - "num_hidden_layers", 0) + + total_num_hidden_layers = getattr( + self.hf_text_config, "num_hidden_layers", 0 + ) pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size pp_size = parallel_config.pipeline_parallel_size start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) return end - start - def get_num_attention_layers(self, - parallel_config: "ParallelConfig") -> int: + def get_num_attention_layers( + self, parallel_config: "ParallelConfig" + ) -> int: if self.is_attention_free: return 0 num_layers = self.get_num_layers(parallel_config) # Transformers supports layers_block_type @property - layers = getattr(self.hf_config, "layers_block_type", - ["attention"] * num_layers) + layers = getattr( + self.hf_config, "layers_block_type", ["attention"] * num_layers + ) return len([t for t in layers if t == "attention"]) def get_multimodal_config(self) -> "MultiModalConfig": @@ -749,7 +864,8 @@ def _verify_args(self) -> None: if self.gpu_memory_utilization > 1.0: raise ValueError( "GPU memory utilization must be less than 1.0. Got " - f"{self.gpu_memory_utilization}.") + f"{self.gpu_memory_utilization}." + ) def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": @@ -759,7 +875,8 @@ def _verify_cache_dtype(self) -> None: "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor") + "scaling factor" + ) else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") @@ -770,7 +887,8 @@ def _verify_prefix_caching(self) -> None: if self.sliding_window is not None: raise NotImplementedError( "Prefix caching is not supported with sliding window. " - "Run with --disable-sliding-window to use prefix caching.") + "Run with --disable-sliding-window to use prefix caching." + ) def verify_with_parallel_config( self, @@ -782,9 +900,11 @@ def verify_with_parallel_config( num_gpus_per_node = parallel_config.tensor_parallel_size cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node - msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " - f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " - "is allocated for the swap space.") + msg = ( + f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space." + ) if cpu_memory_usage > 0.7 * total_cpu_memory: raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: @@ -802,22 +922,25 @@ class TokenizerPoolConfig: The way the config will be used depends on the pool type. """ + pool_size: int pool_type: Union[str, Type["BaseTokenizerGroup"]] extra_config: dict def __post_init__(self): - if self.pool_type not in ("ray", ) and not isinstance( - self.pool_type, type): + if self.pool_type not in ("ray",) and not isinstance( + self.pool_type, type + ): raise ValueError(f"Unknown pool type: {self.pool_type}") if not isinstance(self.extra_config, dict): raise ValueError("extra_config must be a dictionary.") @classmethod def create_config( - cls, tokenizer_pool_size: int, + cls, + tokenizer_pool_size: int, tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]], - tokenizer_pool_extra_config: Optional[Union[str, dict]] + tokenizer_pool_extra_config: Optional[Union[str, dict]], ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. @@ -833,13 +956,17 @@ def create_config( if tokenizer_pool_size: if isinstance(tokenizer_pool_extra_config, str): tokenizer_pool_extra_config_parsed = json.loads( - tokenizer_pool_extra_config) + tokenizer_pool_extra_config + ) else: tokenizer_pool_extra_config_parsed = ( - tokenizer_pool_extra_config or {}) - tokenizer_pool_config = cls(tokenizer_pool_size, - tokenizer_pool_type, - tokenizer_pool_extra_config_parsed) + tokenizer_pool_extra_config or {} + ) + tokenizer_pool_config = cls( + tokenizer_pool_size, + tokenizer_pool_type, + tokenizer_pool_extra_config_parsed, + ) else: tokenizer_pool_config = None return tokenizer_pool_config @@ -861,43 +988,46 @@ class LoadFormat(str, enum.Enum): @dataclass class LoadConfig: """ - download_dir: Directory to download and load the weights, default to the - default cache directory of huggingface. - load_format: The format of the model weights to load: - "auto" will try to load the weights in the safetensors format and - fall back to the pytorch bin format if safetensors format is - not available. - "pt" will load the weights in the pytorch bin format. - "safetensors" will load the weights in the safetensors format. - "npcache" will load the weights in pytorch format and store - a numpy cache to speed up the loading. - "dummy" will initialize the weights with random values, which is - mainly for profiling. - "tensorizer" will use CoreWeave's tensorizer library for - fast weight loading. - "bitsandbytes" will load nf4 type weights. - ignore_patterns: The list of patterns to ignore when loading the model. - Default to "original/**/*" to avoid repeated loading of llama's - checkpoints. + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. """ load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO download_dir: Optional[str] = None model_loader_extra_config: Optional[Union[str, dict]] = field( - default_factory=dict) + default_factory=dict + ) ignore_patterns: Optional[Union[List[str], str]] = None def __post_init__(self): model_loader_extra_config = self.model_loader_extra_config or {} if isinstance(model_loader_extra_config, str): self.model_loader_extra_config = json.loads( - model_loader_extra_config) + model_loader_extra_config + ) self._verify_load_format() if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: logger.info( "Ignoring the following patterns when downloading weights: %s", - self.ignore_patterns) + self.ignore_patterns, + ) else: self.ignore_patterns = ["original/**/*"] @@ -909,16 +1039,20 @@ def _verify_load_format(self) -> None: self.load_format = LoadFormat(load_format) rocm_not_supported_load_format: List[str] = [] - if current_platform.is_rocm( - ) and load_format in rocm_not_supported_load_format: + if ( + current_platform.is_rocm() + and load_format in rocm_not_supported_load_format + ): rocm_supported_load_format = [ - f for f in LoadFormat.__members__ + f + for f in LoadFormat.__members__ if (f not in rocm_not_supported_load_format) ] raise ValueError( f"load format '{load_format}' is not supported in ROCm. " f"Supported load formats are " - f"{rocm_supported_load_format}") + f"{rocm_supported_load_format}" + ) class ParallelConfig: @@ -957,8 +1091,9 @@ def __init__( tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, ray_workers_use_nsight: bool = False, placement_group: Optional["PlacementGroup"] = None, - distributed_executor_backend: Optional[Union[ - str, Type["ExecutorBase"]]] = None, + distributed_executor_backend: Optional[ + Union[str, Type["ExecutorBase"]] + ] = None, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size @@ -974,51 +1109,63 @@ def __init__( if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" elif not self.use_ray: - raise ValueError(f"worker-use-ray can't be used with " - f"distributed executor backend " - f"'{self.distributed_executor_backend}'.") + raise ValueError( + f"worker-use-ray can't be used with " + f"distributed executor backend " + f"'{self.distributed_executor_backend}'." + ) if current_platform.is_tpu() and self.world_size > 1: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" if self.distributed_executor_backend != "ray": raise ValueError( - "TPU backend only supports Ray for distributed inference.") + "TPU backend only supports Ray for distributed inference." + ) if current_platform.is_hpu() and self.world_size > 1: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" if self.distributed_executor_backend != "ray": raise ValueError( - "HPU backend only supports Ray for distributed inference.") + "HPU backend only supports Ray for distributed inference." + ) if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. from vllm.executor import ray_utils + backend = "mp" ray_found = ray_utils.ray_is_available() - if (current_platform.is_cuda() - and cuda_device_count_stateless() < self.world_size): + if ( + current_platform.is_cuda() + and cuda_device_count_stateless() < self.world_size + ): if not ray_found: - raise ValueError("Unable to load Ray which is " - "required for multi-node inference, " - "please install Ray with `pip install " - "ray`.") from ray_utils.ray_import_err + raise ValueError( + "Unable to load Ray which is " + "required for multi-node inference, " + "please install Ray with `pip install " + "ray`." + ) from ray_utils.ray_import_err backend = "ray" elif ray_found: if self.placement_group: backend = "ray" else: from ray import is_initialized as ray_is_initialized + if ray_is_initialized(): from ray.util import get_current_placement_group + if get_current_placement_group(): backend = "ray" self.distributed_executor_backend = backend - logger.info("Defaulting to use %s for distributed inference", - backend) + logger.info( + "Defaulting to use %s for distributed inference", backend + ) self._verify_args() self.rank: int = 0 @@ -1027,31 +1174,40 @@ def __init__( def use_ray(self) -> bool: return self.distributed_executor_backend == "ray" or ( isinstance(self.distributed_executor_backend, type) - and self.distributed_executor_backend.uses_ray) + and self.distributed_executor_backend.uses_ray + ) def _verify_args(self) -> None: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase if self.distributed_executor_backend not in ( - "ray", "mp", None) and not (isinstance( - self.distributed_executor_backend, type) and issubclass( - self.distributed_executor_backend, ExecutorBase)): + "ray", + "mp", + None, + ) and not ( + isinstance(self.distributed_executor_backend, type) + and issubclass(self.distributed_executor_backend, ExecutorBase) + ): raise ValueError( "Unrecognized distributed executor backend " f"{self.distributed_executor_backend}. Supported " - "values are 'ray', 'mp' or custom ExecutorBase subclass.") + "values are 'ray', 'mp' or custom ExecutorBase subclass." + ) if self.use_ray: from vllm.executor import ray_utils + ray_utils.assert_ray_available() if current_platform.is_rocm(): self.disable_custom_all_reduce = True logger.info( "Disabled the custom all-reduce kernel because it is not " - "supported on AMD GPUs.") + "supported on AMD GPUs." + ) if self.ray_workers_use_nsight and not self.use_ray: - raise ValueError("Unable to use nsight profiling unless workers " - "run with Ray.") + raise ValueError( + "Unable to use nsight profiling unless workers " "run with Ray." + ) class SchedulerConfig: @@ -1086,20 +1242,22 @@ class SchedulerConfig: policy: The scheduling policy to use. "fcfs" (default) or "priority". """ - def __init__(self, - task: _Task, - max_num_batched_tokens: Optional[int], - max_num_seqs: int, - max_model_len: int, - num_lookahead_slots: int = 0, - delay_factor: float = 0.0, - enable_chunked_prefill: bool = False, - is_multimodal_model: bool = False, - preemption_mode: Optional[str] = None, - num_scheduler_steps: int = 1, - multi_step_stream_outputs: bool = False, - send_delta_data: bool = False, - policy: str = "fcfs") -> None: + def __init__( + self, + task: _Task, + max_num_batched_tokens: Optional[int], + max_num_seqs: int, + max_model_len: int, + num_lookahead_slots: int = 0, + delay_factor: float = 0.0, + enable_chunked_prefill: bool = False, + is_multimodal_model: bool = False, + preemption_mode: Optional[str] = None, + num_scheduler_steps: int = 1, + multi_step_stream_outputs: bool = False, + send_delta_data: bool = False, + policy: str = "fcfs", + ) -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: if num_scheduler_steps > 1: @@ -1135,7 +1293,8 @@ def __init__(self, if enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", - self.max_num_batched_tokens) + self.max_num_batched_tokens, + ) self.task: Final = task self.max_num_seqs = max_num_seqs @@ -1151,33 +1310,39 @@ def __init__(self, self._verify_args() def _verify_args(self) -> None: - if (self.max_num_batched_tokens < self.max_model_len - and not self.chunked_prefill_enabled): + if ( + self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled + ): raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " f"smaller than max_model_len ({self.max_model_len}). " "This effectively limits the maximum sequence length to " "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " - "decrease max_model_len.") + "decrease max_model_len." + ) if self.max_num_batched_tokens < self.max_num_seqs: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " - f"({self.max_num_seqs}).") + f"({self.max_num_seqs})." + ) if self.num_lookahead_slots < 0: raise ValueError( "num_lookahead_slots " f"({self.num_lookahead_slots}) must be greater than or " - "equal to 0.") + "equal to 0." + ) if self.num_scheduler_steps < 1: raise ValueError( "num_scheduler_steps " f"({self.num_scheduler_steps}) must be greater than or " - "equal to 1.") + "equal to 1." + ) @property def is_multi_step(self) -> bool: @@ -1312,15 +1477,21 @@ def maybe_create_spec_config( if speculative_model is None: if num_speculative_tokens is not None: - raise ValueError("num_speculative_tokens was provided without " - "speculative_model.") + raise ValueError( + "num_speculative_tokens was provided without " + "speculative_model." + ) return None - if (speculative_disable_by_batch_size is not None - and speculative_disable_by_batch_size < 2): - raise ValueError("Expect the batch size threshold of disabling " - "speculative decoding is > 1, but got " - f"{speculative_disable_by_batch_size=}") + if ( + speculative_disable_by_batch_size is not None + and speculative_disable_by_batch_size < 2 + ): + raise ValueError( + "Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{speculative_disable_by_batch_size=}" + ) # TODO: The user should be able to specify revision/max model len # for the draft model. It is not currently supported. @@ -1336,8 +1507,10 @@ def maybe_create_spec_config( if ngram_prompt_lookup_min < 1: raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0") if ngram_prompt_lookup_min > ngram_prompt_lookup_max: - raise ValueError(f"{ngram_prompt_lookup_min=} cannot be " - f"larger than {ngram_prompt_lookup_max=}") + raise ValueError( + f"{ngram_prompt_lookup_min=} cannot be " + f"larger than {ngram_prompt_lookup_max=}" + ) # TODO: current we still need extract vocab_size from target model # config, in future, we may try refactor it out, and set @@ -1353,8 +1526,7 @@ def maybe_create_spec_config( tokenizer=target_model_config.tokenizer, tokenizer_mode=target_model_config.tokenizer_mode, trust_remote_code=target_model_config.trust_remote_code, - allowed_local_media_path=target_model_config. - allowed_local_media_path, + allowed_local_media_path=target_model_config.allowed_local_media_path, dtype=target_model_config.dtype, seed=target_model_config.seed, revision=draft_revision, @@ -1364,15 +1536,15 @@ def maybe_create_spec_config( spec_target_max_model_len=target_model_config.max_model_len, quantization=draft_quantization, enforce_eager=target_model_config.enforce_eager, - max_seq_len_to_capture=target_model_config. - max_seq_len_to_capture, + max_seq_len_to_capture=target_model_config.max_seq_len_to_capture, max_logprobs=target_model_config.max_logprobs, ) draft_hf_config = draft_model_config.hf_config - if (num_speculative_tokens is not None - and hasattr(draft_hf_config, "num_lookahead_tokens")): + if num_speculative_tokens is not None and hasattr( + draft_hf_config, "num_lookahead_tokens" + ): draft_hf_config.num_lookahead_tokens = num_speculative_tokens n_predict = getattr(draft_hf_config, "n_predict", None) @@ -1386,48 +1558,60 @@ def maybe_create_spec_config( raise ValueError( "This speculative model supports a maximum of " f"num_speculative_tokens={n_predict}, but " - f"{num_speculative_tokens=} was provided.") + f"{num_speculative_tokens=} was provided." + ) if enable_chunked_prefill and draft_hf_config.model_type in ( - "medusa", "mlp_speculator", "eagle"): + "medusa", + "mlp_speculator", + "eagle", + ): raise ValueError( "Chunked prefill and hidden-state based draft models are " - "not compatible.") + "not compatible." + ) - speculative_draft_tensor_parallel_size = \ - SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( - target_parallel_config, - speculative_draft_tensor_parallel_size, - draft_hf_config + speculative_draft_tensor_parallel_size = SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size( + target_parallel_config, + speculative_draft_tensor_parallel_size, + draft_hf_config, ) - if (enable_chunked_prefill and \ - speculative_draft_tensor_parallel_size != 1): + if ( + enable_chunked_prefill + and speculative_draft_tensor_parallel_size != 1 + ): # TODO - Investigate why the error reported in # https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258 # is happening and re-enable it. raise ValueError( "Chunked prefill and speculative decoding can be enabled " "simultaneously only for draft models with tensor " - "parallel size 1.") + "parallel size 1." + ) draft_model_config.max_model_len = ( SpeculativeConfig._maybe_override_draft_max_model_len( speculative_max_model_len, draft_model_config.max_model_len, target_model_config.max_model_len, - )) + ) + ) draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( target_parallel_config, - speculative_draft_tensor_parallel_size, draft_hf_config)) + speculative_draft_tensor_parallel_size, + draft_hf_config, + ) + ) if num_speculative_tokens is None: raise ValueError( "num_speculative_tokens must be provided with " "speculative_model unless the draft model config contains an " - "n_predict parameter.") + "n_predict parameter." + ) if typical_acceptance_sampler_posterior_threshold is None: typical_acceptance_sampler_posterior_threshold = 0.09 @@ -1445,10 +1629,8 @@ def maybe_create_spec_config( ngram_prompt_lookup_max, ngram_prompt_lookup_min, draft_token_acceptance_method=draft_token_acceptance_method, - typical_acceptance_sampler_posterior_threshold=\ - typical_acceptance_sampler_posterior_threshold, - typical_acceptance_sampler_posterior_alpha=\ - typical_acceptance_sampler_posterior_alpha, + typical_acceptance_sampler_posterior_threshold=typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=typical_acceptance_sampler_posterior_alpha, disable_logprobs=disable_logprobs, disable_log_stats=disable_log_stats, ) @@ -1472,14 +1654,17 @@ def _maybe_override_draft_max_model_len( """ if speculative_max_model_len is not None: - if speculative_max_model_len > draft_max_model_len: - raise ValueError(f"{speculative_max_model_len=} cannot be " - f"larger than {draft_max_model_len=}") + raise ValueError( + f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}" + ) if speculative_max_model_len > target_max_model_len: - raise ValueError(f"{speculative_max_model_len=} cannot be " - f"larger than {target_max_model_len=}") + raise ValueError( + f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}" + ) return speculative_max_model_len @@ -1490,9 +1675,10 @@ def _maybe_override_draft_max_model_len( @staticmethod def _verify_and_get_draft_model_tensor_parallel_size( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int], - draft_hf_config: PretrainedConfig) -> int: + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig, + ) -> int: """ Verifies and adjusts the tensor parallel size for a draft model specified using speculative_draft_tensor_parallel_size. @@ -1505,15 +1691,20 @@ def _verify_and_get_draft_model_tensor_parallel_size( if target_parallel_config.tensor_parallel_size > 1: logger.warning( "MLPSpeculator cannot currently be run with tp>1; " - "setting speculative_draft_tensor_parallel_size=1") + "setting speculative_draft_tensor_parallel_size=1" + ) else: - speculative_draft_tensor_parallel_size = \ + speculative_draft_tensor_parallel_size = ( target_parallel_config.tensor_parallel_size + ) elif speculative_draft_tensor_parallel_size not in ( - 1, target_parallel_config.tensor_parallel_size): + 1, + target_parallel_config.tensor_parallel_size, + ): raise ValueError( f"{speculative_draft_tensor_parallel_size=} cannot be " - f"other value than 1 or target model tensor_parallel_size") + f"other value than 1 or target model tensor_parallel_size" + ) return speculative_draft_tensor_parallel_size @staticmethod @@ -1527,18 +1718,13 @@ def create_draft_parallel_config( This is mostly a copy of the target parallel config, except the tp_size. """ draft_parallel_config = ParallelConfig( - pipeline_parallel_size=target_parallel_config. - pipeline_parallel_size, + pipeline_parallel_size=target_parallel_config.pipeline_parallel_size, tensor_parallel_size=speculative_draft_tensor_parallel_size, - distributed_executor_backend=target_parallel_config. - distributed_executor_backend, - max_parallel_loading_workers=target_parallel_config. - max_parallel_loading_workers, - disable_custom_all_reduce=target_parallel_config. - disable_custom_all_reduce, + distributed_executor_backend=target_parallel_config.distributed_executor_backend, + max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce, tokenizer_pool_config=target_parallel_config.tokenizer_pool_config, - ray_workers_use_nsight=target_parallel_config. - ray_workers_use_nsight, + ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight, placement_group=target_parallel_config.placement_group, ) @@ -1597,15 +1783,18 @@ def __init__( self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer - self.speculative_disable_by_batch_size = \ + self.speculative_disable_by_batch_size = ( speculative_disable_by_batch_size + ) self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 self.draft_token_acceptance_method = draft_token_acceptance_method - self.typical_acceptance_sampler_posterior_threshold = \ + self.typical_acceptance_sampler_posterior_threshold = ( typical_acceptance_sampler_posterior_threshold - self.typical_acceptance_sampler_posterior_alpha = \ + ) + self.typical_acceptance_sampler_posterior_alpha = ( typical_acceptance_sampler_posterior_alpha + ) self.disable_logprobs = disable_logprobs self.disable_log_stats = disable_log_stats @@ -1613,29 +1802,39 @@ def __init__( def _verify_args(self) -> None: if self.num_speculative_tokens <= 0: - raise ValueError("Expected num_speculative_tokens to be greater " - f"than zero ({self.num_speculative_tokens}).") + raise ValueError( + "Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens})." + ) if self.draft_model_config: self.draft_model_config.verify_with_parallel_config( - self.draft_parallel_config) + self.draft_parallel_config + ) # Validate and set draft token acceptance related settings. - if (self.draft_token_acceptance_method is None): - raise ValueError("draft_token_acceptance_method is not set. " - "Expected values are rejection_sampler or " - "typical_acceptance_sampler.") + if self.draft_token_acceptance_method is None: + raise ValueError( + "draft_token_acceptance_method is not set. " + "Expected values are rejection_sampler or " + "typical_acceptance_sampler." + ) - if (self.draft_token_acceptance_method != 'rejection_sampler' - and self.draft_token_acceptance_method != - 'typical_acceptance_sampler'): + if ( + self.draft_token_acceptance_method != "rejection_sampler" + and self.draft_token_acceptance_method + != "typical_acceptance_sampler" + ): raise ValueError( "Expected draft_token_acceptance_method to be either " "rejection_sampler or typical_acceptance_sampler. Instead it " - f"is {self.draft_token_acceptance_method}") + f"is {self.draft_token_acceptance_method}" + ) - if (self.typical_acceptance_sampler_posterior_threshold < 0 - or self.typical_acceptance_sampler_posterior_alpha < 0): + if ( + self.typical_acceptance_sampler_posterior_threshold < 0 + or self.typical_acceptance_sampler_posterior_alpha < 0 + ): raise ValueError( "Expected typical_acceptance_sampler_posterior_threshold " "and typical_acceptance_sampler_posterior_alpha to be > 0. " @@ -1643,7 +1842,8 @@ def _verify_args(self) -> None: f"typical_acceptance_sampler_posterior_threshold = " f"{self.typical_acceptance_sampler_posterior_threshold} and " f"typical_acceptance_sampler_posterior_alpha = " - f"{self.typical_acceptance_sampler_posterior_alpha}") + f"{self.typical_acceptance_sampler_posterior_alpha}" + ) @property def num_lookahead_slots(self) -> int: @@ -1685,11 +1885,13 @@ def __post_init__(self): if self.max_lora_rank not in possible_max_ranks: raise ValueError( f"max_lora_rank ({self.max_lora_rank}) must be one of " - f"{possible_max_ranks}.") + f"{possible_max_ranks}." + ) if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: raise ValueError( f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " - f"must be one of {possible_lora_extra_vocab_size}.") + f"must be one of {possible_lora_extra_vocab_size}." + ) if self.max_loras < 1: raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") if self.max_cpu_loras is None: @@ -1697,7 +1899,8 @@ def __post_init__(self): elif self.max_cpu_loras < self.max_loras: raise ValueError( f"max_cpu_loras ({self.max_cpu_loras}) must be >= " - f"max_loras ({self.max_loras})") + f"max_loras ({self.max_loras})" + ) def verify_with_model_config(self, model_config: ModelConfig): if self.lora_dtype in (None, "auto"): @@ -1705,11 +1908,14 @@ def verify_with_model_config(self, model_config: ModelConfig): elif isinstance(self.lora_dtype, str): self.lora_dtype = getattr(torch, self.lora_dtype) if model_config.quantization and model_config.quantization not in [ - "awq", "gptq" + "awq", + "gptq", ]: # TODO support marlin - logger.warning("%s quantization is not tested with LoRA yet.", - model_config.quantization) + logger.warning( + "%s quantization is not tested with LoRA yet.", + model_config.quantization, + ) def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): # Reminder: Please update docs/source/serving/compatibility_matrix.rst @@ -1726,10 +1932,11 @@ class PromptAdapterConfig: prompt_adapter_dtype: Optional[torch.dtype] = None def __post_init__(self): - if self.max_prompt_adapters < 1: - raise ValueError(f"max_prompt_adapters " - f"({self.max_prompt_adapters}) must be >= 1.") + raise ValueError( + f"max_prompt_adapters " + f"({self.max_prompt_adapters}) must be >= 1." + ) if self.max_prompt_adapter_token == 0: raise ValueError("max_prompt_adapter_token must be set.") if self.max_cpu_prompt_adapters is None: @@ -1739,8 +1946,9 @@ def verify_with_model_config(self, model_config: ModelConfig): if self.prompt_adapter_dtype in (None, "auto"): self.prompt_adapter_dtype = model_config.dtype elif isinstance(self.prompt_adapter_dtype, str): - self.prompt_adapter_dtype = getattr(torch, - self.prompt_adapter_dtype) + self.prompt_adapter_dtype = getattr( + torch, self.prompt_adapter_dtype + ) @dataclass @@ -1780,15 +1988,15 @@ class PoolerConfig: step_tag_id: Optional[int] = None """ - If set, only the score corresponding to the ``step_tag_id`` in the + If set, only the score corresponding to the ``step_tag_id`` in the generated sentence should be returned. Otherwise, the scores for all tokens are returned. """ returned_token_ids: Optional[List[int]] = None """ - A list of indices for the vocabulary dimensions to be extracted, - such as the token IDs of ``good_token`` and ``bad_token`` in the + A list of indices for the vocabulary dimensions to be extracted, + such as the token IDs of ``good_token`` and ``bad_token`` in the ``math-shepherd-mistral-7b-prm`` model. """ @@ -1826,7 +2034,8 @@ def _get_and_verify_dtype( logger.info( "For Gemma 2, we downcast float32 to bfloat16 instead " "of float16 by default. Please specify `dtype` if you " - "want to use float16.") + "want to use float16." + ) torch_dtype = torch.bfloat16 else: # Following the common practice, we use float16 for float32 @@ -1839,7 +2048,8 @@ def _get_and_verify_dtype( logger.info( "For HPU, we cast models to bfloat16 instead of" "using float16 by default. Please specify `dtype` if you " - "want to use float16.") + "want to use float16." + ) torch_dtype = torch.bfloat16 else: if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: @@ -1898,19 +2108,23 @@ def _get_and_verify_max_len( for key in possible_keys: max_len = getattr(hf_config, key, None) if max_len is not None: - max_len_key = key if max_len < derived_max_model_len \ - else max_len_key + max_len_key = ( + key if max_len < derived_max_model_len else max_len_key + ) derived_max_model_len = min(derived_max_model_len, max_len) # If sliding window is manually disabled, max_length should be less # than the sliding window length in the model config. if disable_sliding_window and sliding_window_len is not None: - sliding_window_len_min = get_min_sliding_window(sliding_window_len) - max_len_key = "sliding_window" \ - if sliding_window_len_min < derived_max_model_len else max_len_key - derived_max_model_len = min(derived_max_model_len, - sliding_window_len_min) + max_len_key = ( + "sliding_window" + if sliding_window_len_min < derived_max_model_len + else max_len_key + ) + derived_max_model_len = min( + derived_max_model_len, sliding_window_len_min + ) # If none of the keys were found in the config, use a default and # log a warning. @@ -1928,8 +2142,10 @@ def _get_and_verify_max_len( logger.warning( "The model's config.json does not contain any of the following " "keys to determine the original maximum length of the model: " - "%s. Assuming the model's maximum length is %d.", possible_keys, - default_max_len) + "%s. Assuming the model's maximum length is %d.", + possible_keys, + default_max_len, + ) derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) @@ -1945,7 +2161,8 @@ def _get_and_verify_max_len( raise NotImplementedError( "Disabling sliding window is not supported for models " "with rope_scaling. Please raise an issue so we can " - "investigate.") + "investigate." + ) # NOTE: rope_type == "default" does not define factor # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py @@ -1953,7 +2170,8 @@ def _get_and_verify_max_len( if rope_type == "yarn": derived_max_model_len = rope_scaling[ - "original_max_position_embeddings"] + "original_max_position_embeddings" + ] derived_max_model_len *= scaling_factor if encoder_config and "max_seq_length" in encoder_config: @@ -1975,35 +2193,42 @@ def _get_and_verify_max_len( raise NotImplementedError( "Disabling sliding window is not supported for models " "model_max_length in the config. Please raise an issue " - "so we can investigate.") + "so we can investigate." + ) else: msg = ( f"User-specified max_model_len ({max_model_len}) is greater " f"than the derived max_model_len ({max_len_key}=" f"{derived_max_model_len} or model_max_length=" f"{model_max_length} in model's config.json). This may lead " - "to incorrect model outputs or CUDA errors.") + "to incorrect model outputs or CUDA errors." + ) if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: logger.warning( "%s Make sure the value is correct and within the " - "model context size.", msg) + "model context size.", + msg, + ) else: raise ValueError( f"{msg} To allow overriding this maximum, set " - "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") + "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1" + ) return int(max_model_len) def get_min_sliding_window( - sliding_window: Union[int, List[Optional[int]]]) -> int: + sliding_window: Union[int, List[Optional[int]]], +) -> int: if isinstance(sliding_window, list): return min(s for s in sliding_window if s is not None) return sliding_window -def get_served_model_name(model: str, - served_model_name: Optional[Union[str, List[str]]]): +def get_served_model_name( + model: str, served_model_name: Optional[Union[str, List[str]]] +): """ If the input is a non-empty list, the first model_name in `served_model_name` is taken. @@ -2022,20 +2247,24 @@ def get_served_model_name(model: str, class DecodingConfig: """Dataclass which contains the decoding strategy of the engine""" - # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer' - guided_decoding_backend: str = 'outlines' + # Which guided decoding algo to use. + # 'outlines' / 'lm-format-enforcer' / 'xgrammar' + guided_decoding_backend: str = "xgrammar" def __post_init__(self): - valid_guided_backends = ['outlines', 'lm-format-enforcer'] + valid_guided_backends = ["outlines", "lm-format-enforcer", "xgrammar"] backend = self.guided_decoding_backend if backend not in valid_guided_backends: - raise ValueError(f"Invalid guided_decoding_backend '{backend}," - f"must be one of {valid_guided_backends}") + raise ValueError( + f"Invalid guided_decoding_backend '{backend}," + f"must be one of {valid_guided_backends}" + ) @dataclass class ObservabilityConfig: """Configuration for observability.""" + otlp_traces_endpoint: Optional[str] = None # Collecting detailed timing information for each request can be expensive. @@ -2051,7 +2280,8 @@ def __post_init__(self): raise ValueError( "OpenTelemetry is not available. Unable to configure " "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " - f"installed. Original error:\n{otel_import_error_traceback}") + f"installed. Original error:\n{otel_import_error_traceback}" + ) @dataclass @@ -2062,12 +2292,9 @@ class VllmConfig: model_config: ModelConfig = field(default=None, init=True) # type: ignore cache_config: CacheConfig = field(default=None, init=True) # type: ignore - parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore - scheduler_config: SchedulerConfig = field(default=None, - init=True) # type: ignore - device_config: DeviceConfig = field(default=None, - init=True) # type: ignore + parallel_config: ParallelConfig = field(default=None, init=True) # type: ignore + scheduler_config: SchedulerConfig = field(default=None, init=True) # type: ignore + device_config: DeviceConfig = field(default=None, init=True) # type: ignore load_config: LoadConfig = field(default=None, init=True) # type: ignore lora_config: Optional[LoRAConfig] = None speculative_config: Optional[SpeculativeConfig] = None @@ -2078,12 +2305,14 @@ class VllmConfig: @staticmethod def _get_quantization_config( - model_config: ModelConfig, - load_config: LoadConfig) -> Optional[QuantizationConfig]: + model_config: ModelConfig, load_config: LoadConfig + ) -> Optional[QuantizationConfig]: """Get the quantization config.""" if model_config.quantization is not None: from vllm.model_executor.model_loader.weight_utils import ( - get_quant_config) + get_quant_config, + ) + quant_config = get_quant_config(model_config, load_config) capability_tuple = current_platform.get_device_capability() @@ -2094,13 +2323,15 @@ def _get_quantization_config( f"The quantization method {model_config.quantization} " "is not supported for the current GPU. Minimum " f"capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + f"Current capability: {capability}." + ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError( f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") + f"{supported_dtypes}" + ) return quant_config return None @@ -2111,12 +2342,13 @@ def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig": return replace(self, model_config=model_config) def __post_init__(self): - """Verify configs are valid & consistent with each other. - """ + """Verify configs are valid & consistent with each other.""" if self.model_config is not None: - self.model_config.verify_async_output_proc(self.parallel_config, - self.speculative_config, - self.device_config) + self.model_config.verify_async_output_proc( + self.parallel_config, + self.speculative_config, + self.device_config, + ) self.model_config.verify_with_parallel_config(self.parallel_config) if self.cache_config is not None: @@ -2124,54 +2356,64 @@ def __post_init__(self): if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) - self.lora_config.verify_with_scheduler_config( - self.scheduler_config) + self.lora_config.verify_with_scheduler_config(self.scheduler_config) if self.prompt_adapter_config: self.prompt_adapter_config.verify_with_model_config( - self.model_config) + self.model_config + ) - if self.quant_config is None and \ - self.model_config is not None and self.load_config is not None: + if ( + self.quant_config is None + and self.model_config is not None + and self.load_config is not None + ): self.quant_config = VllmConfig._get_quantization_config( - self.model_config, self.load_config) + self.model_config, self.load_config + ) def __str__(self): - return ("model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s") % \ - (self.model_config.model, self.speculative_config, - self.model_config.tokenizer, - self.model_config.skip_tokenizer_init, - self.model_config.tokenizer_mode, - self.model_config.revision, - self.model_config.override_neuron_config, - self.model_config.tokenizer_revision, - self.model_config.trust_remote_code, - self.model_config.dtype, - self.model_config.max_model_len, - self.load_config.download_dir, - self.load_config.load_format, - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size, - self.parallel_config.disable_custom_all_reduce, - self.model_config.quantization, - self.model_config.enforce_eager, - self.cache_config.cache_dtype, - self.model_config.quantization_param_path, - self.device_config.device, self.decoding_config, - self.observability_config, self.model_config.seed, - self.model_config.served_model_name, - self.scheduler_config.num_scheduler_steps, - self.cache_config.enable_prefix_caching, - self.model_config.use_async_output_proc, - self.model_config.mm_processor_kwargs) + return ( + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, " + "num_scheduler_steps=%d, enable_prefix_caching=%s, " + "use_async_output_proc=%s, mm_processor_kwargs=%s" + ) % ( + self.model_config.model, + self.speculative_config, + self.model_config.tokenizer, + self.model_config.skip_tokenizer_init, + self.model_config.tokenizer_mode, + self.model_config.revision, + self.model_config.override_neuron_config, + self.model_config.tokenizer_revision, + self.model_config.trust_remote_code, + self.model_config.dtype, + self.model_config.max_model_len, + self.load_config.download_dir, + self.load_config.load_format, + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size, + self.parallel_config.disable_custom_all_reduce, + self.model_config.quantization, + self.model_config.enforce_eager, + self.cache_config.cache_dtype, + self.model_config.quantization_param_path, + self.device_config.device, + self.decoding_config, + self.observability_config, + self.model_config.seed, + self.model_config.served_model_name, + self.scheduler_config.num_scheduler_steps, + self.cache_config.enable_prefix_caching, + self.model_config.use_async_output_proc, + self.model_config.mm_processor_kwargs, + ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d73f95f59c71f..8426e606e56bb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -166,7 +166,7 @@ class EngineArgs: scheduler_delay_factor: float = 0.0 enable_chunked_prefill: Optional[bool] = None - guided_decoding_backend: str = 'outlines' + guided_decoding_backend: str = 'xgrammar' # Speculative decoding configuration. speculative_model: Optional[str] = None speculative_model_quantization: Optional[str] = None @@ -351,11 +351,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--guided-decoding-backend', type=str, - default='outlines', - choices=['outlines', 'lm-format-enforcer'], + default='xgrammar', + choices=['outlines', 'lm-format-enforcer', 'xgrammar'], help='Which engine will be used for guided decoding' ' (JSON schema / regex etc) by default. Currently support ' - 'https://github.com/outlines-dev/outlines and ' + 'https://github.com/outlines-dev/outlines,' + 'https://github.com/mlc-ai/xgrammar, and ' 'https://github.com/noamgat/lm-format-enforcer.' ' Can be overridden per request via guided_decoding_backend' ' parameter.') diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5a5388708b1c6..8eb53a8fccf72 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,4 +1,5 @@ import asyncio +import copy import time import weakref from functools import partial @@ -501,7 +502,8 @@ async def add_request_async( sampling_params=params, tokenizer=await self.get_tokenizer_async(lora_request), default_guided_backend=self.decoding_config. - guided_decoding_backend) + guided_decoding_backend, + model_config=self.model_config) self._add_processed_request( request_id=request_id, @@ -522,22 +524,30 @@ async def check_health_async(self) -> None: async def build_guided_decoding_logits_processor_async( sampling_params: SamplingParams, tokenizer: AnyTokenizer, - default_guided_backend: str) -> SamplingParams: + default_guided_backend: str, + model_config: ModelConfig) -> SamplingParams: """Constructs logits processors based on the guided_decoding, logits_bias, and allowed_token_ids fields in sampling_params. Deletes those fields and adds the constructed logits processors to the logits_processors field. Modifies sampling params in-place and returns the modified sampling params.""" - if (guided_decoding := sampling_params.guided_decoding) is None: + if sampling_params.guided_decoding is None: return sampling_params + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding + logger.debug("Building guided decoding logits processor. " "Params: %s", guided_decoding) guided_decoding.backend = guided_decoding.backend or default_guided_backend processor = await get_guided_decoding_logits_processor( - guided_params=guided_decoding, tokenizer=tokenizer) + guided_params=guided_decoding, + tokenizer=tokenizer, + model_config=model_config) if processor: if sampling_params.logits_processors is None: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index aa9c7893c4cfe..ad372587a8362 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,3 +1,4 @@ +import copy import time from collections import Counter as collectionsCounter from collections import deque @@ -1034,9 +1035,9 @@ def _update_num_computed_tokens_for_multi_step_prefill( This function updates num_computed_tokens for prompt sequences when Multi-Step is enabled. - seq_group: SequenceGroup to update the num_computed_tokens for. + seq_group: SequenceGroup to update the num_computed_tokens for. seq_group_meta: Metadata of the given SequenceGroup. - is_first_step_output: Optional[bool] - + is_first_step_output: Optional[bool] - When available, is_first_step_output indicates if the appended output token is the output of the first-step in multi-step. A value of None indicates that outputs from all steps in @@ -2043,7 +2044,11 @@ def _build_logits_processors( logits_processors = [] - if (guided_decoding := sampling_params.guided_decoding) is not None: + if sampling_params.guided_decoding is not None: + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding logger.debug( "Building guided decoding logits processor in " @@ -2054,7 +2059,9 @@ def _build_logits_processors( self.decoding_config.guided_decoding_backend processor = get_local_guided_decoding_logits_processor( - guided_params=guided_decoding, tokenizer=tokenizer) + guided_params=guided_decoding, + tokenizer=tokenizer, + model_config=self.model_config) if processor: logits_processors.append(processor) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index fe21c58c775fe..66c1ba9df0eda 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -472,8 +472,8 @@ def generate( trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request to use for generation, if any. - priority: Priority of the request (lower means earlier handling). - Any priority other than 0 will lead to an error if the + priority: Priority of the request (lower means earlier handling). + Any priority other than 0 will lead to an error if the scheduling policy is not "priority". """ if inputs is not None: @@ -586,6 +586,7 @@ async def _process_request( default_guided_backend=(self.decoding_config.guided_decoding_backend if self.decoding_config else DecodingConfig.guided_decoding_backend), + model_config=self.model_config ) # 1) Create output queue for this requests. diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index d7b67425fcbc0..23c31fcfd7f05 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -1,14 +1,54 @@ -from typing import Optional +from __future__ import annotations -from vllm.logits_process import LogitsProcessor -from vllm.sampling_params import GuidedDecodingParams +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.config import ModelConfig + from vllm.logits_process import LogitsProcessor + from vllm.sampling_params import GuidedDecodingParams + +logger = init_logger(__name__) + + +def maybe_backend_fallback( + guided_params: GuidedDecodingParams) -> GuidedDecodingParams: + # lm-format-enforce doesn't support grammar, fallback to xgrammar + if (guided_params.backend == "lm-format-enforcer" + and guided_params.grammar is not None): + logger.warning( + "lm-format-enforcer does not support grammar guided decoding. " + "Falling back to use xgrammar instead.") + guided_params.backend = "xgrammar" + + if guided_params.backend == "xgrammar": + # xgrammar doesn't support regex or choice, fallback to outlines + if guided_params.regex is not None or guided_params.choice is not None: + logger.warning( + "xgrammar only supports json or grammar guided decoding. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" + + # xgrammar only supports EBNF grammars and uses the GBNF format + # https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + elif (guided_params.grammar is not None + and "::=" not in guided_params.grammar): + logger.warning("xgrammar only supports EBNF grammars. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" + + return guided_params async def get_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer) -> Optional[LogitsProcessor]: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + model_config: ModelConfig) -> LogitsProcessor | None: + guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead - if guided_params.backend == 'outlines' or guided_params.grammar: + if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) @@ -19,17 +59,23 @@ async def get_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer, model_config) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " - "Must be one of 'outlines, 'lm-format-enforcer'") + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") def get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer) -> Optional[LogitsProcessor]: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + model_config: ModelConfig) -> LogitsProcessor | None: + guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead - if guided_params.backend == 'outlines' or guided_params.grammar: + if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) @@ -40,7 +86,12 @@ def get_local_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer, model_config) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " - "Must be one of 'outlines, 'lm-format-enforcer'") + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py new file mode 100644 index 0000000000000..8287cd6cf3aa0 --- /dev/null +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -0,0 +1,251 @@ +# noqa: UP007 +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, NamedTuple + +import torch +from transformers import PreTrainedTokenizerFast + +try: + import xgrammar as xgr + from xgrammar.base import _core as xgr_core +except ImportError: + pass + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.config import ModelConfig + from vllm.sampling_params import GuidedDecodingParams + + +# TODO: passing batch size to max threads here +def get_local_xgrammar_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + max_threads: int = 8): + config = GrammarConfig.from_guided_params(guided_params=guided_params, + model_config=model_config, + tokenizer=tokenizer, + max_threads=max_threads) + return XGrammarLogitsProcessor(config) + + +class TokenizerData(NamedTuple): + """Immutable container for cached tokenizer data.""" + encoded_vocab: list[str] + stop_token_ids: list[int] | None + backend_str: str + + +class TokenizerDataCache: + """Cache manager for tokenizer data to avoid repeated processing.""" + _cache: dict[int, TokenizerData] = {} + + @classmethod + def get_tokenizer_data(cls, + tokenizer: PreTrainedTokenizer) -> TokenizerData: + tokenizer_hash = hash(tokenizer) + + if tokenizer_hash not in cls._cache: + # Vendored from xgrammar logic since we cannot pickle the tokenizer + # https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501 + try: + encoded_vocab = [ + token for token, _ in sorted(tokenizer.get_vocab().items(), + key=lambda x: x[1]) + ] + except AttributeError as e: + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"{type(tokenizer)}. The tokenizer should have a " + "get_vocab method.") from e + + stop_token_ids = None + backend_str = xgr.VocabType.RAW + if isinstance(tokenizer, PreTrainedTokenizerFast): + backend_str = tokenizer.backend_tokenizer.to_str() + if stop_token_ids is None and hasattr( + tokenizer, + "eos_token_id") and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + + cls._cache[tokenizer_hash] = TokenizerData( + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str) + + return cls._cache[tokenizer_hash] + + +class GrammarCompilerCache: + """ + Cache for GrammarCompiler instances based on tokenizer. + + This cache reduces the overhead of creating new compiler instances when + using the same tokenizer configuration. + """ + _cache: dict[str, xgr.GrammarCompiler] = {} + + @classmethod + def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: + cache_key = str(config.tokenizer_hash) + + if cache_key not in cls._cache: + assert config.encoded_vocab is not None + tokenizer_info = xgr.TokenizerInfo._create_from_handle( + xgr_core.TokenizerInfo.from_huggingface( + config.encoded_vocab, config.backend_str, + config.vocab_size, config.stop_token_ids)) + cls._cache[cache_key] = xgr.GrammarCompiler( + tokenizer_info, max_threads=config.max_threads) + + return cls._cache[cache_key] + + +@dataclass +class GrammarConfig: + """Serializable configuration for grammar compilation""" + tokenizer_hash: int + vocab_size: int + json_str: str | None = None + grammar_str: str | None = None + json_object: bool | None = None + max_threads: int = 8 + # Only populated if tokenizer_hash not in cache + encoded_vocab: list[str] | None = None + stop_token_ids: list[int] | None = None + backend_str: str | None = None + + @classmethod + def from_guided_params(cls, + guided_params: GuidedDecodingParams, + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, + max_threads: int = 8) -> GrammarConfig: + + tokenizer_hash = hash(tokenizer) + # Only get tokenizer data if not already cached + if tokenizer_hash in TokenizerDataCache._cache: + encoded_vocab = None + stop_token_ids = None + backend_str = None + else: + tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) + encoded_vocab = tokenizer_data.encoded_vocab + stop_token_ids = tokenizer_data.stop_token_ids + backend_str = tokenizer_data.backend_str + + if guided_params.json: + if not isinstance(guided_params.json, str): + json_str = json.dumps(guided_params.json) + else: + json_str = guided_params.json + return cls(json_str=json_str, + vocab_size=model_config.hf_config.vocab_size, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads) + elif guided_params.grammar: + return cls(grammar_str=guided_params.grammar, + vocab_size=model_config.hf_config.vocab_size, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads) + elif guided_params.json_object: + return cls(json_object=True, + vocab_size=model_config.hf_config.vocab_size, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads) + else: + raise ValueError( + "Currently only support JSON and EBNF grammar mode for xgrammar" + ) + + +@dataclass +class XGrammarLogitsProcessor: + """Wrapper class to support pickle protocol""" + config: GrammarConfig + + ctx: xgr.CompiledGrammar | None = None + token_bitmask: torch.Tensor = None # type: ignore[assignment] + matchers: list[xgr.GrammarMatcher] = field(default_factory=list) + batch_size: int = field(default=1) + prefilled: bool = field(default=False) + + def __getstate__(self) -> dict[str, Any]: + return {'config': self.config} + + def __setstate__(self, state: dict[str, Any]): + self.config = state['config'] + + self.ctx = None + self.matchers = [] + self.batch_size = 1 + self.token_bitmask = None # type: ignore[assignment] + self.prefilled = False + + def _ensure_ctx(self): + """Lazily initialize the processor in the worker process""" + if self.ctx is None: + compiler = GrammarCompilerCache.get_compiler(self.config) + if self.config.json_str is not None: + self.ctx = compiler.compile_json_schema(self.config.json_str) + elif self.config.grammar_str is not None: + self.ctx = compiler.compile_grammar(self.config.grammar_str) + elif self.config.json_object: + self.ctx = compiler.compile_builtin_json_grammar() + else: + raise ValueError( + "Invalid configuration for xgrammar logits processor") + + def __call__(self, input_ids: list[int], + scores: torch.Tensor) -> torch.Tensor: + if self.ctx is None: + self._ensure_ctx() + + if len(self.matchers) == 0: + self.matchers = [ + xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + ] + self.token_bitmask = xgr.allocate_token_bitmask( + self.batch_size, self.config.vocab_size) + + if not self.prefilled: + # Have not sampled a token yet + self.prefilled = True + else: + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + sampled_token = input_ids[-1] + assert self.matchers[i].accept_token(sampled_token) + + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + # @ubospica: ideally, fill_next_token_bitmask should be + # parallelized with model decoding + # See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303 + matcher.fill_next_token_bitmask(self.token_bitmask, i) + + # token_bitmask is a CPU tensor for use with accept_token and + # fill_next_token_bitmask so we move it to the device of scores + device_type = scores.device.type + if device_type != "cuda": + scores = scores.to("cpu") + xgr.apply_token_bitmask_inplace(scores, + self.token_bitmask.to(scores.device)) + if device_type != "cuda": + scores = scores.to(device_type) + + return scores