Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] add xgrammar as guided generation provider #10803

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11
outlines >= 0.0.43, < 0.1
xgrammar
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs
Expand Down
3 changes: 2 additions & 1 deletion tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):


@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("backend",
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ class DecodingConfig:
guided_decoding_backend: str = 'outlines'

def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer']
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
backend = self.guided_decoding_backend
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'--guided-decoding-backend',
type=str,
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines and '
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ async def get_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
Expand All @@ -40,6 +45,11 @@ def get_local_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == "xgrammar":
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
Expand Down
241 changes: 241 additions & 0 deletions vllm/model_executor/guided_decoding/xgrammar_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# noqa: UP007
from __future__ import annotations

import json
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, NamedTuple

import torch
from transformers import PreTrainedTokenizerFast

from vllm.logger import init_logger

try:
import xgrammar as xgr
from xgrammar.base import _core as xgr_core
except ImportError:
pass

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer

from vllm.sampling_params import GuidedDecodingParams

logger = init_logger(__name__)


# TODO: passing batch size to max threads here
def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
tokenizer=tokenizer,
max_threads=max_threads)
return XGrammarLogitsProcessor(config)


class TokenizerData(NamedTuple):
"""Immutable container for cached tokenizer data."""
encoded_vocab: list[str]
stop_token_ids: list[int] | None
backend_str: str


class TokenizerDataCache:
"""Cache manager for tokenizer data to avoid repeated processing."""
_cache: dict[int, TokenizerData] = {}

@classmethod
def get_tokenizer_data(cls,
tokenizer: PreTrainedTokenizer) -> TokenizerData:
tokenizer_hash = hash(tokenizer)

if tokenizer_hash not in cls._cache:
# Vendored from xgrammar logic since we cannot pickle the tokenizer
# https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501
try:
encoded_vocab = [
token for token, _ in sorted(tokenizer.get_vocab().items(),
key=lambda x: x[1])
]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e

stop_token_ids = None
backend_str = xgr.VocabType.RAW
if isinstance(tokenizer, PreTrainedTokenizerFast):
backend_str = tokenizer.backend_tokenizer.to_str()
if stop_token_ids is None and hasattr(
tokenizer,
"eos_token_id") and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]

cls._cache[tokenizer_hash] = TokenizerData(
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str)

return cls._cache[tokenizer_hash]


class GrammarCompilerCache:
"""
Cache for GrammarCompiler instances based on tokenizer.

This cache reduces the overhead of creating new compiler instances when
using the same tokenizer configuration.
"""
_cache: dict[str, xgr.GrammarCompiler] = {}

@classmethod
def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler:
cache_key = str(config.tokenizer_hash)

if cache_key not in cls._cache:
assert config.encoded_vocab is not None
tokenizer_info = xgr.TokenizerInfo._create_from_handle(
xgr_core.TokenizerInfo.from_huggingface(
config.encoded_vocab, config.backend_str,
config.vocab_size, config.stop_token_ids))
cls._cache[cache_key] = xgr.GrammarCompiler(
tokenizer_info, max_threads=config.max_threads)

return cls._cache[cache_key]


@dataclass
class GrammarConfig:
"""Serializable configuration for grammar compilation"""
tokenizer_hash: int
vocab_size: int
json_str: str | None = None
grammar_str: str | None = None
max_threads: int = 8
# Only populated if tokenizer_hash not in cache
encoded_vocab: list[str] | None = None
stop_token_ids: list[int] | None = None
backend_str: str | None = None

@classmethod
def from_guided_params(cls,
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
max_threads: int = 8) -> GrammarConfig:

tokenizer_hash = hash(tokenizer)
# Only get tokenizer data if not already cached
if tokenizer_hash in TokenizerDataCache._cache:
encoded_vocab = None
stop_token_ids = None
backend_str = None
else:
tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer)
encoded_vocab = tokenizer_data.encoded_vocab
stop_token_ids = tokenizer_data.stop_token_ids
backend_str = tokenizer_data.backend_str

if guided_params.json:
if not isinstance(guided_params.json, str):
json_str = json.dumps(guided_params.json)
else:
json_str = guided_params.json
return cls(json_str=json_str,
vocab_size=len(tokenizer.get_vocab().keys()),
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
elif guided_params.grammar:
return cls(grammar_str=guided_params.grammar,
vocab_size=len(tokenizer.get_vocab().keys()),
encoded_vocab=encoded_vocab,
stop_token_ids=stop_token_ids,
backend_str=backend_str,
tokenizer_hash=tokenizer_hash,
max_threads=max_threads)
else:
raise ValueError(
"Currently only support JSON and EBNF grammar mode for xgrammar"
)


@dataclass
class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol"""
config: GrammarConfig

ctx: xgr.CompiledGrammar | None = None
matchers: list[xgr.GrammarMatcher] = field(default_factory=list)
batch_size: int = 1
token_bitmask: torch.Tensor = None
prefilled: bool = False

def __getstate__(self) -> dict[str, Any]:
return {'config': self.config}

def __setstate__(self, state: dict[str, Any]):
self.config = state['config']

self.ctx = None
self.matchers = []
self.batch_size = 1
self.token_bitmask = None
self.prefilled = False

def _ensure_ctx(self):
"""Lazily initialize the processor in the worker process"""
if self.ctx is None:
start_time = time.time()
compiler = GrammarCompilerCache.get_compiler(self.config)
if self.config.json_str is not None:
self.ctx = compiler.compile_json_schema(self.config.json_str)
else:
self.ctx = compiler.compile_grammar(self.config.grammar_str)
compile_time = time.time() - start_time
if compile_time > 0.1:
log_str = ("Xgrammar compiled grammar/schema "
f"in {compile_time:.2f} seconds")
logger.info(log_str)

def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
if self.ctx is None:
self._ensure_ctx()

if len(self.matchers) == 0:
self.matchers = [
xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size)
]
self.token_bitmask = xgr.allocate_token_bitmask(
self.batch_size, self.config.vocab_size)

if not self.prefilled:
# Have not sampled a token yet
self.prefilled = True
else:
for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
sampled_token = input_ids[-1]
assert self.matchers[i].accept_token(sampled_token)

for i, matcher in enumerate(self.matchers):
if not matcher.is_terminated():
matcher.fill_next_token_bitmask(self.token_bitmask, i)

# token_bitmask is a CPU tensor for use with accept_token and
# fill_next_token_bitmask so we move it to the device of scores
device_type = scores.device.type
if device_type != "cuda":
scores = scores.to("cpu")
xgr.apply_token_bitmask_inplace(scores,
self.token_bitmask.to(scores.device))
if device_type != "cuda":
scores = scores.to(device_type)

return scores