Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add get_task_result_stream to new client #351

Merged
merged 3 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions e2e_tests/basic_streaming/test_run_client.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import pytest

from llama_deploy import AsyncLlamaDeployClient, ControlPlaneConfig, LlamaDeployClient
from llama_deploy import Client


@pytest.mark.e2e
def test_run_client(services):
client = LlamaDeployClient(ControlPlaneConfig(), timeout=10)
client = Client(timeout=10)

# sanity check
sessions = client.list_sessions()
sessions = client.sync.core.sessions.list()
assert len(sessions) == 0, "Sessions list is not empty"

# test streaming
session = client.create_session()
session = client.sync.core.sessions.create()

# kick off run
task_id = session.run_nowait("streaming_workflow", arg1="hello_world")
Expand All @@ -30,27 +30,19 @@ def test_run_client(services):

# get final result
final_result = session.get_task_result(task_id)
assert (
final_result.result == "hello_world_result_result_result" # type: ignore
), "Final result is not 'hello_world_result_result_result'"
assert final_result.result == "hello_world_result_result_result" # type: ignore

# delete everything
client.delete_session(session.session_id)
sessions = client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"
client.sync.core.sessions.delete(session.id)


@pytest.mark.e2e
@pytest.mark.asyncio
async def test_run_client_async(services):
client = AsyncLlamaDeployClient(ControlPlaneConfig(), timeout=10)

# sanity check
sessions = await client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"
client = Client(timeout=10)

# test streaming
session = await client.create_session()
session = await client.core.sessions.create()

# kick off run
task_id = await session.run_nowait("streaming_workflow", arg1="hello_world")
Expand All @@ -67,11 +59,7 @@ async def test_run_client_async(services):
assert event["progress"] == 0.9

final_result = await session.get_task_result(task_id)
assert (
final_result.result == "hello_world_result_result_result" # type: ignore
), "Final result is not 'hello_world_result_result_result'"
assert final_result.result == "hello_world_result_result_result" # type: ignore

# delete everything
await client.delete_session(session.session_id)
sessions = await client.list_sessions()
assert len(sessions) == 0, "Sessions list is not empty"
await client.core.sessions.delete(session.id)
9 changes: 8 additions & 1 deletion llama_deploy/client/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any

from .base import _BaseClient
Expand Down Expand Up @@ -28,7 +29,13 @@ def normal_function():
@property
def sync(self) -> "_SyncClient":
"""Returns the sync version of the client API."""
return _SyncClient(**self.model_dump())
try:
asyncio.get_running_loop()
except RuntimeError:
return _SyncClient(**self.model_dump())

msg = "You cannot use the sync client within an async event loop - just await the async methods directly."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, was this something we were running into before. Anyways, nice catch!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This problem always existed but didn't run into it until @logan-markewich made me notice!

raise RuntimeError(msg)

@property
def apiserver(self) -> ApiServer:
Expand Down
35 changes: 34 additions & 1 deletion llama_deploy/client/models/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
from typing import Any
import time
from typing import Any, AsyncGenerator

import httpx
from llama_index.core.workflow import Event
Expand Down Expand Up @@ -108,6 +109,38 @@ async def send_event(self, service_name: str, task_id: str, ev: Event) -> None:
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks/{task_id}/send_event"
await self.client.request("POST", url, json=event_def.model_dump())

async def get_task_result_stream(
self, task_id: str
) -> AsyncGenerator[dict[str, Any], None]:
"""Get the result of a task in this session if it has one.

Args:
task_id (str): The ID of the task to get the result for.

Returns:
AsyncGenerator[str, None, None]: A generator that yields the result of the task.
"""
url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks/{task_id}/result_stream"
start_time = time.time()
while True:
try:
async with httpx.AsyncClient() as client:
async with client.stream("GET", url) as response:
response.raise_for_status()
async for line in response.aiter_lines():
json_line = json.loads(line)
yield json_line
break # Exit the function if successful
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise # Re-raise if it's not a 404 error
if time.time() - start_time < self.client.timeout:
await asyncio.sleep(self.client.poll_interval)
else:
raise TimeoutError(
f"Task result not available after waiting for {self.client.timeout} seconds"
)


class SessionCollection(Collection):
async def list(self) -> list[Session]: # type: ignore
Expand Down
34 changes: 29 additions & 5 deletions llama_deploy/client/models/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
from typing import Any, Generic, TypeVar
import inspect
from typing import Any, AsyncGenerator, Callable, Generic, TypeVar

from asgiref.sync import async_to_sync
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from typing_extensions import ParamSpec

from llama_deploy.client.base import _BaseClient

Expand Down Expand Up @@ -42,15 +44,37 @@ def list(self) -> list[T]:
return [self.get(id) for id in self.items.keys()]


# Generic type for what's returned by the async generator
_G = TypeVar("_G")
# Generic parameter for the wrapped generator method
_P = ParamSpec("_P")
# Generic parameter for the wrapped generator method return value
_R = TypeVar("_R")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly dumb question as perhaps I just missed this, but where are we using this generic param?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not dumb, leftover! I'll remove it, good catch!



async def _async_gen_to_list(async_gen: AsyncGenerator[_G, None]) -> list[_G]:
return [item async for item in async_gen]


def make_sync(_class: type[T]) -> Any:
"""Wraps the methods of the given model class so that they can be called without `await`."""

class Wrapper(_class): # type: ignore
class ModelWrapper(_class): # type: ignore
_instance_is_sync: bool = True

def generator_wrapper(
func: Callable[_P, AsyncGenerator[_G, None]], /, *args: Any, **kwargs: Any
) -> Callable[_P, list[_G]]:
def new_func(*fargs: Any, **fkwargs: Any) -> list[_G]:
return asyncio.run(_async_gen_to_list(func(*fargs, **fkwargs)))
logan-markewich marked this conversation as resolved.
Show resolved Hide resolved

return new_func

for name, method in _class.__dict__.items():
# Only wrap async public methods
if asyncio.iscoroutinefunction(method) and not name.startswith("_"):
setattr(Wrapper, name, async_to_sync(method))
if inspect.isasyncgenfunction(method):
setattr(ModelWrapper, name, generator_wrapper(method))
elif asyncio.iscoroutinefunction(method) and not name.startswith("_"):
setattr(ModelWrapper, name, async_to_sync(method))

return Wrapper
return ModelWrapper
95 changes: 95 additions & 0 deletions tests/client/models/test_core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from unittest import mock

import httpx
Expand Down Expand Up @@ -282,3 +283,97 @@ async def test_session_run_nowait(client: mock.AsyncMock) -> None:
"task_id": mock.ANY,
},
)


@pytest.mark.asyncio
async def test_get_task_result_stream_success(client: mock.AsyncMock) -> None:
class MockResponse:
async def aiter_lines(self): # type: ignore
yield json.dumps({"status": "running", "progress": 0})
yield json.dumps({"status": "completed", "progress": 100})

def raise_for_status(self): # type: ignore
pass

class MockStreamClient:
async def __aenter__(self): # type: ignore
return MockResponse()

async def __aexit__(self, *args): # type: ignore
pass

class HttpxMockClient:
async def __aenter__(self): # type: ignore
return self

async def __aexit__(self, *args): # type: ignore
pass

def stream(self, *args, **kwargs): # type: ignore
return MockStreamClient()

with mock.patch("httpx.AsyncClient", return_value=HttpxMockClient()):
session = Session(client=client, id="test_session_id")

results = []
async for result in session.get_task_result_stream("test_task_id"):
results.append(result)

assert len(results) == 2
assert results[0]["status"] == "running"
assert results[1]["status"] == "completed"


@pytest.mark.asyncio
async def test_get_task_result_stream_timeout(client: mock.AsyncMock) -> None:
class Mock404Response:
status_code = 404

class HttpxMockClient:
async def __aenter__(self): # type: ignore
return self

async def __aexit__(self, *args): # type: ignore
pass

def stream(self, *args, **kwargs): # type: ignore
raise httpx.HTTPStatusError(
"404 Not Found",
request=mock.MagicMock(),
response=Mock404Response(), # type: ignore
)

with mock.patch("httpx.AsyncClient", return_value=HttpxMockClient()):
client.timeout = 1
session = Session(client=client, id="test_session_id")

with pytest.raises(TimeoutError):
async for _ in session.get_task_result_stream("test_task_id"):
pass


@pytest.mark.asyncio
async def test_get_task_result_stream_error(client: mock.AsyncMock) -> None:
class Mock500Response:
status_code = 500

class HttpxMockClient:
async def __aenter__(self): # type: ignore
return self

async def __aexit__(self, *args): # type: ignore
pass

def stream(self, *args, **kwargs): # type: ignore
raise httpx.HTTPStatusError(
"500 Internal Server Error",
request=mock.MagicMock(),
response=Mock500Response(), # type: ignore
)

with mock.patch("httpx.AsyncClient", return_value=HttpxMockClient()):
session = Session(client=client, id="test_session_id")

with pytest.raises(httpx.HTTPStatusError):
async for _ in session.get_task_result_stream("test_task_id"):
pass
19 changes: 18 additions & 1 deletion tests/client/models/test_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import asyncio
from typing import AsyncGenerator

import pytest

from llama_deploy.client import Client
from llama_deploy.client.models import Collection, Model
from llama_deploy.client.models.model import make_sync
from llama_deploy.client.models.model import _async_gen_to_list, make_sync


class SomeAsyncModel(Model):
async def method(self) -> int:
return 0

async def generator_method(self) -> AsyncGenerator:
yield 4
yield 2


def test_make_sync() -> None:
assert asyncio.iscoroutinefunction(getattr(SomeAsyncModel, "method"))
Expand All @@ -20,6 +27,7 @@ def test_make_sync_instance(client: Client) -> None:
some_sync = make_sync(SomeAsyncModel)(client=client, id="foo")
assert not asyncio.iscoroutinefunction(some_sync.method)
assert some_sync.method() + 1 == 1
assert some_sync.generator_method() == [4, 2]


def test__prepare(client: Client) -> None:
Expand All @@ -42,3 +50,12 @@ class MyCollection(Collection):
assert coll.get("foo").id == "foo"
assert coll.get("bar").id == "bar"
assert coll.list() == models_list


@pytest.mark.asyncio
async def test__async_gen_to_list() -> None:
async def aiter_lines(): # type: ignore
yield "one"
yield "two"

assert await _async_gen_to_list(aiter_lines()) == ["one", "two"]
10 changes: 10 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def test_client_sync() -> None:
assert sc.poll_interval == 0.5


@pytest.mark.asyncio
async def test_client_sync_within_loop() -> None:
c = Client()
with pytest.raises(
RuntimeError,
match="You cannot use the sync client within an async event loop - just await the async methods directly.",
):
c.sync


def test_client_attributes() -> None:
c = Client()
assert type(c.apiserver) is ApiServer
Expand Down
Loading