From e43a578be2c741423f28d919bd02beec4e8e5fd2 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Fri, 13 Dec 2024 22:08:12 +0000 Subject: [PATCH] fix format --- python/mlc_llm/cli/router.py | 13 ++ python/mlc_llm/interface/compiler_flags.py | 2 +- python/mlc_llm/interface/help.py | 4 + python/mlc_llm/interface/router.py | 171 ++++++++++---------- python/mlc_llm/protocol/debug_protocol.py | 1 + python/mlc_llm/router/router.py | 22 +-- python/mlc_llm/serve/engine_base.py | 4 +- python/mlc_llm/serve/server/popen_server.py | 5 +- 8 files changed, 121 insertions(+), 101 deletions(-) diff --git a/python/mlc_llm/cli/router.py b/python/mlc_llm/cli/router.py index 6a696dd6d3..1d69bdb604 100644 --- a/python/mlc_llm/cli/router.py +++ b/python/mlc_llm/cli/router.py @@ -1,9 +1,12 @@ +"""Command line entrypoint of router.""" + from mlc_llm.interface.help import HELP from mlc_llm.interface.router import serve from mlc_llm.support.argparse import ArgumentParser def main(argv): + """Parse command line arguments and call `mlc_llm.interface.router`.""" # Define a custom argument type for a list of strings def list_of_strings(arg): return arg.split(",") @@ -12,54 +15,64 @@ def list_of_strings(arg): parser.add_argument( "model", type=str, + help=HELP["model"] + " (required)", ) parser.add_argument( "--model-lib", type=str, default=None, + help=HELP["model_lib"] + ' (default: "%(default)s")', ) parser.add_argument( "--router-mode", type=str, choices=["disagg", "round-robin"], default="disagg", + help="router mode" + ' (default: "%(default)s")', ) parser.add_argument( "--router-host", type=str, default="127.0.0.1", + help="router host" + ' (default: "%(default)s")', ) parser.add_argument( "--router-port", type=int, default=8000, + help="router port" + ' (default: "%(default)s")', ) parser.add_argument( "--endpoint-hosts", type=list_of_strings, default="127.0.0.1", + help="Host of each endpoint, seperated by comma." + ' (default: "%(default)s")', ) parser.add_argument( "--endpoint-ports", nargs="*", type=int, default=[8080], + help="Port of each endpoint, separated by space." + ' (default: "%(default)s")', ) parser.add_argument( "--endpoint-num-gpus", nargs="*", type=int, default=[1], + help="Number of GPUs of each endpoint, separated by space." + ' (default: "%(default)s")', ) parser.add_argument( "--enable-prefix-cache", default=False, action="store_true", + help="whether to enable prefix cache" + ' (default: "%(default)s")', ) parser.add_argument( "--pd-balance-factor", type=float, default=0.0, + help=HELP["pd_balance_factor"] + ' (default: "%(default)s")', ) parsed = parser.parse_args(argv) serve( diff --git a/python/mlc_llm/interface/compiler_flags.py b/python/mlc_llm/interface/compiler_flags.py index 9bb7d10061..94d66cccea 100644 --- a/python/mlc_llm/interface/compiler_flags.py +++ b/python/mlc_llm/interface/compiler_flags.py @@ -138,7 +138,7 @@ def _cudagraph(target) -> bool: @dataclasses.dataclass -class ModelConfigOverride(ConfigOverrideBase): +class ModelConfigOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes """Flags for overriding model config.""" context_window_size: Optional[int] = None diff --git a/python/mlc_llm/interface/help.py b/python/mlc_llm/interface/help.py index 31bb77cce4..540d909bc3 100644 --- a/python/mlc_llm/interface/help.py +++ b/python/mlc_llm/interface/help.py @@ -262,4 +262,8 @@ """.strip(), "seed_calibrate": """ The seed to sample the calibration dataset.""", + "pd_balance_factor": """ +How much prefill to move to decode engine. For example, +0.1 means the last 10 percent tokens are prefilled by decode engine. + """.strip(), } diff --git a/python/mlc_llm/interface/router.py b/python/mlc_llm/interface/router.py index a4c956e1bc..30fc860945 100644 --- a/python/mlc_llm/interface/router.py +++ b/python/mlc_llm/interface/router.py @@ -1,3 +1,5 @@ +"""Python entrypoint of router.""" + from http import HTTPStatus from typing import AsyncGenerator, List, Literal, Optional @@ -13,90 +15,6 @@ # # Global variables # -router_app = fastapi.APIRouter() -router = None - -# -# Define APIs -# - - -@router_app.post("/v1/completions") -async def request_completion(request: CompletionRequest, raw_request: fastapi.Request): - """OpenAI-compatible completion API. - API reference: https://platform.openai.com/docs/api-reference/completions/create - """ - global router - if router is None: - return error_protocol.create_error_response( - HTTPStatus.BAD_REQUEST, message="Router is not initialized." - ) - request_id = f"cmpl-{engine_utils.random_uuid()}" - - # Streaming response. - if request.stream: - # We manually get the first response from generator to - # capture potential exceptions in this scope, rather then - # the StreamingResponse scope. - stream_generator = router.handle_completion( # pylint: disable=protected-access - request, request_id - ) - first_response = await anext( # type: ignore # pylint: disable=undefined-variable - stream_generator - ) - - async def completion_stream_generator() -> AsyncGenerator[str, None]: - if isinstance(first_response, StopAsyncIteration): - yield "data: [DONE]\n\n" - return - yield f"data: {first_response.model_dump_json(by_alias=True)}\n\n" - async for response in stream_generator: - yield f"data: {response.model_dump_json(by_alias=True)}\n\n" - yield "data: [DONE]\n\n" - - return fastapi.responses.StreamingResponse( - completion_stream_generator(), media_type="text/event-stream" - ) - - # Normal response. - request_final_usage = None - output_texts = [""] * request.n - finish_reasons: List[Optional[str]] = [None] * request.n - logprob_results: List[Optional[CompletionLogProbs]] = [None] * request.n - - async for response in router.handle_completion( # pylint: disable=protected-access - request, request_id - ): - if await raw_request.is_disconnected(): - # In non-streaming cases, the engine will not be notified - # when the request is disconnected. - # Therefore, we check if it is disconnected each time, - # and explicitly return. - # Note that requesta abort is triggered when the async for and funciton scope ends. - return error_protocol.create_error_response( - HTTPStatus.BAD_REQUEST, message="The request has disconnected" - ) - # TODO(Charlie): This is copied from engine.py -- why is it here? Non-streaming only has a single chunk right? - # this is the final chunk - # if response.usage is not None: - # request_final_usage = response.usage - # continue - for choice in response.choices: - output_texts[choice.index] += choice.text - if choice.finish_reason is not None and finish_reasons[choice.index] is None: - finish_reasons[choice.index] = choice.finish_reason - if choice.logprobs is not None: - logprob_results[choice.index] = choice.logprobs - - assert all(finish_reason is not None for finish_reason in finish_reasons) - return engine_base.wrap_completion_response( - request_id=request_id, - model=request.model, - output_texts=output_texts, - finish_reasons=finish_reasons, - logprob_results=logprob_results, - usage=request_final_usage, - ) def serve( @@ -110,9 +28,8 @@ def serve( enable_prefix_cache: bool, router_mode: Literal["disagg", "round-robin"], pd_balance_factor: float, -): +): # pylint: disable=too-many-arguments # 1. Instantiate router - global router router = Router( model=model, model_lib=model_lib, @@ -124,6 +41,85 @@ def serve( pd_balance_factor=pd_balance_factor, ) + router_app = fastapi.APIRouter() + + @router_app.post("/v1/completions") + async def request_completion(request: CompletionRequest, raw_request: fastapi.Request): + """OpenAI-compatible completion API. + API reference: https://platform.openai.com/docs/api-reference/completions/create + """ + if router is None: + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message="Router is not initialized." + ) + request_id = f"cmpl-{engine_utils.random_uuid()}" + + # Streaming response. + if request.stream: + # We manually get the first response from generator to + # capture potential exceptions in this scope, rather then + # the StreamingResponse scope. + stream_generator = router.handle_completion( # pylint: disable=protected-access + request, request_id + ) + first_response = await anext( # type: ignore # pylint: disable=undefined-variable + stream_generator + ) + + async def completion_stream_generator() -> AsyncGenerator[str, None]: + if isinstance(first_response, StopAsyncIteration): + yield "data: [DONE]\n\n" + return + yield f"data: {first_response.model_dump_json(by_alias=True)}\n\n" + async for response in stream_generator: + yield f"data: {response.model_dump_json(by_alias=True)}\n\n" + yield "data: [DONE]\n\n" + + return fastapi.responses.StreamingResponse( + completion_stream_generator(), media_type="text/event-stream" + ) + + # Normal response. + request_final_usage = None + output_texts = [""] * request.n + finish_reasons: List[Optional[str]] = [None] * request.n + logprob_results: List[Optional[CompletionLogProbs]] = [None] * request.n + + async for response in router.handle_completion( # pylint: disable=protected-access + request, request_id + ): + if await raw_request.is_disconnected(): + # In non-streaming cases, the engine will not be notified + # when the request is disconnected. + # Therefore, we check if it is disconnected each time, + # and explicitly return. + # Note that requesta abort is triggered when the async for and funciton scope ends. + return error_protocol.create_error_response( + HTTPStatus.BAD_REQUEST, message="The request has disconnected" + ) + # TODO(Charlie): This is copied from engine.py -- + # why is it here? Non-streaming only has a single chunk right? + # this is the final chunk + # if response.usage is not None: + # request_final_usage = response.usage + # continue + for choice in response.choices: + output_texts[choice.index] += choice.text + if choice.finish_reason is not None and finish_reasons[choice.index] is None: + finish_reasons[choice.index] = choice.finish_reason + if choice.logprobs is not None: + logprob_results[choice.index] = choice.logprobs + + assert all(finish_reason is not None for finish_reason in finish_reasons) + return engine_base.wrap_completion_response( + request_id=request_id, + model=request.model, + output_texts=output_texts, + finish_reasons=finish_reasons, + logprob_results=logprob_results, + usage=request_final_usage, + ) + # 2. Set up app app = fastapi.FastAPI() app.add_middleware(CORSMiddleware) @@ -132,6 +128,3 @@ def serve( # 3. Run uvicorn.run(app, host=router_host, port=router_port, log_level="info") - - # TODO(Charlie): How to properly close the engines? We need to call terminate in each - # underlying endpoint of router diff --git a/python/mlc_llm/protocol/debug_protocol.py b/python/mlc_llm/protocol/debug_protocol.py index bdf34cf776..32fcaefb99 100644 --- a/python/mlc_llm/protocol/debug_protocol.py +++ b/python/mlc_llm/protocol/debug_protocol.py @@ -6,6 +6,7 @@ class DisaggConfig(BaseModel): + """The class of metadata used in microserving APIs.""" kind: Optional[Literal["prepare_prefill", "remote_prefill", "start_decode"]] = None # "kv_append_metadata" is base64-encoded and is thus a string. kv_append_metadata: Optional[str] = None diff --git a/python/mlc_llm/router/router.py b/python/mlc_llm/router/router.py index acdc9bf157..610ae4f4c9 100644 --- a/python/mlc_llm/router/router.py +++ b/python/mlc_llm/router/router.py @@ -1,9 +1,11 @@ +""" Programmable router for dispatching OpenAI API to Microserving API""" + import json import math import threading from typing import Any, AsyncGenerator, Iterable, List, Literal, Optional, Tuple -import aiohttp +import aiohttp # pylint: disable=import-outside-toplevel,import-error import tvm from mlc_llm.protocol import debug_protocol, openai_api_protocol @@ -11,7 +13,8 @@ from mlc_llm.tokenizers import Tokenizer -class Router: +class Router: # pylint: disable=too-many-instance-attributes + """Programmable Router Implementation""" def __init__( self, @@ -23,7 +26,7 @@ def __init__( enable_prefix_cache: bool = False, router_mode: Literal["disagg", "round-robin"] = "disagg", pd_balance_factor: float = 0.0, - ): + ): # pylint: disable=too-many-arguments,too-many-locals,dangerous-default-value """ Spawn len(host_list) server endpoints with Popen. """ @@ -113,7 +116,7 @@ async def handle_completion( ): yield response elif self.router_mode == "round-robin": - async for response in self._handle_completion_round_robin(request, request_id): + async for response in self._handle_completion_round_robin(request): yield response else: raise ValueError("Cannot reach here") @@ -132,7 +135,6 @@ def _pick_endpoint(self, endpoint_ids: Iterable[int]) -> int: async def _handle_completion_round_robin( self, request: openai_api_protocol.CompletionRequest, - request_id: str, ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: """ Handle a completion request from API. Given a streaming request, yields multiple response @@ -171,7 +173,7 @@ async def _handle_completion_round_robin( reason = response.choices[0].finish_reason if reason == "preempt": break - elif reason != None: + if reason is not None: completed = True yield response else: @@ -181,7 +183,7 @@ async def _handle_completion_round_robin( reason = response.choices[0].finish_reason if reason == "preempt": break - elif reason != None: + if reason is not None: completed = True yield response self.num_running_requests[cur_endpoint] -= 1 @@ -196,7 +198,9 @@ async def _handle_completion_disagg( original_request: openai_api_protocol.CompletionRequest, request_id: str, pd_balance_factor=0, - ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: + ) -> AsyncGenerator[ + openai_api_protocol.CompletionResponse, Any + ]: # pylint: disable=too-many-locals """ Handle a completion request from API with disaggregated scheduling. Given two servers P (prefill) and D (decode), the router does the following: @@ -290,7 +294,7 @@ async def _handle_completion_disagg( reason = response_json["choices"][0]["finish_reason"] if reason == "preempt": break - elif reason != None: + if reason is not None: completed = True yield response except Exception as e: diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 2051ce4b51..8b2a57581f 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -667,7 +667,9 @@ def terminate(self): if hasattr(self, "_background_stream_back_loop_thread"): self._background_stream_back_loop_thread.join() - def _debug_call_func_on_all_worker(self, func_name: str, func_args: Optional[str] = None) -> None: + def _debug_call_func_on_all_worker( + self, func_name: str, func_args: Optional[str] = None + ) -> None: """Call the given global function on all workers. Only for debug purpose.""" self._ffi["debug_call_func_on_all_worker"](func_name, func_args) diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index c9f9b2e81d..d192dbf317 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -56,10 +56,13 @@ def __init__( # pylint: disable=too-many-arguments self.base_url = "" self.openai_v1_base_url = "" - def start(self, extra_env={}) -> None: # pylint: disable=too-many-branches,too-many-statements + def start( + self, extra_env=None + ) -> None: # pylint: disable=too-many-branches,too-many-statements """Launch the server in a popen subprocess. Wait until the server becomes ready before return. """ + extra_env = extra_env or {} cmd = [sys.executable] cmd += ["-m", "mlc_llm", "serve", self.model] if self.model_lib is not None: