Skip to content

Commit

Permalink
feat: add get_task_result_stream to new client (#351)
Browse files Browse the repository at this point in the history
* feat: add get_task_result_stream to new client

* prevent using the sync client within an async loop

* remove debug print
  • Loading branch information
masci authored Nov 8, 2024
1 parent 34b9299 commit 955ef73
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 30 deletions.
32 changes: 10 additions & 22 deletions e2e_tests/basic_streaming/test_run_client.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import pytest

from llama_deploy import AsyncLlamaDeployClient, ControlPlaneConfig, LlamaDeployClient
from llama_deploy import Client


@pytest.mark.e2e
def test_run_client(services):
client = LlamaDeployClient(ControlPlaneConfig(), timeout=10)
client = Client(timeout=10)

# sanity check
sessions = client.list_sessions()
sessions = client.sync.core.sessions.list()
assert len(sessions) == 0, "Sessions list is not empty"

# test streaming
session = client.create_session()
session = client.sync.core.sessions.create()

# kick off run
task_id = session.run_nowait("streaming_workflow", arg1="hello_world")
Expand All @@ -30,27 +30,19 @@ def test_run_client(services):

# get final result
final_result = session.get_task_result(task_id)
assert (
final_result.result == "hello_world_result_result_result" # type: ignore
), "Final result is not 'hello_world_result_result_result'"
assert final_result.result == "hello_world_result_result_result" # type: ignore

# delete everything
client.delete_session(session.session_id)
sessions = client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"
client.sync.core.sessions.delete(session.id)


@pytest.mark.e2e
@pytest.mark.asyncio
async def test_run_client_async(services):
client = AsyncLlamaDeployClient(ControlPlaneConfig(), timeout=10)

# sanity check
sessions = await client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"
client = Client(timeout=10)

# test streaming
session = await client.create_session()
session = await client.core.sessions.create()

# kick off run
task_id = await session.run_nowait("streaming_workflow", arg1="hello_world")
Expand All @@ -67,11 +59,7 @@ async def test_run_client_async(services):
assert event["progress"] == 0.9

final_result = await session.get_task_result(task_id)
assert (
final_result.result == "hello_world_result_result_result" # type: ignore
), "Final result is not 'hello_world_result_result_result'"
assert final_result.result == "hello_world_result_result_result" # type: ignore

# delete everything
await client.delete_session(session.session_id)
sessions = await client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"
await client.core.sessions.delete(session.id)
9 changes: 8 additions & 1 deletion llama_deploy/client/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any

from .base import _BaseClient
Expand Down Expand Up @@ -28,7 +29,13 @@ def normal_function():
@property
def sync(self) -> "_SyncClient":
"""Returns the sync version of the client API."""
return _SyncClient(**self.model_dump())
try:
asyncio.get_running_loop()
except RuntimeError:
return _SyncClient(**self.model_dump())

msg = "You cannot use the sync client within an async event loop - just await the async methods directly."
raise RuntimeError(msg)

@property
def apiserver(self) -> ApiServer:
Expand Down
35 changes: 34 additions & 1 deletion llama_deploy/client/models/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
from typing import Any
import time
from typing import Any, AsyncGenerator

import httpx
from llama_index.core.workflow import Event
Expand Down Expand Up @@ -108,6 +109,38 @@ async def send_event(self, service_name: str, task_id: str, ev: Event) -> None:
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks/{task_id}/send_event"
await self.client.request("POST", url, json=event_def.model_dump())

async def get_task_result_stream(
self, task_id: str
) -> AsyncGenerator[dict[str, Any], None]:
"""Get the result of a task in this session if it has one.
Args:
task_id (str): The ID of the task to get the result for.
Returns:
AsyncGenerator[str, None, None]: A generator that yields the result of the task.
"""
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks/{task_id}/result_stream"
start_time = time.time()
while True:
try:
async with httpx.AsyncClient() as client:
async with client.stream("GET", url) as response:
response.raise_for_status()
async for line in response.aiter_lines():
json_line = json.loads(line)
yield json_line
break # Exit the function if successful
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise # Re-raise if it's not a 404 error
if time.time() - start_time < self.client.timeout:
await asyncio.sleep(self.client.poll_interval)
else:
raise TimeoutError(
f"Task result not available after waiting for {self.client.timeout} seconds"
)


class SessionCollection(Collection):
async def list(self) -> list[Session]: # type: ignore
Expand Down
34 changes: 29 additions & 5 deletions llama_deploy/client/models/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
from typing import Any, Generic, TypeVar
import inspect
from typing import Any, AsyncGenerator, Callable, Generic, TypeVar

from asgiref.sync import async_to_sync
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from typing_extensions import ParamSpec

from llama_deploy.client.base import _BaseClient

Expand Down Expand Up @@ -42,15 +44,37 @@ def list(self) -> list[T]:
return [self.get(id) for id in self.items.keys()]


# Generic type for what's returned by the async generator
_G = TypeVar("_G")
# Generic parameter for the wrapped generator method
_P = ParamSpec("_P")
# Generic parameter for the wrapped generator method return value
_R = TypeVar("_R")


async def _async_gen_to_list(async_gen: AsyncGenerator[_G, None]) -> list[_G]:
return [item async for item in async_gen]


def make_sync(_class: type[T]) -> Any:
"""Wraps the methods of the given model class so that they can be called without `await`."""

class Wrapper(_class): # type: ignore
class ModelWrapper(_class): # type: ignore
_instance_is_sync: bool = True

def generator_wrapper(
func: Callable[_P, AsyncGenerator[_G, None]], /, *args: Any, **kwargs: Any
) -> Callable[_P, list[_G]]:
def new_func(*fargs: Any, **fkwargs: Any) -> list[_G]:
return asyncio.run(_async_gen_to_list(func(*fargs, **fkwargs)))

return new_func

for name, method in _class.__dict__.items():
# Only wrap async public methods
if asyncio.iscoroutinefunction(method) and not name.startswith("_"):
setattr(Wrapper, name, async_to_sync(method))
if inspect.isasyncgenfunction(method):
setattr(ModelWrapper, name, generator_wrapper(method))
elif asyncio.iscoroutinefunction(method) and not name.startswith("_"):
setattr(ModelWrapper, name, async_to_sync(method))

return Wrapper
return ModelWrapper
95 changes: 95 additions & 0 deletions tests/client/models/test_core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from unittest import mock

import httpx
Expand Down Expand Up @@ -282,3 +283,97 @@ async def test_session_run_nowait(client: mock.AsyncMock) -> None:
"task_id": mock.ANY,
},
)


@pytest.mark.asyncio
async def test_get_task_result_stream_success(client: mock.AsyncMock) -> None:
class MockResponse:
async def aiter_lines(self): # type: ignore
yield json.dumps({"status": "running", "progress": 0})
yield json.dumps({"status": "completed", "progress": 100})

def raise_for_status(self): # type: ignore
pass

class MockStreamClient:
async def __aenter__(self): # type: ignore
return MockResponse()

async def __aexit__(self, *args): # type: ignore
pass

class HttpxMockClient:
async def __aenter__(self): # type: ignore
return self

async def __aexit__(self, *args): # type: ignore
pass

def stream(self, *args, **kwargs): # type: ignore
return MockStreamClient()

with mock.patch("httpx.AsyncClient", return_value=HttpxMockClient()):
session = Session(client=client, id="test_session_id")

results = []
async for result in session.get_task_result_stream("test_task_id"):
results.append(result)

assert len(results) == 2
assert results[0]["status"] == "running"
assert results[1]["status"] == "completed"


@pytest.mark.asyncio
async def test_get_task_result_stream_timeout(client: mock.AsyncMock) -> None:
class Mock404Response:
status_code = 404

class HttpxMockClient:
async def __aenter__(self): # type: ignore
return self

async def __aexit__(self, *args): # type: ignore
pass

def stream(self, *args, **kwargs): # type: ignore
raise httpx.HTTPStatusError(
"404 Not Found",
request=mock.MagicMock(),
response=Mock404Response(), # type: ignore
)

with mock.patch("httpx.AsyncClient", return_value=HttpxMockClient()):
client.timeout = 1
session = Session(client=client, id="test_session_id")

with pytest.raises(TimeoutError):
async for _ in session.get_task_result_stream("test_task_id"):
pass


@pytest.mark.asyncio
async def test_get_task_result_stream_error(client: mock.AsyncMock) -> None:
class Mock500Response:
status_code = 500

class HttpxMockClient:
async def __aenter__(self): # type: ignore
return self

async def __aexit__(self, *args): # type: ignore
pass

def stream(self, *args, **kwargs): # type: ignore
raise httpx.HTTPStatusError(
"500 Internal Server Error",
request=mock.MagicMock(),
response=Mock500Response(), # type: ignore
)

with mock.patch("httpx.AsyncClient", return_value=HttpxMockClient()):
session = Session(client=client, id="test_session_id")

with pytest.raises(httpx.HTTPStatusError):
async for _ in session.get_task_result_stream("test_task_id"):
pass
19 changes: 18 additions & 1 deletion tests/client/models/test_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import asyncio
from typing import AsyncGenerator

import pytest

from llama_deploy.client import Client
from llama_deploy.client.models import Collection, Model
from llama_deploy.client.models.model import make_sync
from llama_deploy.client.models.model import _async_gen_to_list, make_sync


class SomeAsyncModel(Model):
async def method(self) -> int:
return 0

async def generator_method(self) -> AsyncGenerator:
yield 4
yield 2


def test_make_sync() -> None:
assert asyncio.iscoroutinefunction(getattr(SomeAsyncModel, "method"))
Expand All @@ -20,6 +27,7 @@ def test_make_sync_instance(client: Client) -> None:
some_sync = make_sync(SomeAsyncModel)(client=client, id="foo")
assert not asyncio.iscoroutinefunction(some_sync.method)
assert some_sync.method() + 1 == 1
assert some_sync.generator_method() == [4, 2]


def test__prepare(client: Client) -> None:
Expand All @@ -42,3 +50,12 @@ class MyCollection(Collection):
assert coll.get("foo").id == "foo"
assert coll.get("bar").id == "bar"
assert coll.list() == models_list


@pytest.mark.asyncio
async def test__async_gen_to_list() -> None:
async def aiter_lines(): # type: ignore
yield "one"
yield "two"

assert await _async_gen_to_list(aiter_lines()) == ["one", "two"]
10 changes: 10 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def test_client_sync() -> None:
assert sc.poll_interval == 0.5


@pytest.mark.asyncio
async def test_client_sync_within_loop() -> None:
c = Client()
with pytest.raises(
RuntimeError,
match="You cannot use the sync client within an async event loop - just await the async methods directly.",
):
c.sync


def test_client_attributes() -> None:
c = Client()
assert type(c.apiserver) is ApiServer
Expand Down

0 comments on commit 955ef73

Please sign in to comment.