From 75c6d0eb57fe3e22a7d5dae6ddda38caa061de82 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 11 Dec 2024 12:54:04 +0000 Subject: [PATCH] add continuous batching loop 1/n (#387) * add continuous batching loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * update * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * bump version * fixes * update * fix * update * remove rich * attach batch_size * update * update * update * fix * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/litserve/__about__.py | 2 +- src/litserve/api.py | 1 + src/litserve/loops.py | 244 +++++++++++++++++++++++++++++++++++++- 3 files changed, 241 insertions(+), 6 deletions(-) diff --git a/src/litserve/__about__.py b/src/litserve/__about__.py index 899d1a3f..cef9982b 100644 --- a/src/litserve/__about__.py +++ b/src/litserve/__about__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.6.dev0" +__version__ = "0.2.6.dev1" __author__ = "Lightning-AI et al." __author_email__ = "community@lightning.ai" __license__ = "Apache-2.0" diff --git a/src/litserve/api.py b/src/litserve/api.py index 321ea9c2..c54a7b80 100644 --- a/src/litserve/api.py +++ b/src/litserve/api.py @@ -114,6 +114,7 @@ def device(self, value): self._device = value def _sanitize(self, max_batch_size: int, spec: Optional[LitSpec]): + self.max_batch_size = max_batch_size if self.stream: self._default_unbatch = self._unbatch_stream else: diff --git a/src/litserve/loops.py b/src/litserve/loops.py index 28e061cf..8e10653f 100644 --- a/src/litserve/loops.py +++ b/src/litserve/loops.py @@ -18,8 +18,9 @@ import sys import time from abc import ABC +from dataclasses import dataclass from queue import Empty, Queue -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from fastapi import HTTPException from starlette.formparsers import MultiPartParser @@ -37,10 +38,11 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) except ImportError: - print( - "uvloop is not installed. Falling back to the default asyncio event loop. " - "Please install uvloop for better performance using `pip install uvloop`." - ) + if sys.platform != "win32": + print( + "uvloop is not installed. Falling back to the default asyncio event loop. " + "Please install uvloop for better performance using `pip install uvloop`." + ) logger = logging.getLogger(__name__) @@ -573,6 +575,238 @@ def __call__( ) +def notify_timed_out_requests( + response_queues: List[Queue], + timed_out_uids: List[Tuple[int, str]], +): + for response_queue_id, uid in timed_out_uids: + logger.error(f"Request {uid} was waiting in the queue for too long and has been timed out.") + response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR))) + + +@dataclass +class Output: + """Outputs from a single step of the loop.""" + + uid: str + output: Any + status: LitAPIStatus + + +class LitLoop(_BaseLoop): + def __init__(self): + self._context = {} + + def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float): + if max_batch_size <= 1: + raise ValueError("max_batch_size must be greater than 1") + + batches, timed_out_uids = collate_requests( + lit_api, + request_queue, + max_batch_size, + batch_timeout, + ) + return batches, timed_out_uids + + def get_request(self, request_queue: Queue, timeout: float = 1.0): + response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout) + return response_queue_id, uid, timestamp, x_enc + + def populate_context(self, lit_spec: LitSpec, request: Any): + if lit_spec and hasattr(lit_spec, "populate_context"): + lit_spec.populate_context(self._context, request) + + def put_response( + self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus + ) -> None: + response_queues[response_queue_id].put((uid, (response_data, status))) + + def put_error_response( + self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception + ) -> None: + response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR))) + + +class ContinuousBatchingLoop(LitLoop): + def __init__(self, max_sequence_length: int = 2048): + super().__init__() + self.active_sequences: Dict[str, Dict] = {} # uid -> {input, current_length, generated_tokens} + self.max_sequence_length = max_sequence_length + self.response_queue_ids: Dict[str, int] = {} # uid -> response_queue_id + + def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None: + """Add a new sequence to active sequences.""" + decoded_request = lit_api.decode_request(request) + self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_tokens": []} + + def mark_completed(self, uid: str) -> None: + """Mark a sequence as completed.""" + logger.info(f"Marking sequence {uid} as completed") + del self.active_sequences[uid] + del self.response_queue_ids[uid] + + def has_capacity(self, lit_api: LitAPI) -> bool: + """Check if we can add more sequences based on current batch.""" + capacity = len(self.active_sequences) < lit_api.max_batch_size + if not capacity: + logger.info( + f"No capacity: {len(self.active_sequences)} active sequences, max batch size: {lit_api.max_batch_size}" + ) + return capacity + + def step( + self, prev_outputs: Optional[List[Output]], lit_api: LitAPI, lit_spec: Optional[LitSpec] + ) -> List[Tuple[str, Tuple[Any, LitAPIStatus]]]: + """Process one token generation step for all active sequences.""" + if not self.active_sequences: + return [] + + # Batch forward pass for all active sequences + inputs = [seq["input"] for seq in self.active_sequences.values()] + generated = [seq["generated_tokens"] for seq in self.active_sequences.values()] + + try: + # Assume lit_api.predict handles batched token generation + new_tokens = lit_api.predict(inputs, generated) + + responses = [] + + # Process each sequence's new token + for uid, token in zip(self.active_sequences.keys(), new_tokens): + seq = self.active_sequences[uid] + seq["generated_tokens"].append(token) + seq["current_length"] += 1 + + # Check completion conditions + is_finished = lit_api.is_finished(uid, token, self.max_sequence_length) + + if is_finished: + # Encode final response for completed sequence + response = lit_api.encode_response(seq["generated_tokens"]) + step_output = Output(uid, response, LitAPIStatus.FINISH_STREAMING) + responses.append(step_output) + + return responses + + except Exception as e: + logger.exception("Error during batch token generation") + # On error, terminate all active sequences + responses = [(uid, (e, LitAPIStatus.ERROR)) for uid in self.active_sequences] + self.active_sequences.clear() + return responses + + def prefill( + self, + pending_requests: List[Tuple[str, Any]], + lit_api: LitAPI, + lit_spec: Optional[LitSpec], + request_queue: Queue, + max_batch_size: int, + response_queues: List[Queue], + ) -> List[Tuple[str, Any]]: + """Fill available capacity with pending and new requests.""" + # First process existing pending requests + while pending_requests and self.has_capacity(lit_api): + response_queue_id, uid, input = pending_requests.pop(0) + self.add_request(uid, input, lit_api, lit_spec) + self.response_queue_ids[uid] = response_queue_id + + # Then check for new requests if we still have capacity + if self.has_capacity(lit_api): + new_batches, timed_out_uids = self.get_batch_requests( + lit_api, request_queue, max_batch_size, batch_timeout=0.0001 + ) + notify_timed_out_requests(response_queues, timed_out_uids) + + if new_batches: + # Add new requests to pending_requests and try to process them + for response_queue_id, uid, input in new_batches: + logger.info(f"New request: {uid}, {input}") + if self.has_capacity(lit_api): + self.add_request(uid, input, lit_api, lit_spec) + self.response_queue_ids[uid] = response_queue_id + else: + pending_requests.append((response_queue_id, uid, input)) + + return pending_requests + + def run( + self, + lit_api: LitAPI, + lit_spec: Optional[LitSpec], + device: str, + worker_id: int, + request_queue: Queue, + response_queues: List[Queue], + max_batch_size: int, + batch_timeout: float, + stream: bool, + workers_setup_status: Dict[int, str], + callback_runner: CallbackRunner, + ): + if not lit_api.stream: + raise ValueError( + "Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)" + ) + + """Main loop that processes batches of requests.""" + pending_requests = self.prefill( + [], + lit_api, + lit_spec, + request_queue, + max_batch_size, + response_queues, + ) + try: + prev_outputs = None + while pending_requests or self.active_sequences: + # Process one step for all active sequences + responses = self.step(prev_outputs, lit_api, lit_spec) + logger.debug(f"Responses from step(): {responses}") + if len(responses) == 0: + raise HTTPException(500, "No responses from step()") + if responses and not isinstance(responses[0], Output): + raise HTTPException(500, "Expected StepOutput from step()") + + prev_outputs = responses + + # Send responses for all sequences (both streaming and completed) + for step_output in responses: + logger.debug(f"Processing response: {step_output}") + status = step_output.status + response_data = step_output.output + uid = step_output.uid + response_queue_id = self.response_queue_ids[uid] + + if status == LitAPIStatus.ERROR: + self.put_error_response(response_queues, response_queue_id, uid, response_data) + self.mark_completed(uid) + elif status == LitAPIStatus.FINISH_STREAMING: + self.put_response(response_queues, response_queue_id, uid, response_data, status) + self.mark_completed(uid) + else: + self.put_response(response_queues, response_queue_id, uid, response_data, status) + + # Fill available capacity with both pending and new requests + pending_requests = self.prefill( + pending_requests, + lit_api, + lit_spec, + request_queue, + max_batch_size, + response_queues, + ) + + except Exception as e: + logger.exception("Error in continuous batching loop") + # Handle any errors by sending error responses for all tracked requests + for uid, response_queue_id in self.response_queue_ids.items(): + self.put_error_response(response_queues, response_queue_id, uid, e) + self.response_queue_ids.clear() + + def inference_worker( lit_api: LitAPI, lit_spec: Optional[LitSpec],