Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add continuous batching loop 1/n #387

Merged
merged 20 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/litserve/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.6.dev0"
__version__ = "0.2.6.dev1"
__author__ = "Lightning-AI et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
Expand Down
1 change: 1 addition & 0 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def device(self, value):
self._device = value

def _sanitize(self, max_batch_size: int, spec: Optional[LitSpec]):
self.max_batch_size = max_batch_size
if self.stream:
self._default_unbatch = self._unbatch_stream
else:
Expand Down
244 changes: 239 additions & 5 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
import sys
import time
from abc import ABC
from dataclasses import dataclass
from queue import Empty, Queue
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from fastapi import HTTPException
from starlette.formparsers import MultiPartParser
Expand All @@ -37,10 +38,11 @@
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

except ImportError:
print(
"uvloop is not installed. Falling back to the default asyncio event loop. "
"Please install uvloop for better performance using `pip install uvloop`."
)
if sys.platform != "win32":
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
print(
"uvloop is not installed. Falling back to the default asyncio event loop. "
"Please install uvloop for better performance using `pip install uvloop`."
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -573,6 +575,238 @@ def __call__(
)


def notify_timed_out_requests(
response_queues: List[Queue],
timed_out_uids: List[Tuple[int, str]],
):
for response_queue_id, uid in timed_out_uids:
logger.error(f"Request {uid} was waiting in the queue for too long and has been timed out.")
response_queues[response_queue_id].put((uid, (HTTPException(504, "Request timed out"), LitAPIStatus.ERROR)))


@dataclass
class Output:
"""Outputs from a single step of the loop."""

uid: str
output: Any
status: LitAPIStatus


class LitLoop(_BaseLoop):
def __init__(self):
self._context = {}

def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
if max_batch_size <= 1:
raise ValueError("max_batch_size must be greater than 1")
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved

batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
max_batch_size,
batch_timeout,
)
return batches, timed_out_uids

def get_request(self, request_queue: Queue, timeout: float = 1.0):
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
return response_queue_id, uid, timestamp, x_enc

def populate_context(self, lit_spec: LitSpec, request: Any):
if lit_spec and hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(self._context, request)

def put_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
) -> None:
response_queues[response_queue_id].put((uid, (response_data, status)))

def put_error_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
) -> None:
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)))


class ContinuousBatchingLoop(LitLoop):
def __init__(self, max_sequence_length: int = 2048):
super().__init__()
self.active_sequences: Dict[str, Dict] = {} # uid -> {input, current_length, generated_tokens}
self.max_sequence_length = max_sequence_length
self.response_queue_ids: Dict[str, int] = {} # uid -> response_queue_id

def add_request(self, uid: str, request: Any, lit_api: LitAPI, lit_spec: Optional[LitSpec]) -> None:
"""Add a new sequence to active sequences."""
decoded_request = lit_api.decode_request(request)
self.active_sequences[uid] = {"input": decoded_request, "current_length": 0, "generated_tokens": []}

def mark_completed(self, uid: str) -> None:
"""Mark a sequence as completed."""
logger.info(f"Marking sequence {uid} as completed")
del self.active_sequences[uid]
del self.response_queue_ids[uid]

def has_capacity(self, lit_api: LitAPI) -> bool:
"""Check if we can add more sequences based on current batch."""
capacity = len(self.active_sequences) < lit_api.max_batch_size
if not capacity:
logger.info(
f"No capacity: {len(self.active_sequences)} active sequences, max batch size: {lit_api.max_batch_size}"
)
return capacity

def step(
self, prev_outputs: Optional[List[Output]], lit_api: LitAPI, lit_spec: Optional[LitSpec]
) -> List[Tuple[str, Tuple[Any, LitAPIStatus]]]:
"""Process one token generation step for all active sequences."""
if not self.active_sequences:
return []

# Batch forward pass for all active sequences
inputs = [seq["input"] for seq in self.active_sequences.values()]
generated = [seq["generated_tokens"] for seq in self.active_sequences.values()]

try:
# Assume lit_api.predict handles batched token generation
new_tokens = lit_api.predict(inputs, generated)

responses = []

# Process each sequence's new token
for uid, token in zip(self.active_sequences.keys(), new_tokens):
seq = self.active_sequences[uid]
seq["generated_tokens"].append(token)
seq["current_length"] += 1

# Check completion conditions
is_finished = lit_api.is_finished(uid, token, self.max_sequence_length)

if is_finished:
# Encode final response for completed sequence
response = lit_api.encode_response(seq["generated_tokens"])
step_output = Output(uid, response, LitAPIStatus.FINISH_STREAMING)
responses.append(step_output)

return responses

except Exception as e:
logger.exception("Error during batch token generation")
# On error, terminate all active sequences
responses = [(uid, (e, LitAPIStatus.ERROR)) for uid in self.active_sequences]
self.active_sequences.clear()
return responses

def prefill(
self,
pending_requests: List[Tuple[str, Any]],
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
request_queue: Queue,
max_batch_size: int,
response_queues: List[Queue],
) -> List[Tuple[str, Any]]:
"""Fill available capacity with pending and new requests."""
# First process existing pending requests
while pending_requests and self.has_capacity(lit_api):
response_queue_id, uid, input = pending_requests.pop(0)
self.add_request(uid, input, lit_api, lit_spec)
self.response_queue_ids[uid] = response_queue_id

# Then check for new requests if we still have capacity
if self.has_capacity(lit_api):
new_batches, timed_out_uids = self.get_batch_requests(
lit_api, request_queue, max_batch_size, batch_timeout=0.0001
)
notify_timed_out_requests(response_queues, timed_out_uids)

if new_batches:
# Add new requests to pending_requests and try to process them
for response_queue_id, uid, input in new_batches:
logger.info(f"New request: {uid}, {input}")
if self.has_capacity(lit_api):
self.add_request(uid, input, lit_api, lit_spec)
self.response_queue_ids[uid] = response_queue_id
else:
pending_requests.append((response_queue_id, uid, input))

return pending_requests

def run(
self,
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
device: str,
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
max_batch_size: int,
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
):
if not lit_api.stream:
raise ValueError(
"Continuous batching loop requires streaming to be enabled. Please set LitServe(..., stream=True)"
)

"""Main loop that processes batches of requests."""
pending_requests = self.prefill(
[],
lit_api,
lit_spec,
request_queue,
max_batch_size,
response_queues,
)
try:
prev_outputs = None
while pending_requests or self.active_sequences:
# Process one step for all active sequences
responses = self.step(prev_outputs, lit_api, lit_spec)
logger.debug(f"Responses from step(): {responses}")
if len(responses) == 0:
raise HTTPException(500, "No responses from step()")
if responses and not isinstance(responses[0], Output):
raise HTTPException(500, "Expected StepOutput from step()")

prev_outputs = responses

# Send responses for all sequences (both streaming and completed)
for step_output in responses:
logger.debug(f"Processing response: {step_output}")
status = step_output.status
response_data = step_output.output
uid = step_output.uid
response_queue_id = self.response_queue_ids[uid]

if status == LitAPIStatus.ERROR:
self.put_error_response(response_queues, response_queue_id, uid, response_data)
self.mark_completed(uid)
elif status == LitAPIStatus.FINISH_STREAMING:
self.put_response(response_queues, response_queue_id, uid, response_data, status)
self.mark_completed(uid)
else:
self.put_response(response_queues, response_queue_id, uid, response_data, status)

# Fill available capacity with both pending and new requests
pending_requests = self.prefill(
pending_requests,
lit_api,
lit_spec,
request_queue,
max_batch_size,
response_queues,
)

except Exception as e:
logger.exception("Error in continuous batching loop")
# Handle any errors by sending error responses for all tracked requests
for uid, response_queue_id in self.response_queue_ids.items():
self.put_error_response(response_queues, response_queue_id, uid, e)
self.response_queue_ids.clear()


def inference_worker(
lit_api: LitAPI,
lit_spec: Optional[LitSpec],
Expand Down
Loading