From 021f190afdaab9b5ce461ad6c6a051617715a4f0 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 11 Dec 2024 21:34:14 +0000 Subject: [PATCH] update --- src/litserve/api.py | 3 +-- src/litserve/loops.py | 52 +++++++++++++++++++++++++++++-------------- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/src/litserve/api.py b/src/litserve/api.py index e2dd9eea..9f29ffe3 100644 --- a/src/litserve/api.py +++ b/src/litserve/api.py @@ -56,10 +56,9 @@ def batch(self, inputs): return inputs - @abstractmethod def predict(self, x, **kwargs): """Run the model on the input and return or yield the output.""" - pass + raise NotImplementedError("predict is not implemented") def _unbatch_no_stream(self, output): if isinstance(output, str): diff --git a/src/litserve/loops.py b/src/litserve/loops.py index 46a309c8..936c1f06 100644 --- a/src/litserve/loops.py +++ b/src/litserve/loops.py @@ -700,18 +700,32 @@ class Output: class ContinuousBatchingLoop(LitLoop): def __init__(self, max_sequence_length: int = 2048): + """Runs continuous batching loop. This loop handles adding new requests, processing them in batches, and + managing the state of active sequences. + + The loop requires the following methods to be implemented in the LitAPI: + - setup: sets up the model on the device + - decode_request: decodes the client request into a format that can be processed by the model + - step: generates a new token for each sequence + - encode_response: encodes the response into a format that can be sent to the client + - has_finished: checks if the sequence has finished generating + + Args: + max_sequence_length (int): The maximum sequence length allowed for any active sequence. + + """ super().__init__() - self.active_sequences: Dict[str, Dict] = {} # uid -> {input, current_length, generated_tokens} + self.active_sequences: Dict[str, Dict] = {} # uid -> {input, current_length, generated_sequence} self.max_sequence_length = max_sequence_length self.response_queue_ids: Dict[str, int] = {} # uid -> response_queue_id def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None: - """Add a new sequence to active sequences.""" + """Add a new sequence to active sequences and perform any action before prediction such as filling the cache.""" decoded_request = lit_api.decode_request(request) - self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_tokens": []} + self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_sequence": []} def mark_completed(self, uid: str) -> None: - """Mark a sequence as completed.""" + """Mark a request as completed and remove it from the tracked state.""" logger.info(f"Marking sequence {uid} as completed") del self.active_sequences[uid] del self.response_queue_ids[uid] @@ -725,36 +739,39 @@ def has_capacity(self, lit_api: LitAPI) -> bool: ) return capacity - def step( - self, prev_outputs: Optional[List[Output]], lit_api: LitAPI, lit_spec: Optional[LitSpec] - ) -> List[Tuple[str, Tuple[Any, LitAPIStatus]]]: + def step(self, prev_outputs: Optional[List[Output]], lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> List[Output]: """Process one token generation step for all active sequences.""" + if hasattr(lit_api, "step"): + return lit_api.step(prev_outputs) + if not self.active_sequences: return [] # Batch forward pass for all active sequences inputs = [seq["input"] for seq in self.active_sequences.values()] - generated = [seq["generated_tokens"] for seq in self.active_sequences.values()] + generated = [seq["generated_sequence"] for seq in self.active_sequences.values()] try: # Assume lit_api.predict handles batched token generation - new_tokens = lit_api.predict(inputs, generated) + new_tokens: List[Any] = lit_api.predict(inputs, generated) - responses = [] + responses: List[Output] = [] # Process each sequence's new token for uid, token in zip(self.active_sequences.keys(), new_tokens): seq = self.active_sequences[uid] - seq["generated_tokens"].append(token) + seq["generated_sequence"].append(token) seq["current_length"] += 1 + step_output = Output(uid, token, LitAPIStatus.OK) + responses.append(step_output) + # Check completion conditions - is_finished = lit_api.is_finished(uid, token, self.max_sequence_length) + is_finished = lit_api.has_finished(uid, token, self.max_sequence_length) if is_finished: # Encode final response for completed sequence - response = lit_api.encode_response(seq["generated_tokens"]) - step_output = Output(uid, response, LitAPIStatus.FINISH_STREAMING) + step_output = Output(uid, "", LitAPIStatus.FINISH_STREAMING) responses.append(step_output) return responses @@ -815,12 +832,13 @@ def run( workers_setup_status: Dict[int, str], callback_runner: CallbackRunner, ): + """Main loop that processes batches of requests.""" + if not lit_api.stream: raise ValueError( "Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)" ) - """Main loop that processes batches of requests.""" pending_requests = self.prefill( [], lit_api, @@ -846,7 +864,7 @@ def run( for step_output in responses: logger.debug(f"Processing response: {step_output}") status = step_output.status - response_data = step_output.output + response_data = lit_api.encode_response(step_output.output) uid = step_output.uid response_queue_id = self.response_queue_ids[uid] @@ -870,7 +888,7 @@ def run( ) except Exception as e: - logger.exception("Error in continuous batching loop") + logger.exception(f"Error in continuous batching loop: {e}") # Handle any errors by sending error responses for all tracked requests for uid, response_queue_id in self.response_queue_ids.items(): self.put_error_response(response_queues, response_queue_id, uid, e)