From caf428997f7b23f9d67038708247c839842caa87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= Date: Sat, 30 Nov 2024 22:26:14 +0000 Subject: [PATCH] [Core] add `xgrammar` as guided generation provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Essentially a cleaned up version of this `pr`: https://github.com/vllm-project/vllm/pull/10785 Especially since `outlines` is rather slow and the new version is though to intergrate as they do not focus on being pickleable which is a key feature for us using the multiprocessing engine: https://github.com/dottxt-ai/outlines-core/pull/99 I assume more and more will change over to `xgrammar`. This is a minimum implementation. https://arxiv.org/pdf/2411.15100 Signed-off-by: Jannis Schönleber --- requirements-common.txt | 1 + vllm/config.py | 2 +- vllm/engine/arg_utils.py | 2 +- .../guided_decoding/__init__.py | 10 + .../guided_decoding/xgrammar_decoding.py | 171 ++++++++++++++++++ 5 files changed, 184 insertions(+), 2 deletions(-) create mode 100644 vllm/model_executor/guided_decoding/xgrammar_decoding.py diff --git a/requirements-common.txt b/requirements-common.txt index 02e3d65fb774c..818f72e14be96 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines >= 0.0.43, < 0.1 +xgrammar typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/vllm/config.py b/vllm/config.py index b1e5b412fec8f..0390146866d2a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1999,7 +1999,7 @@ class DecodingConfig: guided_decoding_backend: str = 'outlines' def __post_init__(self): - valid_guided_backends = ['outlines', 'lm-format-enforcer'] + valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar'] backend = self.guided_decoding_backend if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}," diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0020562c3c3a..0b06f5a040cac 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -362,7 +362,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: '--guided-decoding-backend', type=str, default='outlines', - choices=['outlines', 'lm-format-enforcer'], + choices=['outlines', 'lm-format-enforcer', 'xgrammar'], help='Which engine will be used for guided decoding' ' (JSON schema / regex etc) by default. Currently support ' 'https://github.com/outlines-dev/outlines and ' diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index d7b67425fcbc0..d5f6c01072c93 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -19,6 +19,11 @@ async def get_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " @@ -40,6 +45,11 @@ def get_local_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == "xgrammar": + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py new file mode 100644 index 0000000000000..192dc847c521e --- /dev/null +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -0,0 +1,171 @@ +# noqa: UP007 +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import torch +from transformers import PreTrainedTokenizerFast + +try: + import xgrammar as xgr + from xgrammar.base import _core as xgr_core +except ImportError: + pass + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.sampling_params import GuidedDecodingParams + + +# TODO: passing batch size to max threads here +def get_local_xgrammar_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + max_threads: int = 8): + config = GrammarConfig.from_guided_params(guided_params=guided_params, + tokenizer=tokenizer, + max_threads=max_threads) + return XGrammarLogitsProcessor(config) + + +@dataclass +class GrammarConfig: + """Serializable configuration for grammar compilation""" + vocab_size: int = 0 + max_threads: int = 8 + json_str: str | None = None + grammar_str: str | None = None + encoded_vocab: dict[str, int] | None = None + stop_token_ids: list[int] | None = None + backend_str: str = "" + + @classmethod + def from_guided_params(cls, + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + max_threads: int = 8) -> GrammarConfig: + + # Vendorred from xgrammar logics + try: + encoded_vocab = tokenizer.get_vocab() + encoded_vocab = [ + token for token, _ in sorted(encoded_vocab.items(), + key=lambda x: x[1]) + ] + except AttributeError as e: + raise ValueError( + f"Cannot get the vocabulary of the tokenizer {type(tokenizer)}." + " The tokenizer should have a get_vocab method.") from e + + stop_token_ids = None + backend_str = xgr.VocabType.RAW + if isinstance(tokenizer, PreTrainedTokenizerFast): + # the vocabulary is directly obtained from tokenizer.get_vocab() + backend_str = tokenizer.backend_tokenizer.to_str() + if stop_token_ids is None and hasattr( + tokenizer, + "eos_token_id") and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + + if guided_params.json: + if not isinstance(guided_params.json, str): + json_str = json.dumps(guided_params.json) + else: + json_str = guided_params.json + return cls(json_str=json_str, + vocab_size=len(tokenizer.vocab.keys()), + max_threads=max_threads, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str) + elif guided_params.grammar: + return cls(grammar_str=guided_params.grammar, + vocab_size=len(tokenizer.vocab.keys()), + max_threads=max_threads, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str) + else: + raise ValueError( + "Currently only support JSON and EBNF grammar mode for xgrammar" + ) + + def create_tokenizer_info(self): + return xgr.TokenizerInfo._create_from_handle( + xgr_core.TokenizerInfo.from_huggingface(self.encoded_vocab, + self.backend_str, + self.vocab_size, + self.stop_token_ids)) + + +@dataclass +class XGrammarLogitsProcessor: + """Wrapper class to support pickle protocol""" + config: GrammarConfig + + ctx: xgr.CompiledGrammar | None = None + matchers: list[xgr.GrammarMatcher] = field(default_factory=list) + batch_size: int = 1 + token_bitmask: torch.Tensor = None + prefilled: bool = False + + def __getstate__(self) -> dict[str, Any]: + return {'config': self.config} + + def __setstate__(self, state: dict[str, Any]): + self.config = state['config'] + + self.ctx = None + self.matchers = [] + self.batch_size = 1 + self.token_bitmask = None + self.prefilled = False + + def _ensure_ctx(self): + """Lazily initialize the processor in the worker process""" + if self.ctx is None: + compiler = xgr.GrammarCompiler(self.config.create_tokenizer_info(), + max_threads=self.config.max_threads) + + if self.config.json_str is not None: + self.ctx = compiler.compile_json_schema(self.config.json_str) + else: + self.ctx = compiler.compile_grammar(self.config.grammar_str) + + def __call__(self, input_ids: list[int], + scores: torch.Tensor) -> torch.Tensor: + if self.ctx is None: + self._ensure_ctx() + + if len(self.matchers) == 0: + self.matchers = [ + xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + ] + self.token_bitmask = xgr.allocate_token_bitmask( + self.batch_size, self.config.vocab_size) + + if not self.prefilled: + # Have not sampled a token yet + self.prefilled = True + else: + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + sampled_token = input_ids[-1] + assert self.matchers[i].accept_token(sampled_token) + + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + matcher.fill_next_token_bitmask(self.token_bitmask, i) + + device_type = scores.device.type + if device_type != "cuda": + scores = scores.to("cpu") + xgr.apply_token_bitmask_inplace(scores, + self.token_bitmask.to(scores.device)) + if device_type != "cuda": + scores = scores.to(device_type) + + return scores \ No newline at end of file