Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Dec 11, 2024
1 parent 35129f7 commit 021f190
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 19 deletions.
3 changes: 1 addition & 2 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,9 @@ def batch(self, inputs):

return inputs

@abstractmethod
def predict(self, x, **kwargs):
"""Run the model on the input and return or yield the output."""
pass
raise NotImplementedError("predict is not implemented")

def _unbatch_no_stream(self, output):
if isinstance(output, str):
Expand Down
52 changes: 35 additions & 17 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,18 +700,32 @@ class Output:

class ContinuousBatchingLoop(LitLoop):
def __init__(self, max_sequence_length: int = 2048):
"""Runs continuous batching loop. This loop handles adding new requests, processing them in batches, and
managing the state of active sequences.
The loop requires the following methods to be implemented in the LitAPI:
- setup: sets up the model on the device
- decode_request: decodes the client request into a format that can be processed by the model
- step: generates a new token for each sequence
- encode_response: encodes the response into a format that can be sent to the client
- has_finished: checks if the sequence has finished generating
Args:
max_sequence_length (int): The maximum sequence length allowed for any active sequence.
"""
super().__init__()
self.active_sequences: Dict[str, Dict] = {} # uid -> {input, current_length, generated_tokens}
self.active_sequences: Dict[str, Dict] = {} # uid -> {input, current_length, generated_sequence}
self.max_sequence_length = max_sequence_length
self.response_queue_ids: Dict[str, int] = {} # uid -> response_queue_id

def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None:
"""Add a new sequence to active sequences."""
"""Add a new sequence to active sequences and perform any action before prediction such as filling the cache."""
decoded_request = lit_api.decode_request(request)
self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_tokens": []}
self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_sequence": []}

def mark_completed(self, uid: str) -> None:
"""Mark a sequence as completed."""
"""Mark a request as completed and remove it from the tracked state."""
logger.info(f"Marking sequence {uid} as completed")
del self.active_sequences[uid]
del self.response_queue_ids[uid]
Expand All @@ -725,36 +739,39 @@ def has_capacity(self, lit_api: LitAPI) -> bool:
)
return capacity

def step(
self, prev_outputs: Optional[List[Output]], lit_api: LitAPI, lit_spec: Optional[LitSpec]
) -> List[Tuple[str, Tuple[Any, LitAPIStatus]]]:
def step(self, prev_outputs: Optional[List[Output]], lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> List[Output]:
"""Process one token generation step for all active sequences."""
if hasattr(lit_api, "step"):
return lit_api.step(prev_outputs)

if not self.active_sequences:
return []

# Batch forward pass for all active sequences
inputs = [seq["input"] for seq in self.active_sequences.values()]
generated = [seq["generated_tokens"] for seq in self.active_sequences.values()]
generated = [seq["generated_sequence"] for seq in self.active_sequences.values()]

try:
# Assume lit_api.predict handles batched token generation
new_tokens = lit_api.predict(inputs, generated)
new_tokens: List[Any] = lit_api.predict(inputs, generated)

responses = []
responses: List[Output] = []

# Process each sequence's new token
for uid, token in zip(self.active_sequences.keys(), new_tokens):
seq = self.active_sequences[uid]
seq["generated_tokens"].append(token)
seq["generated_sequence"].append(token)
seq["current_length"] += 1

step_output = Output(uid, token, LitAPIStatus.OK)
responses.append(step_output)

# Check completion conditions
is_finished = lit_api.is_finished(uid, token, self.max_sequence_length)
is_finished = lit_api.has_finished(uid, token, self.max_sequence_length)

if is_finished:
# Encode final response for completed sequence
response = lit_api.encode_response(seq["generated_tokens"])
step_output = Output(uid, response, LitAPIStatus.FINISH_STREAMING)
step_output = Output(uid, "", LitAPIStatus.FINISH_STREAMING)
responses.append(step_output)

return responses
Expand Down Expand Up @@ -815,12 +832,13 @@ def run(
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
):
"""Main loop that processes batches of requests."""

if not lit_api.stream:
raise ValueError(
"Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)"
)

"""Main loop that processes batches of requests."""
pending_requests = self.prefill(
[],
lit_api,
Expand All @@ -846,7 +864,7 @@ def run(
for step_output in responses:
logger.debug(f"Processing response: {step_output}")
status = step_output.status
response_data = step_output.output
response_data = lit_api.encode_response(step_output.output)
uid = step_output.uid
response_queue_id = self.response_queue_ids[uid]

Expand All @@ -870,7 +888,7 @@ def run(
)

except Exception as e:
logger.exception("Error in continuous batching loop")
logger.exception(f"Error in continuous batching loop: {e}")
# Handle any errors by sending error responses for all tracked requests
for uid, response_queue_id in self.response_queue_ids.items():
self.put_error_response(response_queues, response_queue_id, uid, e)
Expand Down

0 comments on commit 021f190

Please sign in to comment.