diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index 0f7d15e1d85aa..ef74062ce4b41 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -100,6 +100,45 @@ def sample_complex_json_schema(): } +@pytest.fixture +def sample_definition_json_schema(): + return { + '$defs': { + 'Step': { + 'properties': { + 'explanation': { + 'title': 'Explanation', + 'type': 'string' + }, + 'output': { + 'title': 'Output', + 'type': 'string' + } + }, + 'required': ['explanation', 'output'], + 'title': 'Step', + 'type': 'object' + } + }, + 'properties': { + 'steps': { + 'items': { + '$ref': '#/$defs/Step' + }, + 'title': 'Steps', + 'type': 'array' + }, + 'final_answer': { + 'title': 'Final Answer', + 'type': 'string' + } + }, + 'required': ['steps', 'final_answer'], + 'title': 'MathReasoning', + 'type': 'object' + } + + @pytest.fixture def sample_guided_choice(): return [ diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index de6257cfc551c..ed50ec6bbc9eb 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -104,6 +104,34 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm): schema=sample_complex_json_schema) +@pytest.mark.skip_global_cleanup +def test_guided_definition_json_completion(sample_definition_json_schema, llm): + sampling_params = SamplingParams(temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_definition_json_schema)) + outputs = llm.generate(prompts=[ + f"Give an example JSON for solving 8x + 7 = -23 " + f"that fits this schema: {sample_definition_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, + schema=sample_definition_json_schema) + + @pytest.mark.skip_global_cleanup def test_guided_choice_completion(sample_guided_choice, llm): sampling_params = SamplingParams( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 6ed7c2e9dcd6b..5a70e0952666b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -387,7 +387,7 @@ def to_sampling_params( assert json_schema is not None self.guided_json = json_schema.json_schema if self.guided_decoding_backend is None: - self.guided_decoding_backend = "lm-format-enforcer" + self.guided_decoding_backend = "xgrammar" guided_decoding = GuidedDecodingParams.from_optional( json=self._get_guided_json_from_tool() or self.guided_json,