Skip to content

Commit

Permalink
add continuous batching loop 1/n (#387)
Browse files Browse the repository at this point in the history
* add continuous batching loop

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bump version

* fixes

* update

* fix

* update

* remove rich

* attach batch_size

* update

* update

* update

* fix

* update

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
aniketmaurya and pre-commit-ci[bot] authored Dec 11, 2024
1 parent 4a41f7f commit 75c6d0e
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/litserve/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.6.dev0"
__version__ = "0.2.6.dev1"
__author__ = "Lightning-AI et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
1 change: 1 addition & 0 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def device(self, value):
self._device = value

def _sanitize(self, max_batch_size: int, spec: Optional[LitSpec]):
self.max_batch_size = max_batch_size
if self.stream:
self._default_unbatch = self._unbatch_stream
else:
Expand Down
244 changes: 239 additions & 5 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import sys
import time
from abc import ABC
from dataclasses import dataclass
from queue import Empty, Queue
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from fastapi import HTTPException
from starlette.formparsers import MultiPartParser
Expand All @@ -37,10 +38,11 @@
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

except ImportError:
print(
"uvloop is not installed. Falling back to the default asyncio event loop. "
"Please install uvloop for better performance using `pip install uvloop`."
)
if sys.platform != "win32":
print(
"uvloop is not installed. Falling back to the default asyncio event loop. "
"Please install uvloop for better performance using `pip install uvloop`."
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -573,6 +575,238 @@ def __call__(
)


def notify_timed_out_requests(
response_queues: List[Queue],
timed_out_uids: List[Tuple[int, str]],
):
for response_queue_id, uid in timed_out_uids:
logger.error(f"Request {uid} was waiting in the queue for too long and has been timed out.")
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))


@dataclass
class Output:
"""Outputs from a single step of the loop."""

uid: str
output: Any
status: LitAPIStatus


class LitLoop(_BaseLoop):
def __init__(self):
self._context = {}

def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
if max_batch_size <= 1:
raise ValueError("max_batch_size must be greater than 1")

batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
max_batch_size,
batch_timeout,
)
return batches, timed_out_uids

def get_request(self, request_queue: Queue, timeout: float = 1.0):
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
return response_queue_id, uid, timestamp, x_enc

def populate_context(self, lit_spec: LitSpec, request: Any):
if lit_spec and hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(self._context, request)

def put_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
) -> None:
response_queues[response_queue_id].put((uid, (response_data, status)))

def put_error_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
) -> None:
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)))


class ContinuousBatchingLoop(LitLoop):
def __init__(self, max_sequence_length: int = 2048):
super().__init__()
self.active_sequences: Dict[str, Dict] = {} # uid -> {input, current_length, generated_tokens}
self.max_sequence_length = max_sequence_length
self.response_queue_ids: Dict[str, int] = {} # uid -> response_queue_id

def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None:
"""Add a new sequence to active sequences."""
decoded_request = lit_api.decode_request(request)
self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_tokens": []}

def mark_completed(self, uid: str) -> None:
"""Mark a sequence as completed."""
logger.info(f"Marking sequence {uid} as completed")
del self.active_sequences[uid]
del self.response_queue_ids[uid]

def has_capacity(self, lit_api: LitAPI) -> bool:
"""Check if we can add more sequences based on current batch."""
capacity = len(self.active_sequences) < lit_api.max_batch_size
if not capacity:
logger.info(
f"No capacity: {len(self.active_sequences)} active sequences, max batch size: {lit_api.max_batch_size}"
)
return capacity

def step(
self, prev_outputs: Optional[List[Output]], lit_api: LitAPI, lit_spec: Optional[LitSpec]
) -> List[Tuple[str, Tuple[Any, LitAPIStatus]]]:
"""Process one token generation step for all active sequences."""
if not self.active_sequences:
return []

# Batch forward pass for all active sequences
inputs = [seq["input"] for seq in self.active_sequences.values()]
generated = [seq["generated_tokens"] for seq in self.active_sequences.values()]

try:
# Assume lit_api.predict handles batched token generation
new_tokens = lit_api.predict(inputs, generated)

responses = []

# Process each sequence's new token
for uid, token in zip(self.active_sequences.keys(), new_tokens):
seq = self.active_sequences[uid]
seq["generated_tokens"].append(token)
seq["current_length"] += 1

# Check completion conditions
is_finished = lit_api.is_finished(uid, token, self.max_sequence_length)

if is_finished:
# Encode final response for completed sequence
response = lit_api.encode_response(seq["generated_tokens"])
step_output = Output(uid, response, LitAPIStatus.FINISH_STREAMING)
responses.append(step_output)

return responses

except Exception as e:
logger.exception("Error during batch token generation")
# On error, terminate all active sequences
responses = [(uid, (e, LitAPIStatus.ERROR)) for uid in self.active_sequences]
self.active_sequences.clear()
return responses

def prefill(
self,
pending_requests: List[Tuple[str, Any]],
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
request_queue: Queue,
max_batch_size: int,
response_queues: List[Queue],
) -> List[Tuple[str, Any]]:
"""Fill available capacity with pending and new requests."""
# First process existing pending requests
while pending_requests and self.has_capacity(lit_api):
response_queue_id, uid, input = pending_requests.pop(0)
self.add_request(uid, input, lit_api, lit_spec)
self.response_queue_ids[uid] = response_queue_id

# Then check for new requests if we still have capacity
if self.has_capacity(lit_api):
new_batches, timed_out_uids = self.get_batch_requests(
lit_api, request_queue, max_batch_size, batch_timeout=0.0001
)
notify_timed_out_requests(response_queues, timed_out_uids)

if new_batches:
# Add new requests to pending_requests and try to process them
for response_queue_id, uid, input in new_batches:
logger.info(f"New request: {uid}, {input}")
if self.has_capacity(lit_api):
self.add_request(uid, input, lit_api, lit_spec)
self.response_queue_ids[uid] = response_queue_id
else:
pending_requests.append((response_queue_id, uid, input))

return pending_requests

def run(
self,
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
device: str,
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
max_batch_size: int,
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
):
if not lit_api.stream:
raise ValueError(
"Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)"
)

"""Main loop that processes batches of requests."""
pending_requests = self.prefill(
[],
lit_api,
lit_spec,
request_queue,
max_batch_size,
response_queues,
)
try:
prev_outputs = None
while pending_requests or self.active_sequences:
# Process one step for all active sequences
responses = self.step(prev_outputs, lit_api, lit_spec)
logger.debug(f"Responses from step(): {responses}")
if len(responses) == 0:
raise HTTPException(500, "No responses from step()")
if responses and not isinstance(responses[0], Output):
raise HTTPException(500, "Expected StepOutput from step()")

prev_outputs = responses

# Send responses for all sequences (both streaming and completed)
for step_output in responses:
logger.debug(f"Processing response: {step_output}")
status = step_output.status
response_data = step_output.output
uid = step_output.uid
response_queue_id = self.response_queue_ids[uid]

if status == LitAPIStatus.ERROR:
self.put_error_response(response_queues, response_queue_id, uid, response_data)
self.mark_completed(uid)
elif status == LitAPIStatus.FINISH_STREAMING:
self.put_response(response_queues, response_queue_id, uid, response_data, status)
self.mark_completed(uid)
else:
self.put_response(response_queues, response_queue_id, uid, response_data, status)

# Fill available capacity with both pending and new requests
pending_requests = self.prefill(
pending_requests,
lit_api,
lit_spec,
request_queue,
max_batch_size,
response_queues,
)

except Exception as e:
logger.exception("Error in continuous batching loop")
# Handle any errors by sending error responses for all tracked requests
for uid, response_queue_id in self.response_queue_ids.items():
self.put_error_response(response_queues, response_queue_id, uid, e)
self.response_queue_ids.clear()


def inference_worker(
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
Expand Down

0 comments on commit 75c6d0e

Please sign in to comment.