Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loop.pre_setup to allow fine-grained LitAPI validation based on inference loop #393

Merged
merged 4 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 10 additions & 63 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import warnings
from abc import ABC, abstractmethod
Expand All @@ -25,6 +24,7 @@

class LitAPI(ABC):
_stream: bool = False
_max_batch_size: int = 1
_default_unbatch: callable = None
_spec: LitSpec = None
_device: Optional[str] = None
Expand Down Expand Up @@ -113,76 +113,23 @@ def device(self):
def device(self, value):
self._device = value

def _sanitize(self, max_batch_size: int, spec: Optional[LitSpec]):
@property
def max_batch_size(self):
return self._max_batch_size

@max_batch_size.setter
def max_batch_size(self, value):
self._max_batch_size = value
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved

def pre_setup(self, max_batch_size: int, spec: Optional[LitSpec]):
self.max_batch_size = max_batch_size
if self.stream:
self._default_unbatch = self._unbatch_stream
else:
self._default_unbatch = self._unbatch_no_stream

# we will sanitize regularly if no spec
# in case, we have spec then:
# case 1: spec implements a streaming API
# Case 2: spec implements a non-streaming API
if spec:
# TODO: Implement sanitization
self._spec = spec
return

original = self.unbatch.__code__ is LitAPI.unbatch.__code__
if (
self.stream
and max_batch_size > 1
and not all([
inspect.isgeneratorfunction(self.predict),
inspect.isgeneratorfunction(self.encode_response),
(original or inspect.isgeneratorfunction(self.unbatch)),
])
):
raise ValueError(
"""When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and
`lit_api.unbatch` must generate values using `yield`.

Example:

def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction

def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output

def unbatch(self, outputs):
for output in outputs:
unbatched_output = ...
yield unbatched_output
"""
)

if self.stream and not all([
inspect.isgeneratorfunction(self.predict),
inspect.isgeneratorfunction(self.encode_response),
]):
raise ValueError(
"""When `stream=True` both `lit_api.predict` and
`lit_api.encode_response` must generate values using `yield`.

Example:

def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction

def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
"""
)

def set_logger_queue(self, queue: Queue):
"""Set the queue for logging events."""
Expand Down
170 changes: 122 additions & 48 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def run(

"""

def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
pass

def __call__(
self,
lit_api: LitAPI,
Expand Down Expand Up @@ -487,7 +490,109 @@ def run(
raise NotImplementedError


class SingleLoop(_BaseLoop):
class LitLoop(_BaseLoop):
def __init__(self):
self._context = {}

def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
if max_batch_size <= 1:
raise ValueError("max_batch_size must be greater than 1")

batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
max_batch_size,
batch_timeout,
)
return batches, timed_out_uids

def get_request(self, request_queue: Queue, timeout: float = 1.0):
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
return response_queue_id, uid, timestamp, x_enc

def populate_context(self, lit_spec: LitSpec, request: Any):
if lit_spec and hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(self._context, request)

def put_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
) -> None:
response_queues[response_queue_id].put((uid, (response_data, status)))

def put_error_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
) -> None:
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)))


class DefaultLoop(LitLoop):
def pre_setup(self, lit_api: LitAPI, spec: Optional[LitSpec]):
# we will sanitize regularly if no spec
# in case, we have spec then:
# case 1: spec implements a streaming API
# Case 2: spec implements a non-streaming API
if spec:
# TODO: Implement sanitization
lit_api._spec = spec
return

original = lit_api.unbatch.__code__ is LitAPI.unbatch.__code__
if (
lit_api.stream
and lit_api.max_batch_size > 1
and not all([
inspect.isgeneratorfunction(lit_api.predict),
inspect.isgeneratorfunction(lit_api.encode_response),
(original or inspect.isgeneratorfunction(lit_api.unbatch)),
])
):
raise ValueError(
"""When `stream=True` with max_batch_size > 1, `lit_api.predict`, `lit_api.encode_response` and
`lit_api.unbatch` must generate values using `yield`.

Example:

def predict(self, inputs):
...
for i in range(max_token_length):
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
yield prediction

def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output

def unbatch(self, outputs):
for output in outputs:
unbatched_output = ...
yield unbatched_output
"""
)

if lit_api.stream and not all([
inspect.isgeneratorfunction(lit_api.predict),
inspect.isgeneratorfunction(lit_api.encode_response),
]):
raise ValueError(
"""When `stream=True` both `lit_api.predict` and
`lit_api.encode_response` must generate values using `yield`.

Example:

def predict(self, inputs):
...
for i in range(max_token_length):
yield prediction

def encode_response(self, outputs):
for output in outputs:
encoded_output = ...
yield encoded_output
"""
)


class SingleLoop(DefaultLoop):
def __call__(
self,
lit_api: LitAPI,
Expand All @@ -505,7 +610,7 @@ def __call__(
run_single_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)


class BatchedLoop(_BaseLoop):
class BatchedLoop(DefaultLoop):
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
def __call__(
self,
lit_api: LitAPI,
Expand All @@ -531,7 +636,7 @@ def __call__(
)


class StreamingLoop(_BaseLoop):
class StreamingLoop(DefaultLoop):
def __call__(
self,
lit_api: LitAPI,
Expand All @@ -549,7 +654,7 @@ def __call__(
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)


class BatchedStreamingLoop(_BaseLoop):
class BatchedStreamingLoop(DefaultLoop):
def __call__(
self,
lit_api: LitAPI,
Expand Down Expand Up @@ -593,41 +698,6 @@ class Output:
status: LitAPIStatus


class LitLoop(_BaseLoop):
def __init__(self):
self._context = {}

def get_batch_requests(self, lit_api: LitAPI, request_queue: Queue, max_batch_size: int, batch_timeout: float):
if max_batch_size <= 1:
raise ValueError("max_batch_size must be greater than 1")

batches, timed_out_uids = collate_requests(
lit_api,
request_queue,
max_batch_size,
batch_timeout,
)
return batches, timed_out_uids

def get_request(self, request_queue: Queue, timeout: float = 1.0):
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=timeout)
return response_queue_id, uid, timestamp, x_enc

def populate_context(self, lit_spec: LitSpec, request: Any):
if lit_spec and hasattr(lit_spec, "populate_context"):
lit_spec.populate_context(self._context, request)

def put_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, response_data: Any, status: LitAPIStatus
) -> None:
response_queues[response_queue_id].put((uid, (response_data, status)))

def put_error_response(
self, response_queues: List[Queue], response_queue_id: int, uid: str, error: Exception
) -> None:
response_queues[response_queue_id].put((uid, (error, LitAPIStatus.ERROR)))


class ContinuousBatchingLoop(LitLoop):
def __init__(self, max_sequence_length: int = 2048):
super().__init__()
Expand Down Expand Up @@ -840,15 +910,7 @@ def inference_worker(
logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec")

if loop == "auto":
loop = (
BatchedStreamingLoop()
if stream and max_batch_size > 1
else StreamingLoop()
if stream
else BatchedLoop()
if max_batch_size > 1
else SingleLoop()
)
loop = get_default_loop(stream, max_batch_size)

loop(
lit_api,
Expand All @@ -863,3 +925,15 @@ def inference_worker(
workers_setup_status,
callback_runner,
)


def get_default_loop(stream: bool, max_batch_size: int) -> _BaseLoop:
return (
BatchedStreamingLoop()
if stream and max_batch_size > 1
else StreamingLoop()
if stream
else BatchedLoop()
if max_batch_size > 1
else SingleLoop()
)
11 changes: 7 additions & 4 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from litserve.callbacks.base import Callback, CallbackRunner, EventTypes
from litserve.connector import _Connector
from litserve.loggers import Logger, _LoggerConnector
from litserve.loops import _BaseLoop, inference_worker
from litserve.loops import LitLoop, get_default_loop, inference_worker
from litserve.middlewares import MaxSizeMiddleware, RequestCountMiddleware
from litserve.python_client import client_template
from litserve.specs import OpenAISpec
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
spec: Optional[LitSpec] = None,
max_payload_size=None,
track_requests: bool = False,
loop: Optional[Union[str, _BaseLoop]] = "auto",
loop: Optional[Union[str, LitLoop]] = "auto",
callbacks: Optional[Union[List[Callback], Callback]] = None,
middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None,
loggers: Optional[Union[Logger, List[Logger]]] = None,
Expand Down Expand Up @@ -154,6 +154,8 @@ def __init__(

if isinstance(loop, str) and loop != "auto":
raise ValueError("loop must be an instance of _BaseLoop or 'auto'")
if loop == "auto":
loop = get_default_loop(stream, max_batch_size)

if middlewares is None:
middlewares = []
Expand Down Expand Up @@ -198,15 +200,16 @@ def __init__(
"but the max_batch_size parameter was not set."
)

self._loop = loop
self._loop: LitLoop = loop
self.api_path = api_path
self.healthcheck_path = healthcheck_path
self.info_path = info_path
self.track_requests = track_requests
self.timeout = timeout
lit_api.stream = stream
lit_api.request_timeout = self.timeout
lit_api._sanitize(max_batch_size, spec=spec)
lit_api.pre_setup(max_batch_size, spec=spec)
self._loop.pre_setup(lit_api, spec=spec)
self.app = FastAPI(lifespan=self.lifespan)
self.app.response_queue_id = None
self.response_queue_id = None
Expand Down
2 changes: 1 addition & 1 deletion tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_max_batch_size_warning():

def test_batch_predict_string_warning():
api = ls.test_examples.SimpleBatchedAPI()
api._sanitize(2, None)
api.pre_setup(2, None)
api.predict = MagicMock(return_value="This is a string")

mock_input = torch.tensor([[1.0], [2.0]])
Expand Down
Loading
Loading