From 8845ae8772233dcbc83d4690e9f635a50d22fca5 Mon Sep 17 00:00:00 2001 From: Hongyi Jin Date: Tue, 14 Jan 2025 00:04:57 -0500 Subject: [PATCH] [Python] Add Microserving code example (#3089) This commit adds code examples for Microserving. --- cpp/serve/config.cc | 11 +- docs/index.rst | 8 + docs/microserving/tutorial.rst | 205 ++++++++++++++++++ examples/microserving/custom_router.py | 69 ++++++ python/mlc_llm/interface/router.py | 9 +- .../mlc_llm/protocol/microserving_protocol.py | 18 +- python/mlc_llm/router/router.py | 191 ++++++++-------- .../entrypoints/microserving_entrypoints.py | 16 +- 8 files changed, 406 insertions(+), 121 deletions(-) create mode 100644 docs/microserving/tutorial.rst create mode 100644 examples/microserving/custom_router.py diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 6e6aefbf1e..f7e71e72c9 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -186,12 +186,13 @@ Result DebugConfig::FromJSON(const picojson::object& config) { } else { return TResult::Error("Unknown grammar execution mode " + grammar_execution_mode); } - Result disagg_config = - DisaggConfig::FromJSON(json::Lookup(config, "disagg_config")); - if (disagg_config.IsErr()) { - return TResult::Error(disagg_config.UnwrapErr()); + if (auto disagg_config_obj = json::LookupOptional(config, "disagg_config")) { + Result disagg_config = DisaggConfig::FromJSON(disagg_config_obj.value()); + if (disagg_config.IsErr()) { + return TResult::Error(disagg_config.UnwrapErr()); + } + res.disagg_config = disagg_config.Unwrap(); } - res.disagg_config = disagg_config.Unwrap(); return TResult::Ok(res); } diff --git a/docs/index.rst b/docs/index.rst index 19edfb102e..9670fa6244 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -65,6 +65,13 @@ Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a install/gpu.rst install/emcc.rst +.. toctree:: + :maxdepth: 1 + :caption: Microserving API + :hidden: + + microserving/tutorial.rst + .. toctree:: :maxdepth: 1 :caption: Community @@ -80,3 +87,4 @@ Check out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a :hidden: privacy.rst + diff --git a/docs/microserving/tutorial.rst b/docs/microserving/tutorial.rst new file mode 100644 index 0000000000..8494bb20c4 --- /dev/null +++ b/docs/microserving/tutorial.rst @@ -0,0 +1,205 @@ +Implement LLM Cross-engine Orchestration Patterns +====================================================================== + +In this tutorial, we will introduce how to implement LLM cross-engine +orchestration patterns, like prefill-decode disaggregation, in MLC-LLM +via microserving API. Aiming to make disaggregated serving programmable, +MicroServing provides a new RISC-style approach to design LLM serving +API at sub-request level. It enables programmable cross-engine serving +patterns in a few lines of python code. For more information of +microserving API, check out +https://blog.mlc.ai/2025/01/07/microserving-llm-engines. + +Below is an example of prefill-decode disaggregation implementation. An +LLM cross-engine orchestration pattern is implemented in a router, which +dispatches original OpenAI-style completion requests to a chain of +microserving API calls. In this code example, we create a subclass of +Router (which includes wrappers for calling microserving APIs), and +override ``translate_request`` function. The ``translate_request`` +function takes in a request and a unique identifier of the request +(``request_id``), and returns an AsyncGenerator of response. We launch +the CustomRouter and 2 engines, each of which has tensor parallel degree +2. Engine 0 is prefill engine and engine 1 is decode engine. + +.. code:: python + + from mlc_llm.router import Router + from mlc_llm.protocol import openai_api_protocol + from typing import Any, AsyncGenerator + from mlc_llm.serve.entrypoints import microserving_entrypoints + from mlc_llm.interface.router import serve + + import aiohttp + + class CustomRouter(Router): + async def translate_request(self, request: openai_api_protocol.CompletionRequest, request_id: str) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: + pass + + + serve( + model="/path/to/model", # replace this with actual path + model_lib="/path/to/model_lib", # replace this with actual path + router_host="127.0.0.1", + router_port=9123, + endpoint_hosts=["127.0.0.1", "127.0.0.1"], + endpoint_ports=[9124,9125], + endpoint_num_gpus=[2,2], + enable_prefix_cache=False, + router_type=CustomRouter, + ) + +In the ``translate_request`` function, we first assign ``request_id`` to +request.user, and later the request id will be passed as an argument to +the microserving API. + +.. code:: python + + # we will pass request_id as an argument in microserving API calls + request.user = request_id + + +Next, call ``prep_recv`` on the decode engine to prepare KV entries for +receiving from remote. ``end=-1`` means that we will let the prefill +engine prefill all except the last token, which makes sure that the +prefill engine does not need sampling logic. ``prep_recv`` returns +address to receive KV from remote and matched prefix length. For +simplicity, we do not enable prefix cache in the tutorial, so we only +need the kv address here. + +.. code:: python + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True + ) as session: + decode_start = len(request.prompt) -1 + # 1. Ask decode engine to prepare KV entries to receive from prefill engine + prep_recv_request = microserving_entrypoints.PrepRecvRequest( + **request.model_dump(), end=decode_start + ) + ( + kv_addr_info, + _, + ) = await self.send_prepare_receive( + session=session, + request=prep_recv_request, + server_url=self.server_urls[1], # engine 0 is prefill, engine 1 is decode. Here is decode engine + ) + +Then, call ``remote_send`` on the prefill engine to compute and send KV +to decode engine. ``recv_rank=self.device_id_starts[1]`` means that we +are sending KV to engine 1 (decode engine). + +.. code:: python + + + # 2. Ask prefill engine to send KV to decode engine + remote_send_request = microserving_entrypoints.RemoteSendRequest( + **request.model_dump(), + begin=0, + end=decode_start, + kv_addr_info=kv_addr_info, + recv_rank=self.device_id_starts[1], # the rank of decode engine + ) + await self.send_remote_send( + session=session, + request=remote_send_request, + server_url=self.server_urls[0], # prefill engine + ) + +Finally, call ``start_generate`` on the decode engine to start +generating tokens. ``begin=decode_start`` means we will prefill the last +token in the prompt and start decoding. Notably, the decode process of +the request may be preempted. In such case, we yield None, so that the +router will rerun the ``translate_request`` function. + +.. code:: python + + # 3. Start decoding + start_generate_request = microserving_entrypoints.StartGenerateRequest( + **request.model_dump(), + begin=decode_start, + ) + async for response in self.send_start_generate( + session=session, + request=start_generate_request, + server_url=self.server_urls[1], + ): + if len(response.choices) > 0: + finish_reason = response.choices[0].finish_reason + if finish_reason == "preempt": + yield None + yield response + +Bringing everything together, the complete code is as below: + +.. code:: python + + from mlc_llm.router import Router + from mlc_llm.protocol import openai_api_protocol + from typing import Any, AsyncGenerator + from mlc_llm.serve.entrypoints import microserving_entrypoints + from mlc_llm.interface.router import serve + + import aiohttp + class CustomRouter(Router): + async def translate_request(self, request: openai_api_protocol.CompletionRequest, request_id: str) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: + # we will pass request_id as an argument in microserving API calls + request.user = request_id + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True + ) as session: + decode_start = len(request.prompt) -1 + # 1. Ask decode engine to prepare KV entries to receive from prefill engine + prep_recv_request = microserving_entrypoints.PrepRecvRequest( + **request.model_dump(), end=decode_start + ) + ( + kv_addr_info, + _, + ) = await self.send_prepare_receive( + session=session, + request=prep_recv_request, + server_url=self.server_urls[1], # engine 0 is prefill, engine 1 is decode. Here is decode engine + ) + # 2. Ask prefill engine to send KV to decode engine + remote_send_request = microserving_entrypoints.RemoteSendRequest( + **request.model_dump(), + begin=0, + end=decode_start, + kv_addr_info=kv_addr_info, + recv_rank=self.device_id_starts[1], # the rank of decode engine + ) + await self.send_remote_send( + session=session, + request=remote_send_request, + server_url=self.server_urls[0], # prefill engine + ) + # 3. Start decoding + start_generate_request = microserving_entrypoints.StartGenerateRequest( + **request.model_dump(), + begin=decode_start, + ) + async for response in self.send_start_generate( + session=session, + request=start_generate_request, + server_url=self.server_urls[1], + ): + if len(response.choices) > 0: + finish_reason = response.choices[0].finish_reason + if finish_reason == "preempt": + yield None + yield response + + + serve( + model="/path/to/model", # replace this with actual path + model_lib="/path/to/model_lib", # replace this with actual path + router_host="127.0.0.1", + router_port=9123, + endpoint_hosts=["127.0.0.1", "127.0.0.1"], + endpoint_ports=[9124,9125], + endpoint_num_gpus=[2,2], + enable_prefix_cache=False, + router_type=CustomRouter, + ) diff --git a/examples/microserving/custom_router.py b/examples/microserving/custom_router.py new file mode 100644 index 0000000000..5532c3de3b --- /dev/null +++ b/examples/microserving/custom_router.py @@ -0,0 +1,69 @@ +from mlc_llm.router import Router +from mlc_llm.protocol import openai_api_protocol +from typing import Any, AsyncGenerator +from mlc_llm.serve.entrypoints import microserving_entrypoints +from mlc_llm.interface.router import serve + +import aiohttp +class CustomRouter(Router): + async def translate_request(self, request: openai_api_protocol.CompletionRequest, request_id: str) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: + # we will pass request_id as an argument in microserving API calls + request.user = request_id + + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True + ) as session: + decode_start = len(request.prompt) -1 + # 1. Ask decode engine to prepare KV entries to receive from prefill engine + prep_recv_request = microserving_entrypoints.PrepRecvRequest( + **request.model_dump(), end=decode_start + ) + ( + kv_addr_info, + _, + ) = await self.send_prepare_receive( + session=session, + request=prep_recv_request, + server_url=self.server_urls[1], # engine 0 is prefill, engine 1 is decode. Here is decode engine + ) + # 2. Ask prefill engine to send KV to decode engine + remote_send_request = microserving_entrypoints.RemoteSendRequest( + **request.model_dump(), + begin=0, + end=decode_start, + kv_addr_info=kv_addr_info, + recv_rank=self.device_id_starts[1], # the rank of decode engine + ) + await self.send_remote_send( + session=session, + request=remote_send_request, + server_url=self.server_urls[0], # prefill engine + ) + # 3. Start decoding + start_generate_request = microserving_entrypoints.StartGenerateRequest( + **request.model_dump(), + begin=decode_start, + ) + async for response in self.send_start_generate( + session=session, + request=start_generate_request, + server_url=self.server_urls[1], + ): + if len(response.choices) > 0: + finish_reason = response.choices[0].finish_reason + if finish_reason == "preempt": + yield None + yield response + + +serve( + model="/opt/dlami/nvme/models/Llama-3.1-8B-Instruct-q0f16-MLC", # replace this with actual path + model_lib="/opt/dlami/nvme/models/Llama-3.1-8B-Instruct-q0f16-MLC/lib_disagg.so", # replace this with actual path + router_host="127.0.0.1", + router_port=9123, + endpoint_hosts=["127.0.0.1", "127.0.0.1"], + endpoint_ports=[9124,9125], + endpoint_num_gpus=[2,2], + enable_prefix_cache=False, + router_type=CustomRouter, +) \ No newline at end of file diff --git a/python/mlc_llm/interface/router.py b/python/mlc_llm/interface/router.py index 930c67166b..c609fac138 100644 --- a/python/mlc_llm/interface/router.py +++ b/python/mlc_llm/interface/router.py @@ -2,7 +2,7 @@ # pylint: disable=fixme from http import HTTPStatus -from typing import AsyncGenerator, List, Literal, Optional +from typing import AsyncGenerator, List, Literal, Optional, Type import fastapi import uvicorn @@ -23,12 +23,13 @@ def serve( endpoint_ports: List[int], endpoint_num_gpus: List[int], enable_prefix_cache: bool, - router_mode: Literal["disagg", "round-robin"], - pd_balance_factor: float, + router_mode: Literal["disagg", "round-robin"] = "round-robin", + pd_balance_factor: float = 0.0, + router_type: Type[Router] = Router, ): # pylint: disable=too-many-arguments """Start the router with the specified configuration.""" # 1. Instantiate router - router = Router( + router = router_type( model=model, model_lib=model_lib, hosts=endpoint_hosts, diff --git a/python/mlc_llm/protocol/microserving_protocol.py b/python/mlc_llm/protocol/microserving_protocol.py index fa8cdd63c6..83bcc435ea 100644 --- a/python/mlc_llm/protocol/microserving_protocol.py +++ b/python/mlc_llm/protocol/microserving_protocol.py @@ -16,7 +16,7 @@ class PrepRecvRequest(CompletionRequest): The entries of this KV range will be allocated on the decode instance. """ - kv_window_end: int + end: int class PrepRecvResponse(BaseModel): @@ -24,9 +24,6 @@ class PrepRecvResponse(BaseModel): Attributes ---------- - prompt_length : int - The length of the request prompt in tokens. - prefix_matched_length : int The matched common prefix length on the decode instance when prefix cache is enabled, or 0 if there is no prefix cache. @@ -35,9 +32,8 @@ class PrepRecvResponse(BaseModel): The metadata of the KV range on the destination decode instance. """ - prompt_length: int - prefix_matched_length: int kv_append_metadata: str + prefix_matched_length: int class RemoteSendRequest(CompletionRequest): @@ -58,10 +54,10 @@ class RemoteSendRequest(CompletionRequest): The node group offset of the destination decode instance. """ - kv_window_begin: int - kv_window_end: int - kv_append_metadata: str - dst_group_offset: int + begin: int + end: int + kv_addr_info: str + recv_rank: int class StartGenerateRequest(CompletionRequest): @@ -73,4 +69,4 @@ class StartGenerateRequest(CompletionRequest): Denote the start of the KV range to prefill on the decode instance. """ - kv_window_begin: int + begin: int diff --git a/python/mlc_llm/router/router.py b/python/mlc_llm/router/router.py index 3ab8a8e210..833156da3e 100644 --- a/python/mlc_llm/router/router.py +++ b/python/mlc_llm/router/router.py @@ -118,6 +118,24 @@ async def handle_completion( """ if isinstance(request.prompt, str): request.prompt = self.tokenizer.encode(request.prompt) + # Add a debugConfig if not present + if request.debug_config is None: + request.debug_config = openai_api_protocol.DebugConfig() + completed = False + while not completed: + completed = True + async for response in self.translate_request(request, request_id): + if response is None: + completed = False + break + yield response + + async def translate_request( + self, request: openai_api_protocol.CompletionRequest, request_id: str + ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: + """ + Translate OpenAI API request to microserving API calls. + """ if self.router_mode == "disagg": async for response in self._handle_completion_disagg( request, request_id, pd_balance_factor=self.pd_balance_factor @@ -156,44 +174,43 @@ async def _handle_completion_round_robin( async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True ) as session: + # pylint: disable=fixme + # todo: replace this with start_generate + # pylint: enable=fixme async with session.post( - self.server_urls[cur_endpoint], json=payload, headers=self.headers + self.server_urls[cur_endpoint] + "/v1/completions", + json=payload, + headers=self.headers, ) as response: assert response.status == 200, await response.text() - completed = False - while not completed: - if payload["stream"]: - async for chunk in response.content: - # Convert raw bytes to CompletionResponse - chunk = chunk.strip() - if not chunk or chunk == b"\n": - continue - # Get rid of the prefix "data: " and suffix "\n" - raw_data = chunk[6:].strip() - if raw_data == b"[DONE]": - continue - data = json.loads(raw_data) - # Commented because we still want usage chunk to be passed back - # if not data["choices"]: - # continue - response = openai_api_protocol.CompletionResponse.model_validate(data) - if response.choices: - reason = response.choices[0].finish_reason - if reason == "preempt": - break - if reason is not None: - completed = True - yield response - else: - data = await response.json() + if payload["stream"]: + async for chunk in response.content: + # Convert raw bytes to CompletionResponse + chunk = chunk.strip() + if not chunk or chunk == b"\n": + continue + # Get rid of the prefix "data: " and suffix "\n" + raw_data = chunk[6:].strip() + if raw_data == b"[DONE]": + continue + data = json.loads(raw_data) + # Commented because we still want usage chunk to be passed back + # if not data["choices"]: + # continue response = openai_api_protocol.CompletionResponse.model_validate(data) if response.choices: reason = response.choices[0].finish_reason if reason == "preempt": - break - if reason is not None: - completed = True + yield None yield response + else: + data = await response.json() + response = openai_api_protocol.CompletionResponse.model_validate(data) + if response.choices: + reason = response.choices[0].finish_reason + if reason == "preempt": + yield None + yield response self.num_running_requests[cur_endpoint] -= 1 # @@ -220,10 +237,6 @@ async def _handle_completion_disagg( # pylint: disable=too-many-locals prefill_server_id = 0 decode_server_id = self._pick_endpoint(range(1, self.num_servers)) - # Add a debugConfig if not present - if original_request.debug_config is None: - original_request.debug_config = openai_api_protocol.DebugConfig() - # Tell D to prepare metadata for prompt[0:kv_window_end]. # P does not need to sample. Ask D to treat the last # token like the first sampled token. @@ -232,69 +245,65 @@ async def _handle_completion_disagg( # pylint: disable=too-many-locals if math.fabs(pd_balance_factor) < 1e-5 else int((1 - pd_balance_factor) * len(original_request.prompt)) ) - async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True ) as session: self.num_running_requests[decode_server_id] += 1 try: - completed = False - while not completed: - # 1. Ask D to prepare metadata - prep_recv_request = microserving_entrypoints.PrepRecvRequest( - **original_request.model_dump(), kv_window_end=kv_window_end - ) - ( - prompt_length, - prefix_matched_length, - kv_append_metadata_base64, - ) = await self.send_prepare_receive( - session=session, - request=prep_recv_request, - server_url=self.server_urls[decode_server_id], - ) - - kv_window_end = ( - prompt_length + kv_window_end if kv_window_end < 0 else kv_window_end - ) - assert prefix_matched_length <= kv_window_end + # 1. Ask D to prepare metadata + prep_recv_request = microserving_entrypoints.PrepRecvRequest( + **original_request.model_dump(), end=kv_window_end + ) + ( + kv_append_metadata_base64, + prefix_matched_length, + ) = await self.send_prepare_receive( + session=session, + request=prep_recv_request, + server_url=self.server_urls[decode_server_id], + ) - # 2. Send P the prefill request and D's metadata. When it returns, it means that - # KV transfer has finished prefilling and transferring the KV of - # prompt[prefix_matched_length:kv_window_end]. So D is ready to decode. - if prefix_matched_length < kv_window_end: - remote_send_request = microserving_entrypoints.RemoteSendRequest( - **original_request.model_dump(), - kv_window_begin=prefix_matched_length, - kv_window_end=kv_window_end, - kv_append_metadata=kv_append_metadata_base64, - dst_group_offset=self.device_id_starts[decode_server_id], - ) - await self.send_remote_send( - session=session, - request=remote_send_request, - server_url=self.server_urls[prefill_server_id], - ) + kv_window_end = ( + len(original_request.prompt) + kv_window_end + if kv_window_end < 0 + else kv_window_end + ) + assert prefix_matched_length <= kv_window_end - # 3. Start decoding, receive and yield back response as a normal request - # The kv window passed through denotes the range to prefill on the - # decode server, which should be [-1:] here. - start_generate_request = microserving_entrypoints.StartGenerateRequest( + # 2. Send P the prefill request and D's metadata. When it returns, it means that + # KV transfer has finished prefilling and transferring the KV of + # prompt[prefix_matched_length:kv_window_end]. So D is ready to decode. + if prefix_matched_length < kv_window_end: + remote_send_request = microserving_entrypoints.RemoteSendRequest( **original_request.model_dump(), - kv_window_begin=kv_window_end, + begin=prefix_matched_length, + end=kv_window_end, + kv_addr_info=kv_append_metadata_base64, + recv_rank=self.device_id_starts[decode_server_id], ) - async for response in self.send_start_generate( + await self.send_remote_send( session=session, - request=start_generate_request, - server_url=self.server_urls[decode_server_id], - ): - if len(response.choices) > 0: - finish_reason = response.choices[0].finish_reason - if finish_reason == "preempt": - break - if finish_reason is not None: - completed = True - yield response + request=remote_send_request, + server_url=self.server_urls[prefill_server_id], + ) + + # 3. Start decoding, receive and yield back response as a normal request + # The kv window passed through denotes the range to prefill on the + # decode server, which should be [-1:] here. + start_generate_request = microserving_entrypoints.StartGenerateRequest( + **original_request.model_dump(), + begin=kv_window_end, + ) + async for response in self.send_start_generate( + session=session, + request=start_generate_request, + server_url=self.server_urls[decode_server_id], + ): + if len(response.choices) > 0: + finish_reason = response.choices[0].finish_reason + if finish_reason == "preempt": + yield None + yield response except Exception as e: self.num_running_requests[decode_server_id] -= 1 raise e @@ -305,15 +314,14 @@ async def send_prepare_receive( session: aiohttp.ClientSession, request: openai_api_protocol.CompletionRequest, server_url: str, - ) -> Tuple[int, int, str]: + ) -> Tuple[str, int]: """ Performs step 1 of disaggregated serving: ask D to prepare metadata. Returns: The metadata received from D, which is a tuple of 2 elements: - - prompt_length, which is the raw prompt length of the request. + - kv_append_metadata_base64: str, info about KV append encoded in base64 string - prefix_matched_length: int, length of the matched prefix. i.e. prompt[0:prefix_matched_length] is the matched prefix - - kv_append_metadata_base64: str, info about KV append encoded in base64 string """ # Send request to the decode server for receive preparation. # Get the prompt length, matched prefix length and the KV metadata. @@ -326,9 +334,8 @@ async def send_prepare_receive( data = await response.json() return ( - data["prompt_length"], - data["prefix_matched_length"], data["kv_append_metadata"], + data["prefix_matched_length"], ) async def send_remote_send( diff --git a/python/mlc_llm/serve/entrypoints/microserving_entrypoints.py b/python/mlc_llm/serve/entrypoints/microserving_entrypoints.py index a9a062e57a..1371fdfad0 100644 --- a/python/mlc_llm/serve/entrypoints/microserving_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/microserving_entrypoints.py @@ -24,12 +24,12 @@ async def prep_recv(request: PrepRecvRequest, raw_request: fastapi.Request) -> P """Handle the microserving request for receive preparation. Match the prompt in the prefix cache (when enabled), allocate entries in the KV cache to prepare receiving the KV data of the prompt. - Return the prompt length, matched prefix length and the allocated KV entry metadata. + Return the matched prefix length and the allocated KV entry metadata. """ request.debug_config.disagg_config = DisaggConfig( kind="prepare_receive", kv_window_begin=0, # always zero for prepare_receive - kv_window_end=request.kv_window_end, + kv_window_end=request.end, ) request.stream_options = StreamOptions(include_usage=True) request.stream = False @@ -37,11 +37,9 @@ async def prep_recv(request: PrepRecvRequest, raw_request: fastapi.Request) -> P response = await request_completion(request=request, raw_request=raw_request) assert response.usage is not None assert response.usage.extra is not None - assert "prompt_length" in response.usage.extra assert "prefix_matched_length" in response.usage.extra assert "kv_append_metadata" in response.usage.extra return PrepRecvResponse( - prompt_length=response.usage.extra["prompt_length"], prefix_matched_length=response.usage.extra["prefix_matched_length"], kv_append_metadata=response.usage.extra["kv_append_metadata"], ) @@ -53,10 +51,10 @@ async def remote_send(request: RemoteSendRequest, raw_request: fastapi.Request): Send the KV data to the destination server.""" request.debug_config.disagg_config = DisaggConfig( kind="remote_send", - kv_window_begin=request.kv_window_begin, - kv_window_end=request.kv_window_end, - kv_append_metadata=request.kv_append_metadata, - dst_group_offset=request.dst_group_offset, + kv_window_begin=request.begin, + kv_window_end=request.end, + kv_append_metadata=request.kv_addr_info, + dst_group_offset=request.recv_rank, ) request.stream_options = StreamOptions(include_usage=True) request.stream = False @@ -70,6 +68,6 @@ async def start_generate(request: StartGenerateRequest, raw_request: fastapi.Req """Prefill the prompt in the specified KV window, and start decode.""" request.debug_config.disagg_config = DisaggConfig( kind="start_generation", - kv_window_begin=request.kv_window_begin, + kv_window_begin=request.begin, ) return await request_completion(request=request, raw_request=raw_request)