diff --git a/requirements-common.txt b/requirements-common.txt index ef5ed8b645158..4f9f8ffc456ac 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.10.6 outlines >= 0.0.43, < 0.1 +faster-outlines >= 2024.11.14 typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 partial-json-parser # used for parsing partial JSON outputs diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 45fab8e96b968..34c4fde2975d7 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -35,6 +35,50 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): assert not torch.allclose(tensor, original_tensor) +@pytest.mark.asyncio +@pytest.mark.parametrize("backend", ["faster-outlines"]) +async def test_guided_logits_processor_black_box_faster_outlines( + backend: str, + sample_regex, # noqa: E501 + sample_json_schema): + # faster-outlines processors require special handling, + # since they do not begin computation of a fsm index until they have been + # both serialized and deserialized. More on the reason why can be found + # in the doc comment for the class: + # https://github.com/unaidedelf8777/faster-outlines/blob/main/faster_outlines/fsm/vllm_guide.py # noqa: E501 + import pickle + tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + token_ids = tokenizer.encode( + f"Give an example IPv4 address with this regex: {sample_regex}") + regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) + regex_lp = await get_guided_decoding_logits_processor( + regex_request, tokenizer) + assert regex_lp is not None + serialized = pickle.dumps(regex_lp) + regex_lp = pickle.loads(serialized) + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = regex_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + token_ids = tokenizer.encode( + f"Give an employee profile that fits this schema: {sample_json_schema}" + ) + json_request = GuidedDecodingParams(json=sample_json_schema, + backend=backend) + json_lp = await get_guided_decoding_logits_processor( + json_request, tokenizer) + assert json_lp is not None + serialized = pickle.dumps(json_lp) + json_lp = pickle.loads(serialized) + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + tensor = json_lp(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + @pytest.mark.asyncio @pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) async def test_guided_logits_processor_black_box(backend: str, sample_regex, @@ -82,4 +126,4 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): with pytest.raises(ValueError, match="You can only use one kind of guided"): - GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") + GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index 5ba1c41fcaac1..6087e60bd3e9f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2008,7 +2008,9 @@ class DecodingConfig: guided_decoding_backend: str = 'outlines' def __post_init__(self): - valid_guided_backends = ['outlines', 'lm-format-enforcer'] + valid_guided_backends = [ + 'outlines', 'lm-format-enforcer', 'faster-outlines' + ] backend = self.guided_decoding_backend if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}," diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 27f62b0008578..e9ab3cb143bf8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -358,7 +358,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: '--guided-decoding-backend', type=str, default='outlines', - choices=['outlines', 'lm-format-enforcer'], + choices=['outlines', 'lm-format-enforcer', 'faster-outlines'], help='Which engine will be used for guided decoding' ' (JSON schema / regex etc) by default. Currently support ' 'https://github.com/outlines-dev/outlines and ' diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index d7b67425fcbc0..0383964ce74bd 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -19,6 +19,11 @@ async def get_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == 'faster-outlines': + from vllm.model_executor.guided_decoding.faster_outlines_decoding import ( # noqa + get_local_faster_outlines_guided_decoding_logits_processor) + return get_local_faster_outlines_guided_decoding_logits_processor( + guided_params, tokenizer) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " @@ -40,6 +45,11 @@ def get_local_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == 'faster-outlines': + from vllm.model_executor.guided_decoding.faster_outlines_decoding import ( # noqa + get_local_faster_outlines_guided_decoding_logits_processor) + return get_local_faster_outlines_guided_decoding_logits_processor( + guided_params, tokenizer) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " diff --git a/vllm/model_executor/guided_decoding/faster_outlines_decoding.py b/vllm/model_executor/guided_decoding/faster_outlines_decoding.py new file mode 100644 index 0000000000000..0dc7139e4a605 --- /dev/null +++ b/vllm/model_executor/guided_decoding/faster_outlines_decoding.py @@ -0,0 +1,115 @@ +import math +from json import dumps as json_dumps +from re import escape as regex_escape +from typing import Dict, List, Union + +import torch +from faster_outlines.fsm import (Generate, LazyVLLMRegexGuide, TokenVocabulary, + Write) +from outlines.fsm.json_schema import build_regex_from_schema +from transformers import PreTrainedTokenizerBase + +from vllm.sampling_params import GuidedDecodingParams + +TOKENIZER_CACHE: Dict[str, TokenVocabulary] = {} + + +class BaseLogitsProcessor: + + def __init__(self, guide): + self._guide = guide + self.state = 0 + + def __call__(self, input_ids: List[int], + scores: torch.Tensor) -> torch.Tensor: + """Use the FSM to bias the logits before sampling the next token.""" + if len(input_ids) > 0: + self.state = self._guide.get_next_state(state=self.state, + token_id=input_ids[-1]) + + instruction = self._guide.get_next_instruction(state=self.state) + if type(instruction) == Generate: # noqa: E721 + allowed_tokens = instruction.tokens + elif type(instruction) == Write: # noqa: E721 + # TODO: support fast forward tokens + allowed_tokens = [instruction.tokens[0]] + else: + raise TypeError( + f"Unsupported instruction type {type(instruction)}") + + mask = torch.full((scores.shape[-1], ), + -math.inf, + device=scores.device) + mask[allowed_tokens] = 0 + scores.add_(mask) + return scores + + +class RegexLogitsProcessor(BaseLogitsProcessor): + + @classmethod + def _get_guide(cls, regex_string: str, tokenizer: PreTrainedTokenizerBase): + vocab = _adapt_tokenizer(tokenizer) + return LazyVLLMRegexGuide(regex_string, vocab) + + def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): + """Compile the FSM that drives the regex-structured generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + The model's tokenizer + + """ + super().__init__( + RegexLogitsProcessor._get_guide(regex_string, tokenizer)) + + +def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): + """ + Adapt VLLM's tokenizer into a TokenVocabulary, readable by Rust. + """ + if TOKENIZER_CACHE.get(tokenizer.name_or_path) is not None: + + return TOKENIZER_CACHE[tokenizer.name_or_path] + + token_vocab = TokenVocabulary(tokenizer.get_vocab(), + tokenizer.eos_token_id, + set(tokenizer.all_special_tokens)) + + TOKENIZER_CACHE[tokenizer.name_or_path] = token_vocab + + return token_vocab + + +def get_local_faster_outlines_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase +) -> Union[RegexLogitsProcessor, None]: + regex = _get_regex(guided_params) + + if not regex: + return None + + return RegexLogitsProcessor(regex, tokenizer) + + +def _get_regex(guided_params: GuidedDecodingParams) -> Union[str, None]: + if guided_params.json: + if isinstance(guided_params.json, dict): + # turn dict into hashable string + json = build_regex_from_schema(json_dumps(guided_params.json)) + else: + json = build_regex_from_schema(guided_params.json) + return json + elif guided_params.regex: + return guided_params.regex + elif guided_params.choice: + # choice just uses regex + choices = [ + regex_escape(str(choice)) for choice in guided_params.choice + ] + choices_regex = "(" + "|".join(choices) + ")" + return choices_regex + return None