Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Frontend] Add faster-outlines as guided decoding backend #10277

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.6
outlines >= 0.0.43, < 0.1
faster-outlines >= 2024.11.14
typing_extensions >= 4.10
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
partial-json-parser # used for parsing partial JSON outputs
Expand Down
46 changes: 45 additions & 1 deletion tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,50 @@ def test_guided_logits_processors(sample_regex, sample_json_schema):
assert not torch.allclose(tensor, original_tensor)


@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["faster-outlines"])
async def test_guided_logits_processor_black_box_faster_outlines(
backend: str,
sample_regex, # noqa: E501
sample_json_schema):
# faster-outlines processors require special handling,
# since they do not begin computation of a fsm index until they have been
# both serialized and deserialized. More on the reason why can be found
# in the doc comment for the class:
# https://github.com/unaidedelf8777/faster-outlines/blob/main/faster_outlines/fsm/vllm_guide.py # noqa: E501
import pickle
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
regex_lp = await get_guided_decoding_logits_processor(
regex_request, tokenizer)
assert regex_lp is not None
serialized = pickle.dumps(regex_lp)
regex_lp = pickle.loads(serialized)
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}"
)
json_request = GuidedDecodingParams(json=sample_json_schema,
backend=backend)
json_lp = await get_guided_decoding_logits_processor(
json_request, tokenizer)
assert json_lp is not None
serialized = pickle.dumps(json_lp)
json_lp = pickle.loads(serialized)
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)


@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
async def test_guided_logits_processor_black_box(backend: str, sample_regex,
Expand Down Expand Up @@ -82,4 +126,4 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):

with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
4 changes: 3 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2008,7 +2008,9 @@ class DecodingConfig:
guided_decoding_backend: str = 'outlines'

def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer']
valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'faster-outlines'
]
backend = self.guided_decoding_backend
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'--guided-decoding-backend',
type=str,
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
choices=['outlines', 'lm-format-enforcer', 'faster-outlines'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/outlines-dev/outlines and '
Expand Down
10 changes: 10 additions & 0 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ async def get_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'faster-outlines':
from vllm.model_executor.guided_decoding.faster_outlines_decoding import ( # noqa
get_local_faster_outlines_guided_decoding_logits_processor)
return get_local_faster_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
Expand All @@ -40,6 +45,11 @@ def get_local_guided_decoding_logits_processor(
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend == 'faster-outlines':
from vllm.model_executor.guided_decoding.faster_outlines_decoding import ( # noqa
get_local_faster_outlines_guided_decoding_logits_processor)
return get_local_faster_outlines_guided_decoding_logits_processor(
guided_params, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
Expand Down
115 changes: 115 additions & 0 deletions vllm/model_executor/guided_decoding/faster_outlines_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import math
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Dict, List, Union

import torch
from faster_outlines.fsm import (Generate, LazyVLLMRegexGuide, TokenVocabulary,
Write)
from outlines.fsm.json_schema import build_regex_from_schema
from transformers import PreTrainedTokenizerBase

from vllm.sampling_params import GuidedDecodingParams

TOKENIZER_CACHE: Dict[str, TokenVocabulary] = {}


class BaseLogitsProcessor:

def __init__(self, guide):
self._guide = guide
self.state = 0

def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
if len(input_ids) > 0:
self.state = self._guide.get_next_state(state=self.state,
token_id=input_ids[-1])

instruction = self._guide.get_next_instruction(state=self.state)
if type(instruction) == Generate: # noqa: E721
allowed_tokens = instruction.tokens
elif type(instruction) == Write: # noqa: E721
# TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]]
else:
raise TypeError(
f"Unsupported instruction type {type(instruction)}")

mask = torch.full((scores.shape[-1], ),
-math.inf,
device=scores.device)
mask[allowed_tokens] = 0
scores.add_(mask)
return scores


class RegexLogitsProcessor(BaseLogitsProcessor):

@classmethod
def _get_guide(cls, regex_string: str, tokenizer: PreTrainedTokenizerBase):
vocab = _adapt_tokenizer(tokenizer)
return LazyVLLMRegexGuide(regex_string, vocab)

def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
"""Compile the FSM that drives the regex-structured generation.

Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer

"""
super().__init__(
RegexLogitsProcessor._get_guide(regex_string, tokenizer))


def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""
Adapt VLLM's tokenizer into a TokenVocabulary, readable by Rust.
"""
if TOKENIZER_CACHE.get(tokenizer.name_or_path) is not None:

return TOKENIZER_CACHE[tokenizer.name_or_path]

token_vocab = TokenVocabulary(tokenizer.get_vocab(),
tokenizer.eos_token_id,
set(tokenizer.all_special_tokens))

TOKENIZER_CACHE[tokenizer.name_or_path] = token_vocab

return token_vocab


def get_local_faster_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
) -> Union[RegexLogitsProcessor, None]:
regex = _get_regex(guided_params)

if not regex:
return None

return RegexLogitsProcessor(regex, tokenizer)


def _get_regex(guided_params: GuidedDecodingParams) -> Union[str, None]:
if guided_params.json:
if isinstance(guided_params.json, dict):
# turn dict into hashable string
json = build_regex_from_schema(json_dumps(guided_params.json))
else:
json = build_regex_from_schema(guided_params.json)
return json
elif guided_params.regex:
return guided_params.regex
elif guided_params.choice:
# choice just uses regex
choices = [
regex_escape(str(choice)) for choice in guided_params.choice
]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex
return None