Skip to content

Commit

Permalink
[REFACTOR] Move GenerationConfig to protocol (#2427)
Browse files Browse the repository at this point in the history
This PR moves GenerationConfig to protocol.
As we move towards OAI style API GenerationConfig becomes more like an internal API.

This change reflects that and also removes duplicated definition of ResponseFormat
and DebugConfig
  • Loading branch information
tqchen authored May 26, 2024
1 parent ff91749 commit 14bec5a
Show file tree
Hide file tree
Showing 21 changed files with 95 additions and 185 deletions.
1 change: 0 additions & 1 deletion python/mlc_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
"""

from . import protocol, serve
from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig
from .libinfo import __version__
from .serve import AsyncMLCEngine, MLCEngine
11 changes: 8 additions & 3 deletions python/mlc_llm/protocol/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
"""Definitions of pydantic models for API entry points and configurations"""
from . import openai_api_protocol
"""Definitions of pydantic models for API entry points and configurations
RequestProtocol = openai_api_protocol.CompletionRequest
Note
----
We use the following convention
- filename_protocol If the classes can appear in an API endpoint
- filename_config For other config classes
"""
32 changes: 32 additions & 0 deletions python/mlc_llm/protocol/generation_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Low-level generation config class"""
# pylint: disable=missing-class-docstring, disable=too-many-instance-attributes
from typing import Dict, List, Optional

from pydantic import BaseModel

from .debug_protocol import DebugConfig
from .openai_api_protocol import RequestResponseFormat


class GenerationConfig(BaseModel): # pylint:
"""The generation configuration dataclass.
This is a config class used by Engine internally.
"""

n: int = 1
temperature: Optional[float] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
repetition_penalty: Optional[float] = None
logprobs: bool = False
top_logprobs: int = 0
logit_bias: Optional[Dict[int, float]] = None
# internally we use -1 to represent infinite
max_tokens: int = -1
seed: Optional[int] = None
stop_strs: Optional[List[str]] = None
stop_token_ids: Optional[List[int]] = None
response_format: Optional[RequestResponseFormat] = None
debug_config: Optional[Optional[DebugConfig]] = None
2 changes: 1 addition & 1 deletion python/mlc_llm/serve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Load MLC LLM library by importing base
from .. import base
from .config import DebugConfig, EngineConfig, GenerationConfig
from .config import EngineConfig
from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData
from .engine import AsyncMLCEngine, MLCEngine
from .grammar import BNFGrammar, GrammarStateMatcher
Expand Down
142 changes: 1 addition & 141 deletions python/mlc_llm/serve/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,147 +2,7 @@

import json
from dataclasses import asdict, dataclass, field
from typing import Dict, List, Literal, Optional, Tuple, Union


@dataclass
class ResponseFormat:
"""The response format dataclass.
Parameters
----------
type : Literal["text", "json_object"]
The type of response format. Default: "text".
schema : Optional[str]
The JSON schema string for the JSON response format. If None, a legal json string without
special restrictions will be generated.
Could be specified when the response format is "json_object". Default: None.
"""

type: Literal["text", "json_object"] = "text"
schema: Optional[str] = None

def __post_init__(self):
if self.schema is not None and self.type != "json_object":
raise ValueError("JSON schema is only supported in JSON response format")


@dataclass
class DebugConfig:
"""The debug configuration dataclass.Parameters
----------
ignore_eos : bool
When it is true, ignore the eos token and generate tokens until `max_tokens`.
Default is set to False.
pinned_system_prompt : bool
Whether the input and generated data pinned in engine. Default is set to False.
This can be used for system prompt or other purpose, if the data is aimed to be
kept all the time.
special_request: Optional[string]
Special requests to send to engine
"""

ignore_eos: bool = False
pinned_system_prompt: bool = False
special_request: Optional[Literal["query_engine_metrics"]] = None


@dataclass
class GenerationConfig: # pylint: disable=too-many-instance-attributes
"""The generation configuration dataclass.
Parameters
----------
n : int
How many chat completion choices to generate for each input message.
temperature : Optional[float]
The value that applies to logits and modulates the next token probabilities.
top_p : Optional[float]
In sampling, only the most probable tokens with probabilities summed up to
`top_p` are kept for sampling.
frequency_penalty : Optional[float]
Positive values penalize new tokens based on their existing frequency
in the text so far, decreasing the model's likelihood to repeat the same
line verbatim.
presence_penalty : Optional[float]
Positive values penalize new tokens based on whether they appear in the text
so far, increasing the model's likelihood to talk about new topics.
repetition_penalty : float
The penalty term that applies to logits to control token repetition in generation.
It will be suppressed when any of frequency_penalty and presence_penalty is
non-zero.
logprobs : bool
Whether to return log probabilities of the output tokens or not.
If true, the log probabilities of each output token will be returned.
top_logprobs : int
An integer between 0 and 5 specifying the number of most likely
tokens to return at each token position, each with an associated
log probability.
`logprobs` must be set to True if this parameter is used.
logit_bias : Optional[Dict[int, float]]
The bias logit value added to selected tokens prior to sampling.
max_tokens : Optional[int]
The maximum number of generated tokens,
or None, in which case the generation will not stop
until exceeding model capability or hit any stop criteria.
seed : Optional[int]
The random seed of the generation.
The seed will be a random value if not specified.
stop_strs : List[str]
The list of strings that mark the end of generation.
stop_token_ids : List[int]
The list of token ids that mark the end of generation.
response_format : ResponseFormat
The response format of the generation output.
debug_config : Optional[DebugConfig]
The optional debug configuration.
"""

n: int = 1
temperature: Optional[float] = None
top_p: Optional[float] = None
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
repetition_penalty: float = 1.0
logprobs: bool = False
top_logprobs: int = 0
logit_bias: Optional[Dict[int, float]] = field(default_factory=dict) # type: ignore

max_tokens: Optional[int] = 128
seed: Optional[int] = None
stop_strs: List[str] = field(default_factory=list)
stop_token_ids: List[int] = field(default_factory=list)

response_format: ResponseFormat = field(default_factory=ResponseFormat)

debug_config: Optional[DebugConfig] = field(default_factory=DebugConfig)

def asjson(self) -> str:
"""Return the config in string of JSON format."""
return json.dumps(asdict(self))

@staticmethod
def from_json(json_str: str) -> "GenerationConfig":
"""Construct a config from JSON string."""
return GenerationConfig(**json.loads(json_str))
from typing import List, Literal, Optional, Tuple, Union


@dataclass
Expand Down
11 changes: 8 additions & 3 deletions python/mlc_llm/serve/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from tvm.runtime import Device

from mlc_llm.protocol import debug_protocol, openai_api_protocol
from mlc_llm.protocol.generation_config import GenerationConfig
from mlc_llm.serve import data, engine_utils
from mlc_llm.serve.config import EngineConfig, GenerationConfig
from mlc_llm.serve.config import EngineConfig
from mlc_llm.streamer import TextStreamer
from mlc_llm.support import logging

Expand Down Expand Up @@ -1372,7 +1373,9 @@ async def _generate(
# Create the request with the given id, input data, generation
# config and the created callback.
input_data = engine_utils.convert_prompts_to_data(prompt)
request = self._ffi["create_request"](request_id, input_data, generation_config.asjson())
request = self._ffi["create_request"](
request_id, input_data, generation_config.model_dump_json()
)

# Create the unique async request stream of the request.
stream = engine_base.AsyncRequestStream()
Expand Down Expand Up @@ -1898,7 +1901,9 @@ def _generate( # pylint: disable=too-many-locals
# Create the request with the given id, input data, generation
# config and the created callback.
input_data = engine_utils.convert_prompts_to_data(prompt)
request = self._ffi["create_request"](request_id, input_data, generation_config.asjson())
request = self._ffi["create_request"](
request_id, input_data, generation_config.model_dump_json()
)

# Record the stream in the tracker
self.state.sync_output_queue = queue.Queue()
Expand Down
3 changes: 2 additions & 1 deletion python/mlc_llm/serve/engine_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

from mlc_llm.protocol import openai_api_protocol
from mlc_llm.protocol.conversation_protocol import Conversation
from mlc_llm.protocol.generation_config import GenerationConfig
from mlc_llm.protocol.mlc_chat_config import MLCChatConfig
from mlc_llm.serve import data, engine_utils
from mlc_llm.serve.config import EngineConfig, GenerationConfig
from mlc_llm.serve.config import EngineConfig
from mlc_llm.serve.event_trace_recorder import EventTraceRecorder
from mlc_llm.streamer import TextStreamer
from mlc_llm.support import download_cache, logging
Expand Down
19 changes: 8 additions & 11 deletions python/mlc_llm/serve/engine_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import uuid
from typing import Any, Callable, Dict, List, Optional, Union

from mlc_llm.protocol import RequestProtocol, error_protocol, openai_api_protocol
from mlc_llm.protocol import error_protocol, openai_api_protocol
from mlc_llm.protocol.generation_config import GenerationConfig
from mlc_llm.serve import data

from .config import DebugConfig, GenerationConfig, ResponseFormat
RequestProtocol = Union[
openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest
]


def get_unsupported_fields(request: RequestProtocol) -> List[str]:
Expand All @@ -20,9 +23,7 @@ def get_unsupported_fields(request: RequestProtocol) -> List[str]:
raise RuntimeError("Cannot reach here")


def openai_api_get_generation_config(
request: Union[openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest]
) -> Dict[str, Any]:
def openai_api_get_generation_config(request: RequestProtocol) -> Dict[str, Any]:
"""Create the generation config from the given request."""
kwargs: Dict[str, Any] = {}
arg_names = [
Expand All @@ -36,6 +37,8 @@ def openai_api_get_generation_config(
"top_logprobs",
"logit_bias",
"seed",
"response_format",
"debug_config",
]
for arg_name in arg_names:
kwargs[arg_name] = getattr(request, arg_name)
Expand All @@ -45,12 +48,6 @@ def openai_api_get_generation_config(
kwargs["max_tokens"] = -1
if request.stop is not None:
kwargs["stop_strs"] = [request.stop] if isinstance(request.stop, str) else request.stop
if request.response_format is not None:
kwargs["response_format"] = ResponseFormat(
**request.response_format.model_dump(by_alias=True)
)
if request.debug_config is not None:
kwargs["debug_config"] = DebugConfig(**request.debug_config.model_dump())
return kwargs


Expand Down
5 changes: 3 additions & 2 deletions python/mlc_llm/serve/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import tvm._ffi
from tvm.runtime import Object

from mlc_llm.protocol.generation_config import GenerationConfig

from . import _ffi_api
from .config import GenerationConfig
from .data import Data


Expand All @@ -29,6 +30,6 @@ def inputs(self) -> List[Data]:
@property
def generation_config(self) -> GenerationConfig:
"""The generation config of the request."""
return GenerationConfig.from_json(
return GenerationConfig.model_validate_json(
_ffi_api.RequestGetGenerationConfigJSON(self) # type: ignore # pylint: disable=no-member
)
5 changes: 3 additions & 2 deletions python/mlc_llm/serve/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

import tvm

from mlc_llm.protocol.generation_config import GenerationConfig
from mlc_llm.serve import data
from mlc_llm.serve.config import EngineConfig, GenerationConfig
from mlc_llm.serve.config import EngineConfig
from mlc_llm.serve.engine_base import (
EngineMetrics,
_check_engine_config,
Expand Down Expand Up @@ -307,7 +308,7 @@ def create_request(
"""
if not isinstance(inputs, list):
inputs = [inputs]
return self._ffi["create_request"](request_id, inputs, generation_config.asjson())
return self._ffi["create_request"](request_id, inputs, generation_config.model_dump_json())

def add_request(self, request: Request) -> None:
"""Add a new request to the engine.
Expand Down
3 changes: 0 additions & 3 deletions python/mlc_llm/testing/debug_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,6 @@ def generate(
generate_length : int
How many tokens to generate.
generation_config : Optional[GenerationConfig]
Will be used to override the GenerationConfig in ``mlc-chat-config.json``.
"""
out_tokens = []

Expand Down
2 changes: 1 addition & 1 deletion tests/python/serve/evaluate_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
from typing import List, Tuple

from mlc_llm.serve import GenerationConfig
from mlc_llm.protocol.generation_config import GenerationConfig
from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine


Expand Down
Loading

0 comments on commit 14bec5a

Please sign in to comment.