From d32614810ae719d8e1d0981d417170b8564904f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= Date: Sat, 30 Nov 2024 22:26:14 +0000 Subject: [PATCH 1/4] [Core] add `xgrammar` as guided generation provider MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Essentially a cleaned up version of this `pr`: https://github.com/vllm-project/vllm/pull/10785 Especially since `outlines` is rather slow and the new version is though to intergrate as they do not focus on being pickleable which is a key feature for us using the multiprocessing engine: https://github.com/dottxt-ai/outlines-core/pull/99 I assume more and more will change over to `xgrammar`. This is a minimum implementation. https://arxiv.org/pdf/2411.15100 Signed-off-by: Jannis Schönleber --- requirements-common.txt | 1 + vllm/config.py | 2 +- vllm/engine/arg_utils.py | 2 +- .../guided_decoding/__init__.py | 10 + .../guided_decoding/xgrammar_decoding.py | 171 ++++++++++++++++++ 5 files changed, 184 insertions(+), 2 deletions(-) create mode 100644 vllm/model_executor/guided_decoding/xgrammar_decoding.py diff --git a/requirements-common.txt b/requirements-common.txt index 02e3d65fb774c..818f72e14be96 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines >= 0.0.43, < 0.1 +xgrammar typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/vllm/config.py b/vllm/config.py index b1e5b412fec8f..0390146866d2a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1999,7 +1999,7 @@ class DecodingConfig: guided_decoding_backend: str = 'outlines' def __post_init__(self): - valid_guided_backends = ['outlines', 'lm-format-enforcer'] + valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar'] backend = self.guided_decoding_backend if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}," diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f0020562c3c3a..0b06f5a040cac 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -362,7 +362,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: '--guided-decoding-backend', type=str, default='outlines', - choices=['outlines', 'lm-format-enforcer'], + choices=['outlines', 'lm-format-enforcer', 'xgrammar'], help='Which engine will be used for guided decoding' ' (JSON schema / regex etc) by default. Currently support ' 'https://github.com/outlines-dev/outlines and ' diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index d7b67425fcbc0..d5f6c01072c93 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -19,6 +19,11 @@ async def get_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " @@ -40,6 +45,11 @@ def get_local_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == "xgrammar": + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py new file mode 100644 index 0000000000000..192dc847c521e --- /dev/null +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -0,0 +1,171 @@ +# noqa: UP007 +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +import torch +from transformers import PreTrainedTokenizerFast + +try: + import xgrammar as xgr + from xgrammar.base import _core as xgr_core +except ImportError: + pass + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.sampling_params import GuidedDecodingParams + + +# TODO: passing batch size to max threads here +def get_local_xgrammar_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + max_threads: int = 8): + config = GrammarConfig.from_guided_params(guided_params=guided_params, + tokenizer=tokenizer, + max_threads=max_threads) + return XGrammarLogitsProcessor(config) + + +@dataclass +class GrammarConfig: + """Serializable configuration for grammar compilation""" + vocab_size: int = 0 + max_threads: int = 8 + json_str: str | None = None + grammar_str: str | None = None + encoded_vocab: dict[str, int] | None = None + stop_token_ids: list[int] | None = None + backend_str: str = "" + + @classmethod + def from_guided_params(cls, + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + max_threads: int = 8) -> GrammarConfig: + + # Vendorred from xgrammar logics + try: + encoded_vocab = tokenizer.get_vocab() + encoded_vocab = [ + token for token, _ in sorted(encoded_vocab.items(), + key=lambda x: x[1]) + ] + except AttributeError as e: + raise ValueError( + f"Cannot get the vocabulary of the tokenizer {type(tokenizer)}." + " The tokenizer should have a get_vocab method.") from e + + stop_token_ids = None + backend_str = xgr.VocabType.RAW + if isinstance(tokenizer, PreTrainedTokenizerFast): + # the vocabulary is directly obtained from tokenizer.get_vocab() + backend_str = tokenizer.backend_tokenizer.to_str() + if stop_token_ids is None and hasattr( + tokenizer, + "eos_token_id") and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + + if guided_params.json: + if not isinstance(guided_params.json, str): + json_str = json.dumps(guided_params.json) + else: + json_str = guided_params.json + return cls(json_str=json_str, + vocab_size=len(tokenizer.vocab.keys()), + max_threads=max_threads, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str) + elif guided_params.grammar: + return cls(grammar_str=guided_params.grammar, + vocab_size=len(tokenizer.vocab.keys()), + max_threads=max_threads, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str) + else: + raise ValueError( + "Currently only support JSON and EBNF grammar mode for xgrammar" + ) + + def create_tokenizer_info(self): + return xgr.TokenizerInfo._create_from_handle( + xgr_core.TokenizerInfo.from_huggingface(self.encoded_vocab, + self.backend_str, + self.vocab_size, + self.stop_token_ids)) + + +@dataclass +class XGrammarLogitsProcessor: + """Wrapper class to support pickle protocol""" + config: GrammarConfig + + ctx: xgr.CompiledGrammar | None = None + matchers: list[xgr.GrammarMatcher] = field(default_factory=list) + batch_size: int = 1 + token_bitmask: torch.Tensor = None + prefilled: bool = False + + def __getstate__(self) -> dict[str, Any]: + return {'config': self.config} + + def __setstate__(self, state: dict[str, Any]): + self.config = state['config'] + + self.ctx = None + self.matchers = [] + self.batch_size = 1 + self.token_bitmask = None + self.prefilled = False + + def _ensure_ctx(self): + """Lazily initialize the processor in the worker process""" + if self.ctx is None: + compiler = xgr.GrammarCompiler(self.config.create_tokenizer_info(), + max_threads=self.config.max_threads) + + if self.config.json_str is not None: + self.ctx = compiler.compile_json_schema(self.config.json_str) + else: + self.ctx = compiler.compile_grammar(self.config.grammar_str) + + def __call__(self, input_ids: list[int], + scores: torch.Tensor) -> torch.Tensor: + if self.ctx is None: + self._ensure_ctx() + + if len(self.matchers) == 0: + self.matchers = [ + xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + ] + self.token_bitmask = xgr.allocate_token_bitmask( + self.batch_size, self.config.vocab_size) + + if not self.prefilled: + # Have not sampled a token yet + self.prefilled = True + else: + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + sampled_token = input_ids[-1] + assert self.matchers[i].accept_token(sampled_token) + + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + matcher.fill_next_token_bitmask(self.token_bitmask, i) + + device_type = scores.device.type + if device_type != "cuda": + scores = scores.to("cpu") + xgr.apply_token_bitmask_inplace(scores, + self.token_bitmask.to(scores.device)) + if device_type != "cuda": + scores = scores.to(device_type) + + return scores \ No newline at end of file From 7630b62ea85016ba93a7599593e2162029aad8b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= Date: Sun, 1 Dec 2024 01:00:29 +0000 Subject: [PATCH 2/4] [CI/Build] add test for `xgrammar` guided generation backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jannis Schönleber --- tests/model_executor/test_guided_processors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 45fab8e96b968..9f4d81b583141 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.asyncio -@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) +@pytest.mark.parametrize("backend", + ["outlines", "lm-format-enforcer", "xgrammar"]) async def test_guided_logits_processor_black_box(backend: str, sample_regex, sample_json_schema): tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') From 33682d4dd1b28c487f0a57786ed4d251296a6c1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= Date: Sun, 1 Dec 2024 10:07:08 +0000 Subject: [PATCH 3/4] [Core] add `caching` mechanism to `xgrammar` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit credits to: https://github.com/vllm-project/vllm/pull/10785/commits/865e2a3b4a6d5e5b12ef5ce9fffdfb624b5a3ab6 Signed-off-by: Jannis Schönleber --- .../guided_decoding/xgrammar_decoding.py | 143 +++++++++++++----- 1 file changed, 101 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 192dc847c521e..b01508e3c3e4e 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NamedTuple import torch from transformers import PreTrainedTokenizerFast @@ -31,16 +31,90 @@ def get_local_xgrammar_guided_decoding_logits_processor( return XGrammarLogitsProcessor(config) +class TokenizerData(NamedTuple): + """Immutable container for cached tokenizer data.""" + encoded_vocab: list[str] + stop_token_ids: list[int] | None + backend_str: str + + +class TokenizerDataCache: + """Cache manager for tokenizer data to avoid repeated processing.""" + _cache: dict[int, TokenizerData] = {} + + @classmethod + def get_tokenizer_data(cls, + tokenizer: PreTrainedTokenizer) -> TokenizerData: + tokenizer_hash = hash(tokenizer) + + if tokenizer_hash not in cls._cache: + # Vendored from xgrammar logic since we cannot pickle the tokenizer + # https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501 + try: + encoded_vocab = [ + token for token, _ in sorted(tokenizer.get_vocab().items(), + key=lambda x: x[1]) + ] + except AttributeError as e: + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"{type(tokenizer)}. The tokenizer should have a " + "get_vocab method.") from e + + stop_token_ids = None + backend_str = xgr.VocabType.RAW + if isinstance(tokenizer, PreTrainedTokenizerFast): + backend_str = tokenizer.backend_tokenizer.to_str() + if stop_token_ids is None and hasattr( + tokenizer, + "eos_token_id") and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + + cls._cache[tokenizer_hash] = TokenizerData( + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str) + + return cls._cache[tokenizer_hash] + + +class GrammarCompilerCache: + """ + Cache for GrammarCompiler instances based on tokenizer. + + This cache reduces the overhead of creating new compiler instances when + using the same tokenizer configuration. + """ + _cache: dict[str, xgr.GrammarCompiler] = {} + + @classmethod + def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: + cache_key = str(config.tokenizer_hash) + + if cache_key not in cls._cache: + assert config.encoded_vocab is not None + tokenizer_info = xgr.TokenizerInfo._create_from_handle( + xgr_core.TokenizerInfo.from_huggingface( + config.encoded_vocab, config.backend_str, + config.vocab_size, config.stop_token_ids)) + cls._cache[cache_key] = xgr.GrammarCompiler( + tokenizer_info, max_threads=config.max_threads) + + return cls._cache[cache_key] + + @dataclass class GrammarConfig: """Serializable configuration for grammar compilation""" - vocab_size: int = 0 - max_threads: int = 8 + tokenizer_hash: int + vocab_size: int json_str: str | None = None grammar_str: str | None = None - encoded_vocab: dict[str, int] | None = None + max_threads: int = 8 + # Only populated if tokenizer_hash not in cache + encoded_vocab: list[str] | None = None stop_token_ids: list[int] | None = None - backend_str: str = "" + backend_str: str | None = None @classmethod def from_guided_params(cls, @@ -48,27 +122,17 @@ def from_guided_params(cls, tokenizer: PreTrainedTokenizer, max_threads: int = 8) -> GrammarConfig: - # Vendorred from xgrammar logics - try: - encoded_vocab = tokenizer.get_vocab() - encoded_vocab = [ - token for token, _ in sorted(encoded_vocab.items(), - key=lambda x: x[1]) - ] - except AttributeError as e: - raise ValueError( - f"Cannot get the vocabulary of the tokenizer {type(tokenizer)}." - " The tokenizer should have a get_vocab method.") from e - - stop_token_ids = None - backend_str = xgr.VocabType.RAW - if isinstance(tokenizer, PreTrainedTokenizerFast): - # the vocabulary is directly obtained from tokenizer.get_vocab() - backend_str = tokenizer.backend_tokenizer.to_str() - if stop_token_ids is None and hasattr( - tokenizer, - "eos_token_id") and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] + tokenizer_hash = hash(tokenizer) + # Only get tokenizer data if not already cached + if tokenizer_hash in TokenizerDataCache._cache: + encoded_vocab = None + stop_token_ids = None + backend_str = None + else: + tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) + encoded_vocab = tokenizer_data.encoded_vocab + stop_token_ids = tokenizer_data.stop_token_ids + backend_str = tokenizer_data.backend_str if guided_params.json: if not isinstance(guided_params.json, str): @@ -76,30 +140,25 @@ def from_guided_params(cls, else: json_str = guided_params.json return cls(json_str=json_str, - vocab_size=len(tokenizer.vocab.keys()), - max_threads=max_threads, + vocab_size=len(tokenizer.get_vocab().keys()), encoded_vocab=encoded_vocab, stop_token_ids=stop_token_ids, - backend_str=backend_str) + backend_str=backend_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads) elif guided_params.grammar: return cls(grammar_str=guided_params.grammar, - vocab_size=len(tokenizer.vocab.keys()), - max_threads=max_threads, + vocab_size=len(tokenizer.get_vocab().keys()), encoded_vocab=encoded_vocab, stop_token_ids=stop_token_ids, - backend_str=backend_str) + backend_str=backend_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads) else: raise ValueError( "Currently only support JSON and EBNF grammar mode for xgrammar" ) - def create_tokenizer_info(self): - return xgr.TokenizerInfo._create_from_handle( - xgr_core.TokenizerInfo.from_huggingface(self.encoded_vocab, - self.backend_str, - self.vocab_size, - self.stop_token_ids)) - @dataclass class XGrammarLogitsProcessor: @@ -127,9 +186,7 @@ def __setstate__(self, state: dict[str, Any]): def _ensure_ctx(self): """Lazily initialize the processor in the worker process""" if self.ctx is None: - compiler = xgr.GrammarCompiler(self.config.create_tokenizer_info(), - max_threads=self.config.max_threads) - + compiler = GrammarCompilerCache.get_compiler(self.config) if self.config.json_str is not None: self.ctx = compiler.compile_json_schema(self.config.json_str) else: @@ -160,6 +217,8 @@ def __call__(self, input_ids: list[int], if not matcher.is_terminated(): matcher.fill_next_token_bitmask(self.token_bitmask, i) + # token_bitmask is a CPU tensor for use with accept_token and + # fill_next_token_bitmask so we move it to the device of scores device_type = scores.device.type if device_type != "cuda": scores = scores.to("cpu") From c241a7e6e2406cc3be295d4cee8c65cef4636a57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jannis=20Sch=C3=B6nleber?= Date: Sun, 1 Dec 2024 10:25:35 +0000 Subject: [PATCH 4/4] [Frontend] log `xgrammar` compile time for long compilations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jannis Schönleber --- .../guided_decoding/xgrammar_decoding.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index b01508e3c3e4e..bfb739cbd5ec7 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -2,12 +2,15 @@ from __future__ import annotations import json +import time from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, NamedTuple import torch from transformers import PreTrainedTokenizerFast +from vllm.logger import init_logger + try: import xgrammar as xgr from xgrammar.base import _core as xgr_core @@ -19,6 +22,8 @@ from vllm.sampling_params import GuidedDecodingParams +logger = init_logger(__name__) + # TODO: passing batch size to max threads here def get_local_xgrammar_guided_decoding_logits_processor( @@ -186,11 +191,17 @@ def __setstate__(self, state: dict[str, Any]): def _ensure_ctx(self): """Lazily initialize the processor in the worker process""" if self.ctx is None: + start_time = time.time() compiler = GrammarCompilerCache.get_compiler(self.config) if self.config.json_str is not None: self.ctx = compiler.compile_json_schema(self.config.json_str) else: self.ctx = compiler.compile_grammar(self.config.grammar_str) + compile_time = time.time() - start_time + if compile_time > 0.1: + log_str = ("Xgrammar compiled grammar/schema " + f"in {compile_time:.2f} seconds") + logger.info(log_str) def __call__(self, input_ids: list[int], scores: torch.Tensor) -> torch.Tensor: @@ -227,4 +238,4 @@ def __call__(self, input_ids: list[int], if device_type != "cuda": scores = scores.to(device_type) - return scores \ No newline at end of file + return scores