Skip to content

Commit

Permalink
[Core][Performance] Add XGrammar support for guided decoding and set …
Browse files Browse the repository at this point in the history
…it as default (#10785)

Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>
  • Loading branch information
aarnphm and mgoin authored Dec 3, 2024
1 parent 3257d44 commit 9323a31
Show file tree
Hide file tree
Showing 11 changed files with 385 additions and 33 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def linkcode_resolve(domain, info):
"tensorizer",
"pynvml",
"outlines",
"xgrammar,"
"librosa",
"soundfile",
"gguf",
Expand Down
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11
outlines >= 0.0.43, < 0.1
xgrammar
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs
Expand Down
27 changes: 27 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,30 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm):
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))


@pytest.mark.skip_global_cleanup
def test_guided_json_object(llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100,
guided_decoding=GuidedDecodingParams(json_object=True))

outputs = llm.generate(
prompts=("Generate a JSON object describing a person with name "
"and age for John Smith who is 31 years old."),
sampling_params=sampling_params,
use_tqdm=True)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)

generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None

# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
3 changes: 2 additions & 1 deletion tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):


@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("backend",
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
Expand Down
15 changes: 8 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,15 +1789,15 @@ class PoolerConfig:

step_tag_id: Optional[int] = None
"""
If set, only the score corresponding to the ``step_tag_id`` in the
If set, only the score corresponding to the ``step_tag_id`` in the
generated sentence should be returned. Otherwise, the scores for all tokens
are returned.
"""

returned_token_ids: Optional[List[int]] = None
"""
A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of ``good_token`` and ``bad_token`` in the
A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of ``good_token`` and ``bad_token`` in the
``math-shepherd-mistral-7b-prm`` model.
"""

Expand Down Expand Up @@ -2031,11 +2031,12 @@ def get_served_model_name(model: str,
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine"""

# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
guided_decoding_backend: str = 'outlines'
# Which guided decoding algo to use.
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = 'xgrammar'

def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer']
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
backend = self.guided_decoding_backend
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
Expand Down Expand Up @@ -2222,7 +2223,7 @@ class CompilationConfig(BaseModel):
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- custom inductor passes: see PassConfig for more details
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
for the same size. We need to capture all the sizes we want to use.
Expand Down
9 changes: 5 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class EngineArgs:
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: Optional[bool] = None

guided_decoding_backend: str = 'outlines'
guided_decoding_backend: str = 'xgrammar'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
Expand Down Expand Up @@ -364,11 +364,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
'--guided-decoding-backend',
type=str,
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
default='xgrammar',
choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines and '
'https://github.com/outlines-dev/outlines,'
'https://github.com/mlc-ai/xgrammar, and '
'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend'
' parameter.')
Expand Down
18 changes: 14 additions & 4 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import copy
import time
import weakref
from functools import partial
Expand Down Expand Up @@ -507,7 +508,8 @@ async def add_request_async(
sampling_params=params,
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)
guided_decoding_backend,
model_config=self.model_config)

self._add_processed_request(
request_id=request_id,
Expand All @@ -528,22 +530,30 @@ async def check_health_async(self) -> None:

async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str) -> SamplingParams:
default_guided_backend: str,
model_config: ModelConfig) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if (guided_decoding := sampling_params.guided_decoding) is None:
if sampling_params.guided_decoding is None:
return sampling_params

# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding

logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)

guided_decoding.backend = guided_decoding.backend or default_guided_backend

processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=model_config)

if processor:
if sampling_params.logits_processors is None:
Expand Down
15 changes: 11 additions & 4 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import time
from collections import Counter as collectionsCounter
from collections import deque
Expand Down Expand Up @@ -1024,9 +1025,9 @@ def _update_num_computed_tokens_for_multi_step_prefill(
This function updates num_computed_tokens for prompt sequences
when Multi-Step is enabled.
seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group_meta: Metadata of the given SequenceGroup.
is_first_step_output: Optional[bool] -
is_first_step_output: Optional[bool] -
When available, is_first_step_output indicates if the appended
output token is the output of the first-step in multi-step.
A value of None indicates that outputs from all steps in
Expand Down Expand Up @@ -2036,7 +2037,11 @@ def _build_logits_processors(

logits_processors = []

if (guided_decoding := sampling_params.guided_decoding) is not None:
if sampling_params.guided_decoding is not None:
# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding

logger.debug(
"Building guided decoding logits processor in "
Expand All @@ -2047,7 +2052,9 @@ def _build_logits_processors(
self.decoding_config.guided_decoding_backend

processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=self.model_config)
if processor:
logits_processors.append(processor)

Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ def generate(
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
if inputs is not None:
Expand Down Expand Up @@ -589,6 +589,7 @@ async def _process_request(
default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
model_config=self.model_config
)

# 1) Create output queue for this requests.
Expand Down
73 changes: 62 additions & 11 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,54 @@
from typing import Optional
from __future__ import annotations

from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams
from typing import TYPE_CHECKING

from vllm.logger import init_logger

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer

from vllm.config import ModelConfig
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams

logger = init_logger(__name__)


def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if (guided_params.backend == "lm-format-enforcer"
and guided_params.grammar is not None):
logger.warning(
"lm-format-enforcer does not support grammar guided decoding. "
"Falling back to use xgrammar instead.")
guided_params.backend = "xgrammar"

if guided_params.backend == "xgrammar":
# xgrammar doesn't support regex or choice, fallback to outlines
if guided_params.regex is not None or guided_params.choice is not None:
logger.warning(
"xgrammar only supports json or grammar guided decoding. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

# xgrammar only supports EBNF grammars and uses the GBNF format
# https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
elif (guided_params.grammar is not None
and "::=" not in guided_params.grammar):
logger.warning("xgrammar only supports EBNF grammars. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

return guided_params


async def get_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
Expand All @@ -19,17 +59,23 @@ async def get_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")


def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
if guided_params.backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
Expand All @@ -40,7 +86,12 @@ def get_local_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
Loading

0 comments on commit 9323a31

Please sign in to comment.