From 724beaed7c2b892d219605d9b257bdf344c643f2 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 4 Dec 2024 22:14:06 -0500 Subject: [PATCH] [Bugfix] Fallback to outlines for complex json schemas (#10899) Signed-off-by: mgoin --- tests/entrypoints/conftest.py | 31 +++++++++++++ tests/entrypoints/llm/test_guided_generate.py | 28 ++++++++++++ .../guided_decoding/__init__.py | 43 +++++++++++++++++++ 3 files changed, 102 insertions(+) diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index e7ef5637c8ccb..0f7d15e1d85aa 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -69,6 +69,37 @@ def sample_json_schema(): } +@pytest.fixture +def sample_complex_json_schema(): + return { + "type": "object", + "properties": { + "score": { + "type": "integer", + "minimum": 0, + "maximum": 100 # Numeric range + }, + "grade": { + "type": "string", + "pattern": "^[A-D]$" # Regex pattern + }, + "email": { + "type": "string", + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + }, + "tags": { + "type": "array", + "items": { + "type": "string", + "pattern": + "^[a-z]{1,10}$" # Combining length and pattern restrictions + } + } + }, + "required": ["score", "grade", "email", "tags"] + } + + @pytest.fixture def sample_guided_choice(): return [ diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index c3706f696b264..de6257cfc551c 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -76,6 +76,34 @@ def test_guided_json_completion(sample_json_schema, llm): jsonschema.validate(instance=output_json, schema=sample_json_schema) +@pytest.mark.skip_global_cleanup +def test_guided_complex_json_completion(sample_complex_json_schema, llm): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(json=sample_complex_json_schema)) + outputs = llm.generate(prompts=[ + f"Give an example JSON for an assignment grade " + f"that fits this schema: {sample_complex_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, + schema=sample_complex_json_schema) + + @pytest.mark.skip_global_cleanup def test_guided_choice_completion(sample_guided_choice, llm): sampling_params = SamplingParams( diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 23c31fcfd7f05..13beec5676fda 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -14,6 +14,40 @@ logger = init_logger(__name__) +def has_xgrammar_unsupported_json_features(schema: dict) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and any( + key in obj for key in [ + "minimum", "maximum", "exclusiveMinimum", + "exclusiveMaximum", "multipleOf" + ]): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + def maybe_backend_fallback( guided_params: GuidedDecodingParams) -> GuidedDecodingParams: # lm-format-enforce doesn't support grammar, fallback to xgrammar @@ -40,6 +74,15 @@ def maybe_backend_fallback( "Falling back to use outlines instead.") guided_params.backend = "outlines" + # xgrammar doesn't support some JSON schema features + elif (guided_params.json is not None + and has_xgrammar_unsupported_json_features(guided_params.json)): + logger.warning( + "xgrammar does not support advanced JSON schema features like " + "patterns or numeric ranges. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" + return guided_params