Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fallback to outlines for complex json schemas #10899

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/entrypoints/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,37 @@ def sample_json_schema():
}


@pytest.fixture
def sample_complex_json_schema():
return {
"type": "object",
"properties": {
"score": {
"type": "integer",
"minimum": 0,
"maximum": 100 # Numeric range
},
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"email": {
"type": "string",
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
},
"tags": {
"type": "array",
"items": {
"type": "string",
"pattern":
"^[a-z]{1,10}$" # Combining length and pattern restrictions
}
}
},
"required": ["score", "grade", "email", "tags"]
}


@pytest.fixture
def sample_guided_choice():
return [
Expand Down
28 changes: 28 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,34 @@ def test_guided_json_completion(sample_json_schema, llm):
jsonschema.validate(instance=output_json, schema=sample_json_schema)


@pytest.mark.skip_global_cleanup
def test_guided_complex_json_completion(sample_complex_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_complex_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an assignment grade "
f"that fits this schema: {sample_complex_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)

assert outputs is not None

for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_complex_json_schema)


@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
Expand Down
43 changes: 43 additions & 0 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,40 @@
logger = init_logger(__name__)


def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""

def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False

# Check for pattern restrictions
if "pattern" in obj:
return True

# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
"minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf"
]):
return True

# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True

return False

return check_object(schema)


def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
# lm-format-enforce doesn't support grammar, fallback to xgrammar
Expand Down Expand Up @@ -47,6 +81,15 @@ def maybe_backend_fallback(
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

# xgrammar doesn't support some JSON schema features
elif (guided_params.json is not None
and has_xgrammar_unsupported_json_features(guided_params.json)):
logger.warning(
"xgrammar does not support advanced JSON schema features like "
"patterns or numeric ranges. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

return guided_params


Expand Down
Loading