Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Performance] Add XGrammar support for guided decoding and set it as default #10785

Merged
merged 20 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def setup(app):
"tensorizer",
"pynvml",
"outlines",
"xgrammar,"
"librosa",
"soundfile",
"gguf",
Expand Down
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.9, < 0.11
outlines >= 0.0.43, < 0.1
xgrammar
typing_extensions >= 4.10
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
partial-json-parser # used for parsing partial JSON outputs
Expand Down
3 changes: 2 additions & 1 deletion tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):


@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
@pytest.mark.parametrize("backend",
["outlines", "lm-format-enforcer", "xgrammar"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
sample_json_schema):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
Expand Down
15 changes: 8 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,15 +1753,15 @@ class PoolerConfig:

step_tag_id: Optional[int] = None
"""
If set, only the score corresponding to the ``step_tag_id`` in the
If set, only the score corresponding to the ``step_tag_id`` in the
generated sentence should be returned. Otherwise, the scores for all tokens
are returned.
"""

returned_token_ids: Optional[List[int]] = None
"""
A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of ``good_token`` and ``bad_token`` in the
A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of ``good_token`` and ``bad_token`` in the
``math-shepherd-mistral-7b-prm`` model.
"""

Expand Down Expand Up @@ -1995,11 +1995,12 @@ def get_served_model_name(model: str,
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine"""

# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
guided_decoding_backend: str = 'outlines'
# Which guided decoding algo to use.
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = 'xgrammar'

def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer']
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
backend = self.guided_decoding_backend
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
Expand Down Expand Up @@ -2104,7 +2105,7 @@ class CompilationConfig(BaseModel):
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- custom inductor passes: see PassConfig for more details

Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
for the same size. We need to capture all the sizes we want to use.
Expand Down
9 changes: 5 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class EngineArgs:
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: Optional[bool] = None

guided_decoding_backend: str = 'outlines'
guided_decoding_backend: str = 'xgrammar'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
Expand Down Expand Up @@ -361,11 +361,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
'--guided-decoding-backend',
type=str,
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
default='xgrammar',
choices=['outlines', 'lm-format-enforcer', 'xgrammar'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines and '
'https://github.com/outlines-dev/outlines,'
'https://github.com/mlc-ai/xgrammar, and '
'https://github.com/noamgat/lm-format-enforcer.'
' Can be overridden per request via guided_decoding_backend'
' parameter.')
Expand Down
10 changes: 7 additions & 3 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,8 @@ async def add_request_async(
sampling_params=params,
tokenizer=await self.get_tokenizer_async(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)
guided_decoding_backend,
model_config=self.model_config)

self._add_processed_request(
request_id=request_id,
Expand All @@ -525,7 +526,8 @@ async def check_health_async(self) -> None:

async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str) -> SamplingParams:
default_guided_backend: str,
model_config: ModelConfig) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
Expand All @@ -540,7 +542,9 @@ async def build_guided_decoding_logits_processor_async(
guided_decoding.backend = guided_decoding.backend or default_guided_backend

processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=model_config)

if processor:
if sampling_params.logits_processors is None:
Expand Down
8 changes: 5 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,9 +1023,9 @@ def _update_num_computed_tokens_for_multi_step_prefill(
This function updates num_computed_tokens for prompt sequences
when Multi-Step is enabled.

seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group_meta: Metadata of the given SequenceGroup.
is_first_step_output: Optional[bool] -
is_first_step_output: Optional[bool] -
When available, is_first_step_output indicates if the appended
output token is the output of the first-step in multi-step.
A value of None indicates that outputs from all steps in
Expand Down Expand Up @@ -2046,7 +2046,9 @@ def _build_logits_processors(
self.decoding_config.guided_decoding_backend

processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
guided_params=guided_decoding,
tokenizer=tokenizer,
model_config=self.model_config)
if processor:
logits_processors.append(processor)

Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@ def generate(
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
if inputs is not None:
Expand Down Expand Up @@ -586,6 +586,7 @@ async def _process_request(
default_guided_backend=(self.decoding_config.guided_decoding_backend
if self.decoding_config
else DecodingConfig.guided_decoding_backend),
model_config=self.model_config
)

# 1) Create output queue for this requests.
Expand Down
34 changes: 25 additions & 9 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import Optional
from __future__ import annotations

from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer

from vllm.config import ModelConfig
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams


async def get_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
Copy link

@Ubospica Ubospica Nov 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XGrammar can also do grammar decoding and accelerate it. The grammar formats for XGrammar and Outlines are different. XGrammar uses GBNF format, while Outlines uses lark grammar. That might be documented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, I will add this difference into the docs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just remove the grammar check here.

If user send grammar they should also specify the backend (probably better to document the cartesian product of the combinations)

# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
Expand All @@ -19,15 +25,20 @@ async def get_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")


def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
Expand All @@ -40,7 +51,12 @@ def get_local_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
Loading