Skip to content

Commit

Permalink
migrate hitl tests to the new client (#347)
Browse files Browse the repository at this point in the history
  • Loading branch information
masci authored Nov 5, 2024
1 parent 7455b96 commit b96db7a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 16 deletions.
31 changes: 16 additions & 15 deletions e2e_tests/basic_hitl/test_run_client.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import asyncio
import pytest
import time

from llama_deploy import AsyncLlamaDeployClient, ControlPlaneConfig, LlamaDeployClient
import pytest
from llama_index.core.workflow.events import HumanResponseEvent

from llama_deploy import Client


@pytest.mark.e2ehitl
@pytest.mark.e2e
def test_run_client(services):
client = LlamaDeployClient(ControlPlaneConfig(), timeout=10)
client = Client(timeout=10)

# sanity check
sessions = client.list_sessions()
sessions = client.sync.core.sessions.list()
assert len(sessions) == 0, "Sessions list is not empty"

# create a session
session = client.create_session()
session = client.sync.core.sessions.create()

# kick off run
task_id = session.run_nowait("hitl_workflow")
Expand All @@ -35,22 +36,22 @@ def test_run_client(services):
assert final_result.result == "42", "The human's response is not consistent."

# delete the session
client.delete_session(session.session_id)
sessions = client.list_sessions()
client.sync.core.sessions.delete(session.id)
sessions = client.sync.core.sessions.list()
assert len(sessions) == 0, "Sessions list is not empty"


@pytest.mark.e2ehitl
@pytest.mark.e2e
@pytest.mark.asyncio
async def test_run_client_async(services):
client = AsyncLlamaDeployClient(ControlPlaneConfig(), timeout=10)
client = Client(timeout=10)

# sanity check
sessions = await client.list_sessions()
sessions = await client.core.sessions.list()
assert len(sessions) == 0, "Sessions list is not empty"

# create a session
session = await client.create_session()
session = await client.core.sessions.create()

# kick off run
task_id = await session.run_nowait("hitl_workflow")
Expand All @@ -66,10 +67,10 @@ async def test_run_client_async(services):
final_result = None
while final_result is None:
final_result = await session.get_task_result(task_id)
asyncio.sleep(0.1)
await asyncio.sleep(0.1)
assert final_result.result == "42", "The human's response is not consistent."

# delete the session
await client.delete_session(session.session_id)
sessions = await client.list_sessions()
await client.core.sessions.delete(session.id)
sessions = await client.core.sessions.list()
assert len(sessions) == 0, "Sessions list is not empty"
35 changes: 34 additions & 1 deletion llama_deploy/client/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
from typing import Any

import httpx
from llama_index.core.workflow import Event
from llama_index.core.workflow.context_serializers import JsonSerializer

from llama_deploy.types.core import ServiceDefinition, TaskDefinition, TaskResult
from llama_deploy.types.core import (
EventDefinition,
ServiceDefinition,
TaskDefinition,
TaskResult,
)

from .model import Collection, Model

Expand All @@ -27,6 +34,15 @@ async def _get_result() -> str:

return await asyncio.wait_for(_get_result(), timeout=self.client.timeout)

async def run_nowait(self, service_name: str, **run_kwargs: Any) -> str:
"""Implements the workflow-based run API for a session, but does not wait for the task to complete."""

task_input = json.dumps(run_kwargs)
task_def = TaskDefinition(input=task_input, agent_id=service_name)
task_id = await self._do_create_task(task_def)

return task_id

async def create_task(self, task_def: TaskDefinition) -> str:
"""Create a new task in this session.
Expand Down Expand Up @@ -75,6 +91,23 @@ async def get_tasks(self) -> list[TaskDefinition]:
response = await self.client.request("GET", url)
return [TaskDefinition(**task) for task in response.json()]

async def send_event(self, service_name: str, task_id: str, ev: Event) -> None:
"""Send event to a Workflow service.
Args:
event (Event): The event to be submitted to the workflow.
Returns:
None
"""
serializer = JsonSerializer()
event_def = EventDefinition(
event_obj_str=serializer.serialize(ev), agent_id=service_name
)

url = f"{self.client.control_plane_url}/sessions/{self.id}/tasks/{task_id}/send_event"
await self.client.request("POST", url, json=event_def.model_dump())


class SessionCollection(Collection):
async def list(self) -> list[Session]: # type: ignore
Expand Down
35 changes: 35 additions & 0 deletions tests/client/models/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import httpx
import pytest
from llama_index.core.workflow import Event

from llama_deploy.client.models.core import (
Core,
Expand Down Expand Up @@ -247,3 +248,37 @@ async def test_session_get_tasks(client: mock.AsyncMock) -> None:
assert tasks[1].input == "task2 input"
assert tasks[1].agent_id == "agent2"
assert tasks[1].session_id == "test_session_id"


@pytest.mark.asyncio
async def test_session_send_event(client: mock.AsyncMock) -> None:
event = Event(event_type="test_event", payload={"key": "value"})
session = Session(client=client, id="test_session_id")

await session.send_event("test_service", "test_task_id", event)

client.request.assert_awaited_once_with(
"POST",
"http://localhost:8000/sessions/test_session_id/tasks/test_task_id/send_event",
json={"event_obj_str": mock.ANY, "agent_id": "test_service"},
)


@pytest.mark.asyncio
async def test_session_run_nowait(client: mock.AsyncMock) -> None:
client.request.return_value = mock.MagicMock(json=lambda: "test_task_id")

session = Session(client=client, id="test_session_id")
task_id = await session.run_nowait("test_service", test_param="test_value")

assert task_id == "test_task_id"
client.request.assert_awaited_once_with(
"POST",
"http://localhost:8000/sessions/test_session_id/tasks",
json={
"input": '{"test_param": "test_value"}',
"agent_id": "test_service",
"session_id": "test_session_id",
"task_id": mock.ANY,
},
)

0 comments on commit b96db7a

Please sign in to comment.