diff --git a/src/litserve/server.py b/src/litserve/server.py index e1716106..feb44e26 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -25,9 +25,9 @@ import warnings from collections import deque from concurrent.futures import ThreadPoolExecutor -from contextlib import asynccontextmanager +from contextlib import AsyncExitStack, asynccontextmanager from queue import Empty -from typing import Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import uvicorn from fastapi import Depends, FastAPI, HTTPException, Request, Response @@ -251,7 +251,7 @@ async def lifespan(self, app: FastAPI): "the LitServer class to initialize the response queues." ) - response_queue = self.response_queues[app.response_queue_id] + response_queue = self.response_queues[self.app.response_queue_id] response_executor = ThreadPoolExecutor(max_workers=len(self.inference_workers)) future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) task = loop.create_task(future) @@ -468,3 +468,94 @@ def setup_auth(self): if LIT_SERVER_API_KEY: return api_key_auth return no_auth + + +@asynccontextmanager +async def multi_server_lifespan(app: FastAPI, servers: List[LitServer]): + """Context manager to handle the lifespan events of multiple FastAPI servers.""" + # Start lifespan events for each server + async with AsyncExitStack() as stack: + for server in servers: + await stack.enter_async_context(server.lifespan(server.app)) + yield + + +def run_all( + servers: List[LitServer], + port: Union[str, int] = 8000, + num_api_servers: Optional[int] = 1, + log_level: str = "info", + generate_client_file: bool = True, + api_server_worker_type: Optional[str] = None, + **kwargs, +): + """Run multiple LitServers on the same port.""" + + if any(not isinstance(server, LitServer) for server in servers): + raise ValueError("All elements in the servers list must be instances of LitServer") + + if generate_client_file: + LitServer.generate_client_file() + + port_msg = f"port must be a value from 1024 to 65535 but got {port}" + try: + port = int(port) + except ValueError: + raise ValueError(port_msg) + if not (1024 <= port <= 65535): + raise ValueError(port_msg) + + if num_api_servers < 1: + raise ValueError("num_api_servers must be greater than 0") + + if sys.platform == "win32": + print("Windows does not support forking. Using threads api_server_worker_type will be set to 'thread'") + api_server_worker_type = "thread" + elif api_server_worker_type is None: + api_server_worker_type = "process" + + # Create the main FastAPI app + app = FastAPI(lifespan=lambda app: multi_server_lifespan(app, servers)) + config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) + sockets = [config.bind_socket()] + + managers, inference_workers = [], [] + try: + for server in servers: + manager, workers = server.launch_inference_worker(num_api_servers) + managers.append(manager) + inference_workers.extend(workers) + + # include routes from each litserver's app into the main app + app.include_router(server.app.router) + + server_processes = [] + for response_queue_id in range(num_api_servers): + for server in servers: + server.app.response_queue_id = response_queue_id + if server.lit_spec: + server.lit_spec.response_queue_id = response_queue_id + + app = copy.copy(app) + config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) + uvicorn_server = uvicorn.Server(config=config) + + if api_server_worker_type == "process": + ctx = mp.get_context("fork") + worker = ctx.Process(target=uvicorn_server.run, args=(sockets,)) + elif api_server_worker_type == "thread": + worker = threading.Thread(target=uvicorn_server.run, args=(sockets,)) + else: + raise ValueError("Invalid value for api_server_worker_type. Must be 'process' or 'thread'") + worker.start() + server_processes.append(worker) + print(f"Swagger UI is available at http://0.0.0.0:{port}/docs") + for process in server_processes: + process.join() + finally: + print("Shutting down LitServe") + for worker in inference_workers: + worker.terminate() + worker.join() + for manager in managers: + manager.shutdown() diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index f2618f13..ea9360bf 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -57,6 +57,17 @@ def test_run(): os.remove("client.py") +@e2e_from_file("tests/multiple_litserver.py") +def test_e2e_combined_multiple_litserver(): + assert os.path.exists("client.py"), f"Expected client file to be created at {os.getcwd()} after starting the server" + for i in range(1, 5): + resp = requests.post(f"http://127.0.0.1:8000/predict-{i}", json={"input": 4.0}, headers=None) + assert resp.status_code == 200, f"Expected response to be 200 but got {resp.status_code}" + assert resp.json() == { + "output": 4.0**i + }, "tests/simple_server_with_multi_endpoints.py didn't return expected output" + + @e2e_from_file("tests/e2e/default_api.py") def test_e2e_default_api(): resp = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0}, headers=None) diff --git a/tests/multiple_litserver.py b/tests/multiple_litserver.py new file mode 100644 index 00000000..b10d267e --- /dev/null +++ b/tests/multiple_litserver.py @@ -0,0 +1,43 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from litserve.server import LitServer, run_all +from litserve.test_examples import SimpleLitAPI + + +class MultipleLitServerAPI1(SimpleLitAPI): + def setup(self, device): + self.model = lambda x: x**1 + + +class MultipleLitServerAPI2(SimpleLitAPI): + def setup(self, device): + self.model = lambda x: x**2 + + +class MultipleLitServerAPI3(SimpleLitAPI): + def setup(self, device): + self.model = lambda x: x**3 + + +class MultipleLitServerAPI4(SimpleLitAPI): + def setup(self, device): + self.model = lambda x: x**4 + + +if __name__ == "__main__": + server1 = LitServer(MultipleLitServerAPI1(), api_path="/predict-1") + server2 = LitServer(MultipleLitServerAPI2(), api_path="/predict-2") + server3 = LitServer(MultipleLitServerAPI3(), api_path="/predict-3") + server4 = LitServer(MultipleLitServerAPI4(), api_path="/predict-4") + run_all([server1, server2, server3, server4], port=8000) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 74dc854a..d85c05c6 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -15,27 +15,26 @@ import pickle import re import sys +from unittest.mock import MagicMock, patch -from asgi_lifespan import LifespanManager -from litserve import LitAPI -from fastapi import Request, Response, HTTPException +import pytest import torch import torch.nn as nn +from asgi_lifespan import LifespanManager +from fastapi import HTTPException, Request, Response +from fastapi.testclient import TestClient from httpx import AsyncClient -from litserve.utils import wrap_litserve_start +from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware +from starlette.types import ASGIApp -from unittest.mock import patch, MagicMock -import pytest - -from litserve.connector import _Connector - -from litserve.server import LitServer import litserve as ls -from fastapi.testclient import TestClient -from starlette.types import ASGIApp -from starlette.middleware.base import BaseHTTPMiddleware +from litserve import LitAPI +from litserve.connector import _Connector +from litserve.server import LitServer, multi_server_lifespan, run_all +from litserve.test_examples.openai_spec_example import TestAPI +from litserve.utils import wrap_litserve_start def test_index(sync_testclient): @@ -429,3 +428,40 @@ def test_middlewares_inputs(): with pytest.raises(ValueError, match="middlewares must be a list of tuples"): ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5})) + + +@pytest.mark.asyncio +@patch("litserve.server.LitServer") +async def test_multi_server_lifespan(mock_litserver): + # List of servers + servers = [mock_litserver, mock_litserver] + # Use the async context manager + async with multi_server_lifespan(MagicMock(), servers): + # Check if the lifespan method was called for each server + assert mock_litserver.lifespan.call_count == 2 + assert mock_litserver.lifespan.return_value.__aexit__.call_count == 2 + + +@patch("litserve.server.uvicorn") +def test_run_all_litservers(mock_uvicorn): + server1 = LitServer(SimpleLitAPI(), api_path="/predict-1") + server2 = LitServer(SimpleLitAPI(), api_path="/predict-2") + server3 = LitServer(TestAPI(), spec=ls.OpenAISpec()) + + with pytest.raises(ValueError, match="All elements in the servers list must be instances of LitServer"): + run_all([server1, "server2"]) + + with pytest.raises(ValueError, match="port must be a value from 1024 to 65535 but got"): + run_all([server1, server2], port="invalid port") + + with pytest.raises(ValueError, match="port must be a value from 1024 to 65535 but got"): + run_all([server1, server2], port=65536) + + with pytest.raises(ValueError, match="num_api_servers must be greater than 0"): + run_all([server1, server2], num_api_servers=0) + + run_all([server1, server2, server3], port=8000) + mock_uvicorn.Config.assert_called() + mock_uvicorn.reset_mock() + run_all([server1, server2, server3], port="8001") + mock_uvicorn.Config.assert_called()