From 0b3e05742ddb69edc150c0a8f92a6e1bc62287ee Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 30 Dec 2024 16:41:41 +0000 Subject: [PATCH] Fix OpenAI parallel sampling when using xgrammar Signed-off-by: mgoin --- tests/entrypoints/openai/test_completion.py | 19 ++++++++++--------- .../guided_decoding/xgrammar_decoding.py | 5 +++++ vllm/sampling_params.py | 9 +++++---- vllm/sequence.py | 2 +- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index c81cfdbbe5cff..acdd800bee470 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -28,6 +28,8 @@ # need to change to match the prompt adapter PA_NUM_VIRTUAL_TOKENS = 8 +GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] + @pytest.fixture(scope="module") def zephyr_lora_files(): @@ -635,8 +637,7 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI): @pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) async def test_guided_json_completion(client: openai.AsyncOpenAI, guided_decoding_backend: str, sample_json_schema): @@ -658,8 +659,7 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) async def test_guided_regex_completion(client: openai.AsyncOpenAI, guided_decoding_backend: str, sample_regex): @@ -680,8 +680,7 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) async def test_guided_choice_completion(client: openai.AsyncOpenAI, guided_decoding_backend: str, sample_guided_choice): @@ -701,7 +700,9 @@ async def test_guided_choice_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) async def test_guided_grammar(client: openai.AsyncOpenAI, + guided_decoding_backend: str, sample_sql_statements): completion = await client.completions.create( @@ -710,7 +711,8 @@ async def test_guided_grammar(client: openai.AsyncOpenAI, "table_1 where it is equals to 1"), temperature=1.0, max_tokens=500, - extra_body=dict(guided_grammar=sample_sql_statements)) + extra_body=dict(guided_grammar=sample_sql_statements, + guided_decoding_backend=guided_decoding_backend)) content = completion.choices[0].text @@ -761,8 +763,7 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize("guided_decoding_backend", - ["outlines", "lm-format-enforcer"]) +@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, guided_decoding_backend: str, sample_json_schema, sample_regex): diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 5e1948977bff4..f10a8fb8e03cf 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -1,6 +1,7 @@ # noqa: UP007 from __future__ import annotations +import copy import json from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -309,3 +310,7 @@ def __call__(self, input_ids: list[int], scores = scores.to(device_type).squeeze() return scores + + def clone(self) -> XGrammarLogitsProcessor: + """Deepcopy due to per-sequence state in the matchers""" + return copy.deepcopy(self) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index fc77f3ca529b2..605c09b8d7225 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -450,15 +450,16 @@ def all_stop_token_ids(self) -> Set[int]: return self._all_stop_token_ids def clone(self) -> "SamplingParams": - """Deep copy excluding LogitsProcessor objects. + """Deep copy, but maybe not the LogitsProcessor objects. - LogitsProcessor objects are excluded because they may contain an - arbitrary, nontrivial amount of data. + LogitsProcessor objects may contain an arbitrary, nontrivial amount of + data that is expensive to copy. However, if not copied, the processor + needs to support parallel decoding for multiple sequences See https://github.com/vllm-project/vllm/issues/3087 """ logit_processor_refs = None if self.logits_processors is None else { - id(lp): lp + id(lp): lp.clone() if hasattr(lp, 'clone') else lp for lp in self.logits_processors } return copy.deepcopy(self, memo=logit_processor_refs) diff --git a/vllm/sequence.py b/vllm/sequence.py index 34f910d47b7d9..034f89c0ddbe9 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1372,7 +1372,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase): @staticmethod def add_request(request_id: str, engine, params, **kwargs): original_params = params - params = copy.deepcopy(original_params) + params = original_params.clone() params.n = 1 group = ParallelSampleSequenceGroup(request_id) seqs = []