Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
jinhongyii committed Dec 13, 2024
1 parent 67e8c57 commit e43a578
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 101 deletions.
13 changes: 13 additions & 0 deletions python/mlc_llm/cli/router.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Command line entrypoint of router."""

from mlc_llm.interface.help import HELP
from mlc_llm.interface.router import serve
from mlc_llm.support.argparse import ArgumentParser


def main(argv):
"""Parse command line arguments and call `mlc_llm.interface.router`."""
# Define a custom argument type for a list of strings
def list_of_strings(arg):
return arg.split(",")
Expand All @@ -12,54 +15,64 @@ def list_of_strings(arg):
parser.add_argument(
"model",
type=str,
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--model-lib",
type=str,
default=None,
help=HELP["model_lib"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--router-mode",
type=str,
choices=["disagg", "round-robin"],
default="disagg",
help="router mode" + ' (default: "%(default)s")',
)
parser.add_argument(
"--router-host",
type=str,
default="127.0.0.1",
help="router host" + ' (default: "%(default)s")',
)
parser.add_argument(
"--router-port",
type=int,
default=8000,
help="router port" + ' (default: "%(default)s")',
)
parser.add_argument(
"--endpoint-hosts",
type=list_of_strings,
default="127.0.0.1",
help="Host of each endpoint, seperated by comma." + ' (default: "%(default)s")',
)
parser.add_argument(
"--endpoint-ports",
nargs="*",
type=int,
default=[8080],
help="Port of each endpoint, separated by space." + ' (default: "%(default)s")',
)
parser.add_argument(
"--endpoint-num-gpus",
nargs="*",
type=int,
default=[1],
help="Number of GPUs of each endpoint, separated by space." + ' (default: "%(default)s")',
)
parser.add_argument(
"--enable-prefix-cache",
default=False,
action="store_true",
help="whether to enable prefix cache" + ' (default: "%(default)s")',
)
parser.add_argument(
"--pd-balance-factor",
type=float,
default=0.0,
help=HELP["pd_balance_factor"] + ' (default: "%(default)s")',
)
parsed = parser.parse_args(argv)
serve(
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/interface/compiler_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _cudagraph(target) -> bool:


@dataclasses.dataclass
class ModelConfigOverride(ConfigOverrideBase):
class ModelConfigOverride(ConfigOverrideBase): # pylint: disable=too-many-instance-attributes
"""Flags for overriding model config."""

context_window_size: Optional[int] = None
Expand Down
4 changes: 4 additions & 0 deletions python/mlc_llm/interface/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,4 +262,8 @@
""".strip(),
"seed_calibrate": """
The seed to sample the calibration dataset.""",
"pd_balance_factor": """
How much prefill to move to decode engine. For example,
0.1 means the last 10 percent tokens are prefilled by decode engine.
""".strip(),
}
171 changes: 82 additions & 89 deletions python/mlc_llm/interface/router.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Python entrypoint of router."""

from http import HTTPStatus
from typing import AsyncGenerator, List, Literal, Optional

Expand All @@ -13,90 +15,6 @@
#
# Global variables
#
router_app = fastapi.APIRouter()
router = None

#
# Define APIs
#


@router_app.post("/v1/completions")
async def request_completion(request: CompletionRequest, raw_request: fastapi.Request):
"""OpenAI-compatible completion API.
API reference: https://platform.openai.com/docs/api-reference/completions/create
"""
global router
if router is None:
return error_protocol.create_error_response(
HTTPStatus.BAD_REQUEST, message="Router is not initialized."
)
request_id = f"cmpl-{engine_utils.random_uuid()}"

# Streaming response.
if request.stream:
# We manually get the first response from generator to
# capture potential exceptions in this scope, rather then
# the StreamingResponse scope.
stream_generator = router.handle_completion( # pylint: disable=protected-access
request, request_id
)
first_response = await anext( # type: ignore # pylint: disable=undefined-variable
stream_generator
)

async def completion_stream_generator() -> AsyncGenerator[str, None]:
if isinstance(first_response, StopAsyncIteration):
yield "data: [DONE]\n\n"
return
yield f"data: {first_response.model_dump_json(by_alias=True)}\n\n"
async for response in stream_generator:
yield f"data: {response.model_dump_json(by_alias=True)}\n\n"
yield "data: [DONE]\n\n"

return fastapi.responses.StreamingResponse(
completion_stream_generator(), media_type="text/event-stream"
)

# Normal response.
request_final_usage = None
output_texts = [""] * request.n
finish_reasons: List[Optional[str]] = [None] * request.n
logprob_results: List[Optional[CompletionLogProbs]] = [None] * request.n

async for response in router.handle_completion( # pylint: disable=protected-access
request, request_id
):
if await raw_request.is_disconnected():
# In non-streaming cases, the engine will not be notified
# when the request is disconnected.
# Therefore, we check if it is disconnected each time,
# and explicitly return.
# Note that requesta abort is triggered when the async for and funciton scope ends.
return error_protocol.create_error_response(
HTTPStatus.BAD_REQUEST, message="The request has disconnected"
)
# TODO(Charlie): This is copied from engine.py -- why is it here? Non-streaming only has a single chunk right?
# this is the final chunk
# if response.usage is not None:
# request_final_usage = response.usage
# continue
for choice in response.choices:
output_texts[choice.index] += choice.text
if choice.finish_reason is not None and finish_reasons[choice.index] is None:
finish_reasons[choice.index] = choice.finish_reason
if choice.logprobs is not None:
logprob_results[choice.index] = choice.logprobs

assert all(finish_reason is not None for finish_reason in finish_reasons)
return engine_base.wrap_completion_response(
request_id=request_id,
model=request.model,
output_texts=output_texts,
finish_reasons=finish_reasons,
logprob_results=logprob_results,
usage=request_final_usage,
)


def serve(
Expand All @@ -110,9 +28,8 @@ def serve(
enable_prefix_cache: bool,
router_mode: Literal["disagg", "round-robin"],
pd_balance_factor: float,
):
): # pylint: disable=too-many-arguments
# 1. Instantiate router
global router
router = Router(
model=model,
model_lib=model_lib,
Expand All @@ -124,6 +41,85 @@ def serve(
pd_balance_factor=pd_balance_factor,
)

router_app = fastapi.APIRouter()

@router_app.post("/v1/completions")
async def request_completion(request: CompletionRequest, raw_request: fastapi.Request):
"""OpenAI-compatible completion API.
API reference: https://platform.openai.com/docs/api-reference/completions/create
"""
if router is None:
return error_protocol.create_error_response(
HTTPStatus.BAD_REQUEST, message="Router is not initialized."
)
request_id = f"cmpl-{engine_utils.random_uuid()}"

# Streaming response.
if request.stream:
# We manually get the first response from generator to
# capture potential exceptions in this scope, rather then
# the StreamingResponse scope.
stream_generator = router.handle_completion( # pylint: disable=protected-access
request, request_id
)
first_response = await anext( # type: ignore # pylint: disable=undefined-variable
stream_generator
)

async def completion_stream_generator() -> AsyncGenerator[str, None]:
if isinstance(first_response, StopAsyncIteration):
yield "data: [DONE]\n\n"
return
yield f"data: {first_response.model_dump_json(by_alias=True)}\n\n"
async for response in stream_generator:
yield f"data: {response.model_dump_json(by_alias=True)}\n\n"
yield "data: [DONE]\n\n"

return fastapi.responses.StreamingResponse(
completion_stream_generator(), media_type="text/event-stream"
)

# Normal response.
request_final_usage = None
output_texts = [""] * request.n
finish_reasons: List[Optional[str]] = [None] * request.n
logprob_results: List[Optional[CompletionLogProbs]] = [None] * request.n

async for response in router.handle_completion( # pylint: disable=protected-access
request, request_id
):
if await raw_request.is_disconnected():
# In non-streaming cases, the engine will not be notified
# when the request is disconnected.
# Therefore, we check if it is disconnected each time,
# and explicitly return.
# Note that requesta abort is triggered when the async for and funciton scope ends.
return error_protocol.create_error_response(
HTTPStatus.BAD_REQUEST, message="The request has disconnected"
)
# TODO(Charlie): This is copied from engine.py --
# why is it here? Non-streaming only has a single chunk right?
# this is the final chunk
# if response.usage is not None:
# request_final_usage = response.usage
# continue
for choice in response.choices:
output_texts[choice.index] += choice.text
if choice.finish_reason is not None and finish_reasons[choice.index] is None:
finish_reasons[choice.index] = choice.finish_reason
if choice.logprobs is not None:
logprob_results[choice.index] = choice.logprobs

assert all(finish_reason is not None for finish_reason in finish_reasons)
return engine_base.wrap_completion_response(
request_id=request_id,
model=request.model,
output_texts=output_texts,
finish_reasons=finish_reasons,
logprob_results=logprob_results,
usage=request_final_usage,
)

# 2. Set up app
app = fastapi.FastAPI()
app.add_middleware(CORSMiddleware)
Expand All @@ -132,6 +128,3 @@ def serve(

# 3. Run
uvicorn.run(app, host=router_host, port=router_port, log_level="info")

# TODO(Charlie): How to properly close the engines? We need to call terminate in each
# underlying endpoint of router
1 change: 1 addition & 0 deletions python/mlc_llm/protocol/debug_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


class DisaggConfig(BaseModel):
"""The class of metadata used in microserving APIs."""
kind: Optional[Literal["prepare_prefill", "remote_prefill", "start_decode"]] = None
# "kv_append_metadata" is base64-encoded and is thus a string.
kv_append_metadata: Optional[str] = None
Expand Down
22 changes: 13 additions & 9 deletions python/mlc_llm/router/router.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
""" Programmable router for dispatching OpenAI API to Microserving API"""

import json
import math
import threading
from typing import Any, AsyncGenerator, Iterable, List, Literal, Optional, Tuple

import aiohttp
import aiohttp # pylint: disable=import-outside-toplevel,import-error
import tvm

from mlc_llm.protocol import debug_protocol, openai_api_protocol
from mlc_llm.serve import EngineConfig, PopenServer
from mlc_llm.tokenizers import Tokenizer


class Router:
class Router: # pylint: disable=too-many-instance-attributes
"""Programmable Router Implementation"""

def __init__(
self,
Expand All @@ -23,7 +26,7 @@ def __init__(
enable_prefix_cache: bool = False,
router_mode: Literal["disagg", "round-robin"] = "disagg",
pd_balance_factor: float = 0.0,
):
): # pylint: disable=too-many-arguments,too-many-locals,dangerous-default-value
"""
Spawn len(host_list) server endpoints with Popen.
"""
Expand Down Expand Up @@ -113,7 +116,7 @@ async def handle_completion(
):
yield response
elif self.router_mode == "round-robin":
async for response in self._handle_completion_round_robin(request, request_id):
async for response in self._handle_completion_round_robin(request):
yield response
else:
raise ValueError("Cannot reach here")
Expand All @@ -132,7 +135,6 @@ def _pick_endpoint(self, endpoint_ids: Iterable[int]) -> int:
async def _handle_completion_round_robin(
self,
request: openai_api_protocol.CompletionRequest,
request_id: str,
) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:
"""
Handle a completion request from API. Given a streaming request, yields multiple response
Expand Down Expand Up @@ -171,7 +173,7 @@ async def _handle_completion_round_robin(
reason = response.choices[0].finish_reason
if reason == "preempt":
break
elif reason != None:
if reason is not None:
completed = True
yield response
else:
Expand All @@ -181,7 +183,7 @@ async def _handle_completion_round_robin(
reason = response.choices[0].finish_reason
if reason == "preempt":
break
elif reason != None:
if reason is not None:
completed = True
yield response
self.num_running_requests[cur_endpoint] -= 1
Expand All @@ -196,7 +198,9 @@ async def _handle_completion_disagg(
original_request: openai_api_protocol.CompletionRequest,
request_id: str,
pd_balance_factor=0,
) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:
) -> AsyncGenerator[
openai_api_protocol.CompletionResponse, Any
]: # pylint: disable=too-many-locals
"""
Handle a completion request from API with disaggregated scheduling. Given two servers
P (prefill) and D (decode), the router does the following:
Expand Down Expand Up @@ -290,7 +294,7 @@ async def _handle_completion_disagg(
reason = response_json["choices"][0]["finish_reason"]
if reason == "preempt":
break
elif reason != None:
if reason is not None:
completed = True
yield response
except Exception as e:
Expand Down
Loading

0 comments on commit e43a578

Please sign in to comment.