From 14bec5ad67f1e375b809f8b402dcffcf44b4a18b Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sun, 26 May 2024 17:50:29 -0400 Subject: [PATCH] [REFACTOR] Move GenerationConfig to protocol (#2427) This PR moves GenerationConfig to protocol. As we move towards OAI style API GenerationConfig becomes more like an internal API. This change reflects that and also removes duplicated definition of ResponseFormat and DebugConfig --- python/mlc_llm/__init__.py | 1 - python/mlc_llm/protocol/__init__.py | 11 +- python/mlc_llm/protocol/generation_config.py | 32 ++++ python/mlc_llm/serve/__init__.py | 2 +- python/mlc_llm/serve/config.py | 142 +----------------- python/mlc_llm/serve/engine.py | 11 +- python/mlc_llm/serve/engine_base.py | 3 +- python/mlc_llm/serve/engine_utils.py | 19 +-- python/mlc_llm/serve/request.py | 5 +- python/mlc_llm/serve/sync_engine.py | 5 +- python/mlc_llm/testing/debug_chat.py | 3 - tests/python/serve/evaluate_engine.py | 2 +- tests/python/serve/test_serve_async_engine.py | 18 ++- .../serve/test_serve_async_engine_spec.py | 3 +- tests/python/serve/test_serve_engine.py | 3 +- .../python/serve/test_serve_engine_grammar.py | 5 +- tests/python/serve/test_serve_engine_image.py | 3 +- .../serve/test_serve_engine_prefix_cache.py | 3 +- tests/python/serve/test_serve_engine_rnn.py | 3 +- tests/python/serve/test_serve_engine_spec.py | 3 +- tests/python/serve/test_serve_sync_engine.py | 3 +- 21 files changed, 95 insertions(+), 185 deletions(-) create mode 100644 python/mlc_llm/protocol/generation_config.py diff --git a/python/mlc_llm/__init__.py b/python/mlc_llm/__init__.py index 4843c6766d..66285cea4e 100644 --- a/python/mlc_llm/__init__.py +++ b/python/mlc_llm/__init__.py @@ -4,6 +4,5 @@ """ from . import protocol, serve -from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig from .libinfo import __version__ from .serve import AsyncMLCEngine, MLCEngine diff --git a/python/mlc_llm/protocol/__init__.py b/python/mlc_llm/protocol/__init__.py index 8cd2a69ca7..b430746477 100644 --- a/python/mlc_llm/protocol/__init__.py +++ b/python/mlc_llm/protocol/__init__.py @@ -1,4 +1,9 @@ -"""Definitions of pydantic models for API entry points and configurations""" -from . import openai_api_protocol +"""Definitions of pydantic models for API entry points and configurations -RequestProtocol = openai_api_protocol.CompletionRequest +Note +---- +We use the following convention + +- filename_protocol If the classes can appear in an API endpoint +- filename_config For other config classes +""" diff --git a/python/mlc_llm/protocol/generation_config.py b/python/mlc_llm/protocol/generation_config.py new file mode 100644 index 0000000000..6cd5e82cf0 --- /dev/null +++ b/python/mlc_llm/protocol/generation_config.py @@ -0,0 +1,32 @@ +"""Low-level generation config class""" +# pylint: disable=missing-class-docstring, disable=too-many-instance-attributes +from typing import Dict, List, Optional + +from pydantic import BaseModel + +from .debug_protocol import DebugConfig +from .openai_api_protocol import RequestResponseFormat + + +class GenerationConfig(BaseModel): # pylint: + """The generation configuration dataclass. + + This is a config class used by Engine internally. + """ + + n: int = 1 + temperature: Optional[float] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + repetition_penalty: Optional[float] = None + logprobs: bool = False + top_logprobs: int = 0 + logit_bias: Optional[Dict[int, float]] = None + # internally we use -1 to represent infinite + max_tokens: int = -1 + seed: Optional[int] = None + stop_strs: Optional[List[str]] = None + stop_token_ids: Optional[List[int]] = None + response_format: Optional[RequestResponseFormat] = None + debug_config: Optional[Optional[DebugConfig]] = None diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 4ef4470399..6b122bdf64 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -2,7 +2,7 @@ # Load MLC LLM library by importing base from .. import base -from .config import DebugConfig, EngineConfig, GenerationConfig +from .config import EngineConfig from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import AsyncMLCEngine, MLCEngine from .grammar import BNFGrammar, GrammarStateMatcher diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index f4fadf0dae..bf79bb672f 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -2,147 +2,7 @@ import json from dataclasses import asdict, dataclass, field -from typing import Dict, List, Literal, Optional, Tuple, Union - - -@dataclass -class ResponseFormat: - """The response format dataclass. - - Parameters - ---------- - type : Literal["text", "json_object"] - The type of response format. Default: "text". - - schema : Optional[str] - The JSON schema string for the JSON response format. If None, a legal json string without - special restrictions will be generated. - - Could be specified when the response format is "json_object". Default: None. - """ - - type: Literal["text", "json_object"] = "text" - schema: Optional[str] = None - - def __post_init__(self): - if self.schema is not None and self.type != "json_object": - raise ValueError("JSON schema is only supported in JSON response format") - - -@dataclass -class DebugConfig: - """The debug configuration dataclass.Parameters - ---------- - ignore_eos : bool - When it is true, ignore the eos token and generate tokens until `max_tokens`. - Default is set to False. - - pinned_system_prompt : bool - Whether the input and generated data pinned in engine. Default is set to False. - This can be used for system prompt or other purpose, if the data is aimed to be - kept all the time. - - special_request: Optional[string] - Special requests to send to engine - """ - - ignore_eos: bool = False - pinned_system_prompt: bool = False - special_request: Optional[Literal["query_engine_metrics"]] = None - - -@dataclass -class GenerationConfig: # pylint: disable=too-many-instance-attributes - """The generation configuration dataclass. - - Parameters - ---------- - n : int - How many chat completion choices to generate for each input message. - - temperature : Optional[float] - The value that applies to logits and modulates the next token probabilities. - - top_p : Optional[float] - In sampling, only the most probable tokens with probabilities summed up to - `top_p` are kept for sampling. - - frequency_penalty : Optional[float] - Positive values penalize new tokens based on their existing frequency - in the text so far, decreasing the model's likelihood to repeat the same - line verbatim. - - presence_penalty : Optional[float] - Positive values penalize new tokens based on whether they appear in the text - so far, increasing the model's likelihood to talk about new topics. - - repetition_penalty : float - The penalty term that applies to logits to control token repetition in generation. - It will be suppressed when any of frequency_penalty and presence_penalty is - non-zero. - - logprobs : bool - Whether to return log probabilities of the output tokens or not. - If true, the log probabilities of each output token will be returned. - - top_logprobs : int - An integer between 0 and 5 specifying the number of most likely - tokens to return at each token position, each with an associated - log probability. - `logprobs` must be set to True if this parameter is used. - - logit_bias : Optional[Dict[int, float]] - The bias logit value added to selected tokens prior to sampling. - - max_tokens : Optional[int] - The maximum number of generated tokens, - or None, in which case the generation will not stop - until exceeding model capability or hit any stop criteria. - - seed : Optional[int] - The random seed of the generation. - The seed will be a random value if not specified. - - stop_strs : List[str] - The list of strings that mark the end of generation. - - stop_token_ids : List[int] - The list of token ids that mark the end of generation. - - response_format : ResponseFormat - The response format of the generation output. - - debug_config : Optional[DebugConfig] - The optional debug configuration. - """ - - n: int = 1 - temperature: Optional[float] = None - top_p: Optional[float] = None - frequency_penalty: Optional[float] = None - presence_penalty: Optional[float] = None - repetition_penalty: float = 1.0 - logprobs: bool = False - top_logprobs: int = 0 - logit_bias: Optional[Dict[int, float]] = field(default_factory=dict) # type: ignore - - max_tokens: Optional[int] = 128 - seed: Optional[int] = None - stop_strs: List[str] = field(default_factory=list) - stop_token_ids: List[int] = field(default_factory=list) - - response_format: ResponseFormat = field(default_factory=ResponseFormat) - - debug_config: Optional[DebugConfig] = field(default_factory=DebugConfig) - - def asjson(self) -> str: - """Return the config in string of JSON format.""" - return json.dumps(asdict(self)) - - @staticmethod - def from_json(json_str: str) -> "GenerationConfig": - """Construct a config from JSON string.""" - return GenerationConfig(**json.loads(json_str)) +from typing import List, Literal, Optional, Tuple, Union @dataclass diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index e072d1028d..012f450bb2 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -22,8 +22,9 @@ from tvm.runtime import Device from mlc_llm.protocol import debug_protocol, openai_api_protocol +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineConfig, GenerationConfig +from mlc_llm.serve.config import EngineConfig from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -1372,7 +1373,9 @@ async def _generate( # Create the request with the given id, input data, generation # config and the created callback. input_data = engine_utils.convert_prompts_to_data(prompt) - request = self._ffi["create_request"](request_id, input_data, generation_config.asjson()) + request = self._ffi["create_request"]( + request_id, input_data, generation_config.model_dump_json() + ) # Create the unique async request stream of the request. stream = engine_base.AsyncRequestStream() @@ -1898,7 +1901,9 @@ def _generate( # pylint: disable=too-many-locals # Create the request with the given id, input data, generation # config and the created callback. input_data = engine_utils.convert_prompts_to_data(prompt) - request = self._ffi["create_request"](request_id, input_data, generation_config.asjson()) + request = self._ffi["create_request"]( + request_id, input_data, generation_config.model_dump_json() + ) # Record the stream in the tracker self.state.sync_output_queue = queue.Queue() diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index d8b1842c0b..8aa8d52b97 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -18,9 +18,10 @@ from mlc_llm.protocol import openai_api_protocol from mlc_llm.protocol.conversation_protocol import Conversation +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.protocol.mlc_chat_config import MLCChatConfig from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineConfig, GenerationConfig +from mlc_llm.serve.config import EngineConfig from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.streamer import TextStreamer from mlc_llm.support import download_cache, logging diff --git a/python/mlc_llm/serve/engine_utils.py b/python/mlc_llm/serve/engine_utils.py index c2d686d583..6ccbc0e621 100644 --- a/python/mlc_llm/serve/engine_utils.py +++ b/python/mlc_llm/serve/engine_utils.py @@ -3,10 +3,13 @@ import uuid from typing import Any, Callable, Dict, List, Optional, Union -from mlc_llm.protocol import RequestProtocol, error_protocol, openai_api_protocol +from mlc_llm.protocol import error_protocol, openai_api_protocol +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve import data -from .config import DebugConfig, GenerationConfig, ResponseFormat +RequestProtocol = Union[ + openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest +] def get_unsupported_fields(request: RequestProtocol) -> List[str]: @@ -20,9 +23,7 @@ def get_unsupported_fields(request: RequestProtocol) -> List[str]: raise RuntimeError("Cannot reach here") -def openai_api_get_generation_config( - request: Union[openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest] -) -> Dict[str, Any]: +def openai_api_get_generation_config(request: RequestProtocol) -> Dict[str, Any]: """Create the generation config from the given request.""" kwargs: Dict[str, Any] = {} arg_names = [ @@ -36,6 +37,8 @@ def openai_api_get_generation_config( "top_logprobs", "logit_bias", "seed", + "response_format", + "debug_config", ] for arg_name in arg_names: kwargs[arg_name] = getattr(request, arg_name) @@ -45,12 +48,6 @@ def openai_api_get_generation_config( kwargs["max_tokens"] = -1 if request.stop is not None: kwargs["stop_strs"] = [request.stop] if isinstance(request.stop, str) else request.stop - if request.response_format is not None: - kwargs["response_format"] = ResponseFormat( - **request.response_format.model_dump(by_alias=True) - ) - if request.debug_config is not None: - kwargs["debug_config"] = DebugConfig(**request.debug_config.model_dump()) return kwargs diff --git a/python/mlc_llm/serve/request.py b/python/mlc_llm/serve/request.py index d9260e6598..10c2e0577d 100644 --- a/python/mlc_llm/serve/request.py +++ b/python/mlc_llm/serve/request.py @@ -4,8 +4,9 @@ import tvm._ffi from tvm.runtime import Object +from mlc_llm.protocol.generation_config import GenerationConfig + from . import _ffi_api -from .config import GenerationConfig from .data import Data @@ -29,6 +30,6 @@ def inputs(self) -> List[Data]: @property def generation_config(self) -> GenerationConfig: """The generation config of the request.""" - return GenerationConfig.from_json( + return GenerationConfig.model_validate_json( _ffi_api.RequestGetGenerationConfigJSON(self) # type: ignore # pylint: disable=no-member ) diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 460bc4d52e..5b5fd9cd98 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -13,8 +13,9 @@ import tvm +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve import data -from mlc_llm.serve.config import EngineConfig, GenerationConfig +from mlc_llm.serve.config import EngineConfig from mlc_llm.serve.engine_base import ( EngineMetrics, _check_engine_config, @@ -307,7 +308,7 @@ def create_request( """ if not isinstance(inputs, list): inputs = [inputs] - return self._ffi["create_request"](request_id, inputs, generation_config.asjson()) + return self._ffi["create_request"](request_id, inputs, generation_config.model_dump_json()) def add_request(self, request: Request) -> None: """Add a new request to the engine. diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 6aacce1faf..6f25328c8f 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -373,9 +373,6 @@ def generate( generate_length : int How many tokens to generate. - - generation_config : Optional[GenerationConfig] - Will be used to override the GenerationConfig in ``mlc-chat-config.json``. """ out_tokens = [] diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index 608f69dd4c..7767c30abc 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -4,7 +4,7 @@ import random from typing import List, Tuple -from mlc_llm.serve import GenerationConfig +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index 1884359718..993e5b60b3 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -3,7 +3,8 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncMLCEngine, EngineConfig, GenerationConfig +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import AsyncMLCEngine, EngineConfig from mlc_llm.testing import require_test_model prompts = [ @@ -20,7 +21,7 @@ ] -@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC") +@require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC") async def test_engine_generate(model: str): # Create engine async_engine = AsyncMLCEngine( @@ -48,9 +49,12 @@ async def generate_task( async for delta_outputs in async_engine._generate( prompt, generation_cfg, request_id=request_id ): - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - output_texts[rid][i] += delta_output.delta_text + if len(delta_outputs) == generation_cfg.n: + for i, delta_output in enumerate(delta_outputs): + output_texts[rid][i] += delta_output.delta_text + else: + assert len(delta_outputs) == 1 + assert len(delta_outputs[0].request_final_usage_json_str) != 0 tasks = [ asyncio.create_task( @@ -75,7 +79,7 @@ async def generate_task( del async_engine -@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC") +@require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC") async def test_chat_completion(model: str): # Create engine async_engine = AsyncMLCEngine( @@ -126,7 +130,7 @@ async def generate_task(prompt: str, request_id: str): del async_engine -@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC") +@require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC") async def test_chat_completion_non_stream(model: str): # Create engine async_engine = AsyncMLCEngine( diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index c3d4c37756..476d970e1c 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,7 +3,8 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncMLCEngine, EngineConfig, GenerationConfig +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import AsyncMLCEngine, EngineConfig from mlc_llm.testing import require_test_model prompts = [ diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 670d33b236..899629a448 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -2,7 +2,8 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List -from mlc_llm.serve import EngineConfig, GenerationConfig, MLCEngine +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import EngineConfig, MLCEngine from mlc_llm.testing import require_test_model prompts = [ diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index d85ab8e762..13d12f5a29 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -7,8 +7,9 @@ import pytest from pydantic import BaseModel -from mlc_llm.serve import AsyncMLCEngine, GenerationConfig -from mlc_llm.serve.config import ResponseFormat +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.protocol.openai_api_protocol import RequestResponseFormat as ResponseFormat +from mlc_llm.serve import AsyncMLCEngine from mlc_llm.serve.sync_engine import SyncMLCEngine from mlc_llm.testing import require_test_model diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py index b1cdf1fcea..0fdf141faf 100644 --- a/tests/python/serve/test_serve_engine_image.py +++ b/tests/python/serve/test_serve_engine_image.py @@ -1,7 +1,8 @@ import json from pathlib import Path -from mlc_llm.serve import GenerationConfig, data +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import data from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine diff --git a/tests/python/serve/test_serve_engine_prefix_cache.py b/tests/python/serve/test_serve_engine_prefix_cache.py index ca55540fff..0a32c04b11 100644 --- a/tests/python/serve/test_serve_engine_prefix_cache.py +++ b/tests/python/serve/test_serve_engine_prefix_cache.py @@ -1,4 +1,5 @@ -from mlc_llm.serve import DebugConfig, GenerationConfig +from mlc_llm.protocol.debug_protocol import DebugConfig +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine from mlc_llm.testing import require_test_model diff --git a/tests/python/serve/test_serve_engine_rnn.py b/tests/python/serve/test_serve_engine_rnn.py index 090c06dbc3..194e7ec35d 100644 --- a/tests/python/serve/test_serve_engine_rnn.py +++ b/tests/python/serve/test_serve_engine_rnn.py @@ -2,7 +2,8 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List -from mlc_llm.serve import EngineConfig, GenerationConfig, MLCEngine +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import EngineConfig, MLCEngine prompts = [ "What is the meaning of life?", diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index b37e7c8051..61a40476ae 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -4,7 +4,8 @@ import numpy as np -from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import Request, RequestStreamOutput, data from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine from mlc_llm.testing import require_test_model diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py index 8dbc60925e..b889628592 100644 --- a/tests/python/serve/test_serve_sync_engine.py +++ b/tests/python/serve/test_serve_sync_engine.py @@ -4,7 +4,8 @@ import numpy as np -from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import Request, RequestStreamOutput, data from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine from mlc_llm.testing import require_test_model