Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Performance] Add XGrammar support for guided decoding and set it as default #10785

Merged
merged 20 commits into from
Dec 3, 2024

Conversation

aarnphm
Copy link
Contributor

@aarnphm aarnphm commented Nov 29, 2024

Add initial support for XGrammar for V0 and makes it the default for grammar and json usage. Written in collaboration with @mgoin

I'm using the benchmark scripts from #10557

Results for using XGrammar as backend:

Throughput: 0.94 requests/s, 1022.46 total tokens/s, 480.27 output tokens/s Correct rate is 100.0 %
First token latency(msecs):
count      10.000000
mean     4552.206317
std       734.671745
min      3289.774953
25%      3864.269087
50%      5102.686635
75%      5102.717258
max      5114.346570
dtype: float64
Next token latency(msecs):
count    10.000000
mean     11.906452
std       1.409063
min      10.831970
25%      10.837367
50%      10.854235
75%      13.227200
max      14.325024
dtype: float64

Comparing to outlines

Throughput: 0.22 requests/s, 241.22 total tokens/s, 113.31 output tokens/s Correct rate is 100.0 %
First token latency(msecs):
count       10.000000
mean     38533.083248
std         35.807892
min      38491.813741
25%      38491.826321
50%      38556.601226
75%      38556.628519
max      38568.547848
dtype: float64
Next token latency(msecs):
count    10.000000
mean     12.955556
std       0.042220
min      12.901755
25%      12.914099
50%      12.953058
75%      12.996646
max      13.003127
dtype: float64

NOTE: Running on A100 80GB, with Llama 3.2 3B with chunked prefill enable and JSON grammar

Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: Aaron Pham <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation ci/build labels Nov 29, 2024
Signed-off-by: Aaron Pham <[email protected]>
@aarnphm aarnphm marked this pull request as draft November 29, 2024 23:46
@aarnphm aarnphm marked this pull request as ready for review November 30, 2024 00:16
Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: Aaron Pham <[email protected]>
joennlae added a commit to 44ai-labs/vllm that referenced this pull request Nov 30, 2024
Copy link

@Ubospica Ubospica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution to integrating XGrammar into vLLM! It overall looks good, but there are some minor points to enhance parallelism.

guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]:
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig) -> LogitsProcessor | None:
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
Copy link

@Ubospica Ubospica Nov 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XGrammar can also do grammar decoding and accelerate it. The grammar formats for XGrammar and Outlines are different. XGrammar uses GBNF format, while Outlines uses lark grammar. That might be documented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, I will add this difference into the docs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just remove the grammar check here.

If user send grammar they should also specify the backend (probably better to document the cartesian product of the combinations)

joennlae added a commit to 44ai-labs/vllm that referenced this pull request Dec 1, 2024
Essentially a cleaned up version of this `pr`:
vllm-project#10785

Especially since `outlines` is rather slow and the new version is though
to intergrate as they do not focus on being pickleable which is a key
feature for us using the multiprocessing engine: dottxt-ai/outlines-core#99

I assume more and more will change over to `xgrammar`.

This is a minimum implementation.

https://arxiv.org/pdf/2411.15100

Signed-off-by: Jannis Schönleber <[email protected]>
@mgoin
Copy link
Member

mgoin commented Dec 1, 2024

Updated this PR with caches for the tokenizer data and the grammar compiler to avoid constructing these data structures for each request. It isn't pretty but it boosts throughput by about 1.4x.

I need to perform more profiling but we are limited by the required-serialization architecture that we currently have. We plan to move the FSM initialization out of the frontend to both simplify the implementation and speed up TTFT.

Setup: Llama-3.1-8B-Instruct, 1xH100

Command:

python benchmark_guided.py --model meta-llama/Llama-3.1-8B-Instruct --dataset xgrammar_bench --async-engine --output-len 512 --num-prompts 20 --enable-chunked-prefill --guided-decoding-ratio 1

Before:

Throughput: 1.46 requests/s, 1189.12 total tokens/s, 748.00 output tokens/s Correct rate is 95.0 % 
First token latency(msecs):
count      20.000000
mean     7180.142369
std      1212.973158
min      4644.173431
25%      7012.610644
50%      7578.541221
75%      8079.524654
max      8092.886029
dtype: float64
Next token latency(msecs):
count    20.000000
mean     12.662371
std       2.336552
min      10.942158
25%      10.942283
50%      11.864077
75%      12.990130
max      17.550802
dtype: float64

After:

Throughput: 2.12 requests/s, 1726.67 total tokens/s, 1086.13 output tokens/s Correct rate is 95.0 % 
First token latency(msecs):
count      20.000000
mean     3254.682581
std       290.516334
min      2869.083916
25%      2869.120228
50%      3449.280638
75%      3477.460549
max      3477.504314
dtype: float64
Next token latency(msecs):
count    20.000000
mean     12.054585
std       0.550868
min      11.643879
25%      11.643967
50%      11.674903
75%      12.786106
max      12.786302
dtype: float64

joennlae added a commit to 44ai-labs/vllm that referenced this pull request Dec 1, 2024
Essentially a cleaned up version of this `pr`:
vllm-project#10785

Especially since `outlines` is rather slow and the new version is though
to intergrate as they do not focus on being pickleable which is a key
feature for us using the multiprocessing engine: dottxt-ai/outlines-core#99

I assume more and more will change over to `xgrammar`.

This is a minimum implementation.

https://arxiv.org/pdf/2411.15100

Signed-off-by: Jannis Schönleber <[email protected]>
@mgoin
Copy link
Member

mgoin commented Dec 2, 2024

@Ubospica do you know when XGrammar can support regex? This would help with covering existing use cases

@mgoin mgoin changed the title feat(guided): xgrammar support [Core][Performance] Add XGrammar support for guided decoding Dec 2, 2024
@joennlae
Copy link
Contributor

joennlae commented Dec 2, 2024

@mgoin I added a pull request yesterday that adds some simple regex pattern + integer ranges support:

mlc-ai/xgrammar#106

@mergify mergify bot added the frontend label Dec 2, 2024
@simon-mo simon-mo changed the title [Core][Performance] Add XGrammar support for guided decoding [Core][Performance] Add XGrammar support for guided decoding and set it as default Dec 3, 2024
simon-mo
simon-mo previously approved these changes Dec 3, 2024
vllm/entrypoints/llm.py Outdated Show resolved Hide resolved
@simon-mo simon-mo dismissed their stale review December 3, 2024 01:41

if isinstance(params, Sequence) else copy.copy(params), is actually a blocking review. We can only introduce it if it is not perf regression.

@mgoin
Copy link
Member

mgoin commented Dec 3, 2024

Thanks for review @simon-mo I moved the copy into a specific if sampling_params.guided_decoding is not None case - ready for re-review

@DarkLight1337 DarkLight1337 merged commit 9323a31 into vllm-project:main Dec 3, 2024
73 checks passed
@hmellor
Copy link
Collaborator

hmellor commented Dec 3, 2024

The new dependency in this PR appears to have broken installation on ARM

8.373 ERROR: Could not find a version that satisfies the requirement xgrammar (from versions: none)
8.419 ERROR: No matching distribution found for xgrammar
------
Dockerfile.arm:37
--------------------
  36 |     
  37 | >>> RUN --mount=type=cache,target=/root/.cache/pip \
  38 | >>>     --mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
  39 | >>>     --mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
  40 | >>>     pip install -v -r requirements-cpu.txt
  41 |     
--------------------
ERROR: failed to solve: process "/bin/sh -c pip install -v -r requirements-cpu.txt" did not complete successfully: exit code: 1

@mgoin
Copy link
Member

mgoin commented Dec 3, 2024

Thanks for reporting @hmellor indeed it seems there isn't a manylinux arm wheel available https://pypi.org/project/xgrammar/#files

I'll work on a patch fix

@stefanobranco
Copy link

Obviously super cool to see new integrations, but it does seem a bit hasty to me to immediately change the default? The implementation with outlines core should be able to close the gap after all, and this one does not support regex yet. Or is xgrammar just objectively better?

@joennlae
Copy link
Contributor

joennlae commented Dec 3, 2024

I second this opinion. Currently, the same behaviour cannot be expected from 'grammar`. I added a simple PR with some rudimentary regex + integer range support (mlc-ai/xgrammar#106).

I can attest that it is much faster, especially if one uses dynamic schemas. However, we should use outlines as the default, as it supports more cases for now, and the change is not breaking for many.

I introduced it as an option in my closed PR (#10803). But I forgot it when I discussed it with @mgoin.

@mgoin
Copy link
Member

mgoin commented Dec 3, 2024

Hi @stefanobranco and @joennlae thanks for raising your concern. Our primary concern is immediately improving structured output performance where it is easy to do so while maintaining the same behavior. With xgrammar as the default in supported cases, we still fallback to outlines in several cases covered here https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/guided_decoding/__init__.py#L18-L48

Please let me know if a case isn't being accounted for that is affecting your usage. We do not want to change external behavior. We have several integration tests that I have been using to create these rules, but more test points are certainly welcome!

We have several fast-followup items to reduce the special cases around using xgrammar and improving performance even further in V0. We are also working on enabling outlines>=0.1.8 support with the devs of that project. Then of course we will enable the usage of structured output in V1.

I hope this is helpful context and we will work on making a public roadmap for longer term goals. Please join the #feat-structured-output channel in slack if you want to have more direct discussion with the people working on this.

@Ubospica
Copy link

Ubospica commented Dec 5, 2024

Thanks @stefanobranco, @joennlae, @@mgoin for great feedbacks.

The first initial release of XGrammar focuses on performance across grammar and json schema. We would like to ensure the system is holistically design to ensure zero overhead structure output, which aligns with many users needs we also see.

Now that initial release land, we are working full steam to enable full support for JSON schema and regex. Thank you for these great feedbacks and please feel free to open new issues on XGrammar to give us feedbacks.

Our general mission is to enable bringing flexible, zero-overhead structured generation everywhere, and we are excited to work with the community here to achieve that mission together, thank you for these feedbacks and we love contributions and collaborations to bring better, zero-overhead structured output for everyone

sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
…it as default (vllm-project#10785)

Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>
@ktrapeznikov
Copy link

will this support models that use mistral tokenizers?

ZenPuzzle pushed a commit to ZenPuzzle/vllm that referenced this pull request Dec 19, 2024
…it as default (vllm-project#10785)

Signed-off-by: Aaron Pham <[email protected]>
Signed-off-by: mgoin <[email protected]>
Co-authored-by: mgoin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation frontend performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants