Skip to content

Commit

Permalink
[Bugfix] Fallback to outlines for complex json schemas (vllm-project#…
Browse files Browse the repository at this point in the history
…10899)

Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin authored and ZenPuzzle committed Dec 24, 2024
1 parent b1bd000 commit 724beae
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/entrypoints/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,37 @@ def sample_json_schema():
}


@pytest.fixture
def sample_complex_json_schema():
return {
"type": "object",
"properties": {
"score": {
"type": "integer",
"minimum": 0,
"maximum": 100 # Numeric range
},
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"email": {
"type": "string",
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
},
"tags": {
"type": "array",
"items": {
"type": "string",
"pattern":
"^[a-z]{1,10}$" # Combining length and pattern restrictions
}
}
},
"required": ["score", "grade", "email", "tags"]
}


@pytest.fixture
def sample_guided_choice():
return [
Expand Down
28 changes: 28 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,34 @@ def test_guided_json_completion(sample_json_schema, llm):
jsonschema.validate(instance=output_json, schema=sample_json_schema)


@pytest.mark.skip_global_cleanup
def test_guided_complex_json_completion(sample_complex_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_complex_json_schema))
outputs = llm.generate(prompts=[
f"Give an example JSON for an assignment grade "
f"that fits this schema: {sample_complex_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)

assert outputs is not None

for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_complex_json_schema)


@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
Expand Down
43 changes: 43 additions & 0 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,40 @@
logger = init_logger(__name__)


def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""

def check_object(obj: dict) -> bool:
if not isinstance(obj, dict):
return False

# Check for pattern restrictions
if "pattern" in obj:
return True

# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj for key in [
"minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf"
]):
return True

# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True

return False

return check_object(schema)


def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
# lm-format-enforce doesn't support grammar, fallback to xgrammar
Expand All @@ -40,6 +74,15 @@ def maybe_backend_fallback(
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

# xgrammar doesn't support some JSON schema features
elif (guided_params.json is not None
and has_xgrammar_unsupported_json_features(guided_params.json)):
logger.warning(
"xgrammar does not support advanced JSON schema features like "
"patterns or numeric ranges. "
"Falling back to use outlines instead.")
guided_params.backend = "outlines"

return guided_params


Expand Down

0 comments on commit 724beae

Please sign in to comment.