Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/evict req on client disconnect streaming case #223

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d613565
chore: Add SimpleDelayedStreamAPI for delayed streaming of output
bhimrazy Aug 26, 2024
371cf56
add test_stream_client_disconnection
bhimrazy Aug 26, 2024
9e7f841
add request_evicted_status param to run_streaming_loop
bhimrazy Aug 26, 2024
7ce49ac
update test_stream_client_disconnection
bhimrazy Aug 26, 2024
56c8587
adds functionality to evict the request if disconnected before comple…
bhimrazy Aug 26, 2024
f5961c4
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Aug 26, 2024
9330997
update exception
aniketmaurya Aug 26, 2024
1f0bfe5
fix test
aniketmaurya Aug 26, 2024
d41db3c
Update src/litserve/server.py
aniketmaurya Aug 26, 2024
4344720
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
aniketmaurya Aug 26, 2024
f177fcb
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Aug 27, 2024
4e5045a
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Aug 28, 2024
1d4677c
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Aug 29, 2024
ca6fbc2
reverted changes to new updates
bhimrazy Aug 31, 2024
e61cdab
update
bhimrazy Aug 31, 2024
6c2e0c6
update
bhimrazy Aug 31, 2024
6668cc8
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Aug 31, 2024
3448ef3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 31, 2024
2cfd68e
chore: Add test for streaming client disconnection
bhimrazy Aug 31, 2024
c95ee45
handle client disconnection streaming nonbatched case
bhimrazy Aug 31, 2024
bac5534
chore: Optimize streaming loop performance by checking for client dis…
bhimrazy Aug 31, 2024
f08ed4b
chore: Update streaming loop to include request eviction status
bhimrazy Aug 31, 2024
e060e39
Merge branch 'main' into feat/evict-req-on-client-disconnect-streamin…
bhimrazy Sep 21, 2024
2b11fc7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2024
5323c51
Refactor inference_worker function to remove optional parameters and …
bhimrazy Sep 21, 2024
4368b57
update
bhimrazy Sep 21, 2024
611a751
update
bhimrazy Sep 21, 2024
5cc0f77
add missing param
bhimrazy Sep 21, 2024
8d4a05d
add missing param
bhimrazy Sep 21, 2024
bd68b6c
add missing param for run streaming loop
bhimrazy Sep 21, 2024
56f1076
test by removing the check interval
bhimrazy Sep 21, 2024
49bed55
so there is performance drop with this check,
bhimrazy Sep 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 39 additions & 20 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,13 @@ def run_batched_loop(
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))


def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
def run_streaming_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
request_evicted_status: Dict[str, bool],
):
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
Expand Down Expand Up @@ -256,6 +262,9 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue,
y_gen,
)
for y_enc in y_enc_gen:
if request_evicted_status.get(uid):
request_evicted_status.pop(uid)
break
y_enc = lit_api.format_encoded_response(y_enc)
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))
Expand Down Expand Up @@ -342,6 +351,7 @@ def inference_worker(
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[str, bool] = None,
request_evicted_status: Dict[str, bool] = None,
):
lit_api.setup(device)
lit_api.device = device
Expand All @@ -357,7 +367,7 @@ def inference_worker(
if max_batch_size > 1:
run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout)
else:
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues)
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, request_evicted_status)
return

if max_batch_size > 1:
Expand Down Expand Up @@ -397,7 +407,7 @@ async def response_queue_to_buffer(
await asyncio.sleep(0.0001)
continue
q, event = buffer[uid]
q.append(payload)
q.append((uid, payload))
event.set()

else:
Expand Down Expand Up @@ -498,6 +508,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
manager = mp.Manager()
self.workers_setup_status = manager.dict()
self.request_queue = manager.Queue()
self.request_evicted_status = manager.dict()

self.response_queues = []
for _ in range(num_uvicorn_servers):
Expand Down Expand Up @@ -535,6 +546,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
self.batch_timeout,
self.stream,
self.workers_setup_status,
self.request_evicted_status,
),
)
process.start()
Expand Down Expand Up @@ -568,26 +580,33 @@ def device_identifiers(self, accelerator, device):
return [f"{accelerator}:{device}"]

async def data_streamer(self, q: deque, data_available: asyncio.Event, send_status: bool = False):
uid = None
while True:
await data_available.wait()
while len(q) > 0:
data, status = q.popleft()
if status == LitAPIStatus.FINISH_STREAMING:
return

if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
try:
await data_available.wait()
while len(q) > 0:
uid, (data, status) = q.popleft()
if status == LitAPIStatus.FINISH_STREAMING:
return

if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
if send_status:
yield data, status
return
if send_status:
yield data, status
return
if send_status:
yield data, status
else:
yield data
data_available.clear()
else:
yield data
data_available.clear()
except asyncio.CancelledError:
if uid is not None:
self.request_evicted_status[uid] = True
logger.error("Request evicted for the uid=%s", uid)
break
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved

def setup_server(self):
workers_ready = False
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def encode_response(self, output: Generator) -> Generator:
yield out.lower()


class SimpleDelayedStreamAPI(SimpleStreamAPI):
def encode_response(self, output: Generator) -> Generator:
delay = 0.2
for out in output:
time.sleep(delay)
yield out.lower()


class SimpleBatchedStreamAPI(LitAPI):
def setup(self, device) -> None:
self.sentence = "LitServe is streaming output"
Expand Down Expand Up @@ -98,6 +106,11 @@ def simple_batched_stream_api():
return SimpleBatchedStreamAPI()


@pytest.fixture()
def simple_delayed_stream_api():
return SimpleDelayedStreamAPI()


@pytest.fixture()
def lit_server(simple_litapi):
server = LitServer(simple_litapi, accelerator="cpu", devices=1, timeout=10)
Expand Down
25 changes: 24 additions & 1 deletion tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import asyncio
import inspect
import logging
import pickle
import re
from asgi_lifespan import LifespanManager
Expand Down Expand Up @@ -120,6 +121,27 @@ async def test_stream(simple_stream_api):
), "Server returns input prompt and generated output which didn't match."


@pytest.mark.asyncio()
async def test_stream_client_disconnection(simple_delayed_stream_api, caplog):
server = LitServer(simple_delayed_stream_api, stream=True, timeout=10)

with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG):
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?"}, timeout=10))
await asyncio.sleep(2)
task.cancel() # simulate client disconnection
await asyncio.sleep(1) # wait for the task to stop
with pytest.raises(asyncio.CancelledError):
await task
assert "Request evicted for the uid=" in caplog.text
# TODO: also check if the task actually stopped in the server

caplog.clear()
task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?"}, timeout=10))
await task
assert "Request evicted for the uid=" not in caplog.text


@pytest.mark.asyncio()
async def test_batched_stream_server(simple_batched_stream_api):
server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30)
Expand Down Expand Up @@ -177,9 +199,10 @@ def fake_encode(output):
requests_queue = Queue()
requests_queue.put((0, "UUID-1234", time.monotonic(), {"prompt": "Hello"}))
response_queues = [FakeStreamResponseQueue(num_streamed_outputs)]
request_evicted_status = {}

with pytest.raises(StopIteration, match="exit loop"):
run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues)
run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues, request_evicted_status)

fake_stream_api.predict.assert_called_once_with("Hello")
fake_stream_api.encode_response.assert_called_once()
Expand Down
Loading