diff --git a/e2e_tests/apiserver/__init__.py b/e2e_tests/apiserver/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/e2e_tests/apiserver/conftest.py b/e2e_tests/apiserver/conftest.py new file mode 100644 index 00000000..9cb57aca --- /dev/null +++ b/e2e_tests/apiserver/conftest.py @@ -0,0 +1,29 @@ +import multiprocessing +import time + +import pytest +import uvicorn + +from llama_deploy.client import Client +from llama_deploy.client.client_settings import ClientSettings + + +def run_async_apiserver(): + uvicorn.run("llama_deploy.apiserver:app", host="127.0.0.1", port=4501) + + +@pytest.fixture(scope="module") +def apiserver(): + p = multiprocessing.Process(target=run_async_apiserver) + p.start() + time.sleep(3) + + yield + + p.kill() + + +@pytest.fixture +def client(): + s = ClientSettings(api_server_url="http://localhost:4501") + return Client(**s.model_dump()) diff --git a/e2e_tests/apiserver/deployments/deployment1.yml b/e2e_tests/apiserver/deployments/deployment1.yml new file mode 100644 index 00000000..e63acffc --- /dev/null +++ b/e2e_tests/apiserver/deployments/deployment1.yml @@ -0,0 +1,15 @@ +name: TestDeployment1 + +control-plane: {} + +default-service: dummy_workflow + +services: + test-workflow: + name: Test Workflow + port: 8002 + host: localhost + source: + type: git + name: https://github.com/run-llama/llama_deploy.git + path: tests/apiserver/data/workflow:my_workflow diff --git a/e2e_tests/apiserver/deployments/deployment2.yml b/e2e_tests/apiserver/deployments/deployment2.yml new file mode 100644 index 00000000..1699d78f --- /dev/null +++ b/e2e_tests/apiserver/deployments/deployment2.yml @@ -0,0 +1,15 @@ +name: TestDeployment2 + +control-plane: {} + +default-service: dummy_workflow + +services: + test-workflow: + name: Test Workflow + port: 8002 + host: localhost + source: + type: git + name: https://github.com/run-llama/llama_deploy.git + path: tests/apiserver/data/workflow:my_workflow diff --git a/e2e_tests/apiserver/deployments/deployment_streaming.yml b/e2e_tests/apiserver/deployments/deployment_streaming.yml new file mode 100644 index 00000000..4d0c6ecf --- /dev/null +++ b/e2e_tests/apiserver/deployments/deployment_streaming.yml @@ -0,0 +1,14 @@ +name: Streaming + +control-plane: + port: 8000 + +default-service: streaming_workflow + +services: + streaming_workflow: + name: Streaming Workflow + source: + type: local + name: ./e2e_tests/apiserver/deployments/src + path: workflow:streaming_workflow diff --git a/e2e_tests/apiserver/deployments/src/workflow.py b/e2e_tests/apiserver/deployments/src/workflow.py new file mode 100644 index 00000000..ac3f47ad --- /dev/null +++ b/e2e_tests/apiserver/deployments/src/workflow.py @@ -0,0 +1,41 @@ +import asyncio + +from llama_index.core.workflow import ( + Context, + Event, + StartEvent, + StopEvent, + Workflow, + step, +) + + +class Message(Event): + text: str + + +class EchoWorkflow(Workflow): + """A dummy workflow streaming three events.""" + + @step() + async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent: + for i in range(3): + ctx.write_event_to_stream(Message(text=f"message number {i+1}")) + await asyncio.sleep(0.5) + + return StopEvent(result="Done.") + + +streaming_workflow = EchoWorkflow() + + +async def main(): + h = streaming_workflow.run(message="Hello!") + async for ev in h.stream_events(): + if type(ev) is Message: + print(ev.text) + print(await h) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/e2e_tests/apiserver/test_deploy.py b/e2e_tests/apiserver/test_deploy.py new file mode 100644 index 00000000..fc836068 --- /dev/null +++ b/e2e_tests/apiserver/test_deploy.py @@ -0,0 +1,23 @@ +from pathlib import Path + +import pytest + + +@pytest.mark.asyncio +async def test_deploy(apiserver, client): + here = Path(__file__).parent + deployments = await client.apiserver.deployments() + with open(here / "deployments" / "deployment1.yml") as f: + await deployments.create(f) + + status = await client.apiserver.status() + assert "TestDeployment1" in status.deployments + + +def test_deploy_sync(apiserver, client): + here = Path(__file__).parent + deployments = client.sync.apiserver.deployments() + with open(here / "deployments" / "deployment2.yml") as f: + deployments.create(f) + + assert "TestDeployment2" in client.sync.apiserver.status().deployments diff --git a/e2e_tests/apiserver/test_status.py b/e2e_tests/apiserver/test_status.py new file mode 100644 index 00000000..d8f7759d --- /dev/null +++ b/e2e_tests/apiserver/test_status.py @@ -0,0 +1,23 @@ +import pytest + + +@pytest.mark.asyncio +async def test_status_down(client): + res = await client.apiserver.status() + assert res.status.value == "Down" + + +def test_status_down_sync(client): + res = client.sync.apiserver.status() + assert res.status.value == "Down" + + +@pytest.mark.asyncio +async def test_status_up(apiserver, client): + res = await client.apiserver.status() + assert res.status.value == "Healthy" + + +def test_status_up_sync(apiserver, client): + res = client.sync.apiserver.status() + assert res.status.value == "Healthy" diff --git a/e2e_tests/apiserver/test_streaming.py b/e2e_tests/apiserver/test_streaming.py new file mode 100644 index 00000000..ac06caf5 --- /dev/null +++ b/e2e_tests/apiserver/test_streaming.py @@ -0,0 +1,21 @@ +import asyncio +from pathlib import Path + +import pytest + +from llama_deploy.types import TaskDefinition + + +@pytest.mark.asyncio +async def test_stream(apiserver, client): + here = Path(__file__).parent + + with open(here / "deployments" / "deployment_streaming.yml") as f: + deployments = await client.apiserver.deployments() + deployment = await deployments.create(f) + await asyncio.sleep(5) + + tasks = await deployment.tasks() + task = await tasks.create(TaskDefinition(input='{"a": "b"}')) + async for ev in task.events(): + print(ev) diff --git a/llama_deploy/apiserver/deployment.py b/llama_deploy/apiserver/deployment.py index fdf70027..f4b20223 100644 --- a/llama_deploy/apiserver/deployment.py +++ b/llama_deploy/apiserver/deployment.py @@ -7,32 +7,26 @@ from typing import Any from llama_deploy import ( + AsyncLlamaDeployClient, ControlPlaneServer, SimpleMessageQueue, - SimpleOrchestratorConfig, SimpleOrchestrator, + SimpleOrchestratorConfig, WorkflowService, WorkflowServiceConfig, - AsyncLlamaDeployClient, ) from llama_deploy.message_queues import ( - BaseMessageQueue, - SimpleMessageQueueConfig, AWSMessageQueue, + BaseMessageQueue, KafkaMessageQueue, RabbitMQMessageQueue, RedisMessageQueue, + SimpleMessageQueueConfig, ) -from .config_parser import ( - Config, - SourceType, - Service, - MessageQueueConfig, -) +from .config_parser import Config, MessageQueueConfig, Service, SourceType from .source_managers import GitSourceManager, LocalSourceManager, SourceManager - SOURCE_MANAGERS: dict[SourceType, SourceManager] = { SourceType.git: GitSourceManager(), SourceType.local: LocalSourceManager(), diff --git a/llama_deploy/apiserver/routers/deployments.py b/llama_deploy/apiserver/routers/deployments.py index 2406047a..c7081c0a 100644 --- a/llama_deploy/apiserver/routers/deployments.py +++ b/llama_deploy/apiserver/routers/deployments.py @@ -1,14 +1,13 @@ import json +from typing import AsyncGenerator -from fastapi import APIRouter, File, UploadFile, HTTPException +from fastapi import APIRouter, File, HTTPException, UploadFile from fastapi.responses import JSONResponse, StreamingResponse -from typing import AsyncGenerator -from llama_deploy.apiserver.server import manager from llama_deploy.apiserver.config_parser import Config +from llama_deploy.apiserver.server import manager from llama_deploy.types import TaskDefinition - deployments_router = APIRouter( prefix="/deployments", ) @@ -144,6 +143,23 @@ async def get_task_result( return JSONResponse(result.result if result else "") +@deployments_router.get("/{deployment_name}/tasks") +async def get_tasks( + deployment_name: str, +) -> JSONResponse: + """Get the active sessions in a deployment and service.""" + deployment = manager.get_deployment(deployment_name) + if deployment is None: + raise HTTPException(status_code=404, detail="Deployment not found") + + tasks: list[TaskDefinition] = [] + for session_def in await deployment.client.list_sessions(): + session = await deployment.client.get_session(session_id=session_def.session_id) + for task_def in await session.get_tasks(): + tasks.append(task_def) + return JSONResponse(tasks) + + @deployments_router.get("/{deployment_name}/sessions") async def get_sessions( deployment_name: str, diff --git a/llama_deploy/client/__init__.py b/llama_deploy/client/__init__.py index 9679b7a7..de8fbf04 100644 --- a/llama_deploy/client/__init__.py +++ b/llama_deploy/client/__init__.py @@ -1,4 +1,5 @@ -from llama_deploy.client.async_client import AsyncLlamaDeployClient -from llama_deploy.client.sync_client import LlamaDeployClient +from .async_client import AsyncLlamaDeployClient +from .sync_client import LlamaDeployClient +from .client import Client -__all__ = ["AsyncLlamaDeployClient", "LlamaDeployClient"] +__all__ = ["AsyncLlamaDeployClient", "Client", "LlamaDeployClient"] diff --git a/llama_deploy/client/base.py b/llama_deploy/client/base.py new file mode 100644 index 00000000..032fd192 --- /dev/null +++ b/llama_deploy/client/base.py @@ -0,0 +1,22 @@ +from typing import Any + +import httpx + +from .client_settings import ClientSettings + + +class _BaseClient: + """Base type for clients, to be used in Pydantic models to avoid circular imports.""" + + def __init__(self, **kwargs: Any) -> None: + self.settings = ClientSettings(**kwargs) + + async def request( + self, method: str, url: str | httpx.URL, *args: Any, **kwargs: Any + ) -> httpx.Response: + """Performs an async HTTP request using httpx.""" + verify = kwargs.pop("verify", True) + async with httpx.AsyncClient(verify=verify) as client: + response = await client.request(method, url, *args, **kwargs) + response.raise_for_status() + return response diff --git a/llama_deploy/client/client.py b/llama_deploy/client/client.py new file mode 100644 index 00000000..90dd928e --- /dev/null +++ b/llama_deploy/client/client.py @@ -0,0 +1,23 @@ +from .base import _BaseClient +from .models import ApiServer + + +class Client(_BaseClient): + """Fixme. + + Fixme. + """ + + @property + def sync(self) -> "Client": + return _SyncClient(**self.settings.model_dump()) + + @property + def apiserver(self) -> ApiServer: + return ApiServer.instance(client=self, id="apiserver") + + +class _SyncClient(Client): + @property + def apiserver(self) -> ApiServer: + return ApiServer.instance(make_sync=True, client=self, id="apiserver") diff --git a/llama_deploy/client/client_settings.py b/llama_deploy/client/client_settings.py new file mode 100644 index 00000000..41837fb4 --- /dev/null +++ b/llama_deploy/client/client_settings.py @@ -0,0 +1,10 @@ +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class ClientSettings(BaseSettings): + model_config = SettingsConfigDict(env_prefix="LLAMA_DEPLOY_") + + api_server_url: str = "http://localhost:4501" + disable_ssl: bool = False + timeout: float = 120.0 + poll_interval: float = 0.5 diff --git a/llama_deploy/client/models/__init__.py b/llama_deploy/client/models/__init__.py new file mode 100644 index 00000000..ac7104c9 --- /dev/null +++ b/llama_deploy/client/models/__init__.py @@ -0,0 +1,4 @@ +from .apiserver import ApiServer +from .model import Collection, Model + +__all__ = ["ApiServer", "Collection", "Model"] diff --git a/llama_deploy/client/models/apiserver.py b/llama_deploy/client/models/apiserver.py new file mode 100644 index 00000000..d242e705 --- /dev/null +++ b/llama_deploy/client/models/apiserver.py @@ -0,0 +1,256 @@ +import asyncio +import json +from typing import Any, AsyncGenerator, TextIO + +import httpx + +from llama_deploy.types.apiserver import Status, StatusEnum +from llama_deploy.types.core import TaskDefinition, TaskResult + +from .model import Collection, Model + +DEFAULT_POLL_INTERVAL = 0.5 + + +class Session(Model): + pass + + +class SessionCollection(Collection): + deployment_id: str + + async def delete(self, session_id: str) -> None: + settings = self.client.settings + delete_url = f"{settings.api_server_url}/deployments/{self.deployment_id}/sessions/delete" + + await self.client.request( + "POST", + delete_url, + params={"session_id": session_id}, + verify=not settings.disable_ssl, + timeout=settings.timeout, + ) + + +class Task(Model): + deployment_id: str + session_id: str + + async def results(self, session_id: str) -> TaskResult: + settings = self.client.settings + results_url = f"{settings.api_server_url}/deployments/{self.deployment_id}/tasks/{self.id}/results" + + r = await self.client.request( + "GET", + results_url, + verify=not settings.disable_ssl, + params={"session_id": session_id}, + timeout=settings.timeout, + ) + return TaskResult.model_validate_json(r.json()) + + async def events(self) -> AsyncGenerator[dict[str, Any], None]: # pragma: no cover + settings = self.client.settings + events_url = f"{settings.api_server_url}/deployments/{self.deployment_id}/tasks/{self.id}/events" + + while True: + try: + async with httpx.AsyncClient(verify=not settings.disable_ssl) as client: + async with client.stream( + "GET", events_url, params={"session_id": self.session_id} + ) as response: + response.raise_for_status() + async for line in response.aiter_lines(): + json_line = json.loads(line) + yield json_line + break # Exit the function if successful + except httpx.HTTPStatusError as e: + if e.response.status_code != 404: + raise # Re-raise if it's not a 404 error + await asyncio.sleep(DEFAULT_POLL_INTERVAL) + + +class TaskCollection(Collection): + deployment_id: str + + async def run(self, task: TaskDefinition) -> Any: + settings = self.client.settings + run_url = ( + f"{settings.api_server_url}/deployments/{self.deployment_id}/tasks/run" + ) + + r = await self.client.request( + "POST", + run_url, + verify=not settings.disable_ssl, + json=task.model_dump(), + timeout=settings.timeout, + ) + + return r.json() + + async def create(self, task: TaskDefinition) -> Task: + settings = self.client.settings + create_url = ( + f"{settings.api_server_url}/deployments/{self.deployment_id}/tasks/create" + ) + + r = await self.client.request( + "POST", + create_url, + verify=not settings.disable_ssl, + json=task.model_dump(), + timeout=settings.timeout, + ) + response_fields = r.json() + + return Task.instance( + make_sync=self._instance_is_sync, + client=self.client, + deployment_id=self.deployment_id, + id=response_fields["task_id"], + session_id=response_fields["session_id"], + ) + + +class Deployment(Model): + async def tasks(self) -> TaskCollection: + settings = self.client.settings + tasks_url = f"{settings.api_server_url}/deployments/{self.id}/tasks" + r = await self.client.request( + "GET", + tasks_url, + verify=not settings.disable_ssl, + timeout=settings.timeout, + ) + items = { + "id": Task.instance( + make_sync=self._instance_is_sync, + client=self.client, + id=task_def.task_id, + session_id=task_def.session_id, + deployment_id=self.id, + ) + for task_def in r.json() + } + return TaskCollection.instance( + make_sync=self._instance_is_sync, + client=self.client, + deployment_id=self.id, + items=items, + ) + + async def sessions(self) -> SessionCollection: + settings = self.client.settings + sessions_url = f"{settings.api_server_url}/deployments/{self.id}/sessions" + r = await self.client.request( + "GET", + sessions_url, + verify=not settings.disable_ssl, + timeout=settings.timeout, + ) + items = { + "id": Session.instance( + make_sync=self._instance_is_sync, + client=self.client, + id=session_def.session_id, + ) + for session_def in r.json() + } + return SessionCollection.instance( + make_sync=self._instance_is_sync, + client=self.client, + deployment_id=self.id, + items=items, + ) + + +class DeploymentCollection(Collection): + async def create(self, config: TextIO) -> Deployment: + """Creates a deployment""" + settings = self.client.settings + create_url = f"{settings.api_server_url}/deployments/create" + + files = {"config_file": config.read()} + r = await self.client.request( + "POST", + create_url, + files=files, + verify=not settings.disable_ssl, + timeout=settings.timeout, + ) + + return Deployment.instance( + make_sync=self._instance_is_sync, + client=self.client, + id=r.json().get("name"), + ) + + async def get(self, deployment_id: str) -> Deployment: + """Get a deployment by id""" + settings = self.client.settings + get_url = f"{settings.api_server_url}/deployments/{deployment_id}" + # Current version of apiserver doesn't returns anything useful in this endpoint, let's just ignore it + await self.client.request( + "GET", get_url, verify=not settings.disable_ssl, timeout=settings.timeout + ) + return Deployment.instance( + client=self.client, make_sync=self._instance_is_sync, id=deployment_id + ) + + +class ApiServer(Model): + async def status(self) -> Status: + """Returns the status of the API Server.""" + settings = self.client.settings + status_url = f"{settings.api_server_url}/status/" + + try: + r = await self.client.request( + "GET", + status_url, + verify=not settings.disable_ssl, + timeout=settings.timeout, + ) + except httpx.ConnectError: + return Status( + status=StatusEnum.DOWN, + status_message="API Server is down", + ) + + if r.status_code >= 400: + body = r.json() + return Status(status=StatusEnum.UNHEALTHY, status_message=r.text) + + description = "Llama Deploy is up and running." + body = r.json() + deployments = body.get("deployments") or [] + if deployments: + description += "\nActive deployments:" + for d in deployments: + description += f"\n- {d}" + else: + description += "\nCurrently there are no active deployments" + + return Status( + status=StatusEnum.HEALTHY, + status_message=description, + deployments=deployments, + ) + + async def deployments(self) -> DeploymentCollection: + settings = self.client.settings + status_url = f"{settings.api_server_url}/deployments/" + + r = await self.client.request( + "GET", status_url, verify=not settings.disable_ssl, timeout=settings.timeout + ) + deployments = { + "id": Deployment.instance( + make_sync=self._instance_is_sync, client=self.client, id=name + ) + for name in r.json() + } + return DeploymentCollection.instance( + make_sync=self._instance_is_sync, client=self.client, items=deployments + ) diff --git a/llama_deploy/client/models/model.py b/llama_deploy/client/models/model.py new file mode 100644 index 00000000..638e547b --- /dev/null +++ b/llama_deploy/client/models/model.py @@ -0,0 +1,55 @@ +import asyncio +from typing import Any, Generic, TypeVar, cast + +from asgiref.sync import async_to_sync +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from typing_extensions import Self + +from llama_deploy.client.base import _BaseClient + + +class _Base(BaseModel): + client: _BaseClient = Field(exclude=True) + _instance_is_sync: bool = PrivateAttr(default=False) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def __new__(cls, *args, **kwargs): # type: ignore[no-untyped-def] + raise TypeError("Please use instance() instead of direct instantiation") + + @classmethod + def instance(cls, make_sync: bool = False, **kwargs: Any) -> Self: + if make_sync: + cls = _make_sync(cls) + + inst = super(_Base, cls).__new__(cls) + inst.__init__(**kwargs) # type: ignore[misc] + inst._instance_is_sync = make_sync + return inst + + +T = TypeVar("T", bound=_Base) + + +class Model(_Base): + id: str + + +class Collection(_Base, Generic[T]): + items: dict[str, T] + + def get(self, id: str) -> T: + return self.items[id] + + def list(self) -> list[T]: + return [self.get(id) for id in self.items.keys()] + + +def _make_sync(_class: type[T]) -> type[T]: + class Wrapper(_class): # type: ignore + pass + + for name, method in _class.__dict__.items(): + if asyncio.iscoroutinefunction(method) and not name.startswith("_"): + setattr(Wrapper, name, async_to_sync(method)) + return cast(type[T], Wrapper) diff --git a/llama_deploy/sdk/__init__.py b/llama_deploy/sdk/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llama_deploy/types/__init__.py b/llama_deploy/types/__init__.py new file mode 100644 index 00000000..96b47106 --- /dev/null +++ b/llama_deploy/types/__init__.py @@ -0,0 +1,35 @@ +from .core import ( + CONTROL_PLANE_NAME, + ActionTypes, + ChatMessage, + HumanResponse, + MessageRole, + PydanticValidatedUrl, + ServiceDefinition, + SessionDefinition, + TaskDefinition, + TaskResult, + TaskStream, + ToolCall, + ToolCallBundle, + ToolCallResult, + generate_id, +) + +__all__ = [ + "CONTROL_PLANE_NAME", + "ActionTypes", + "ChatMessage", + "HumanResponse", + "MessageRole", + "PydanticValidatedUrl", + "ServiceDefinition", + "SessionDefinition", + "TaskDefinition", + "TaskResult", + "TaskStream", + "ToolCall", + "ToolCallBundle", + "ToolCallResult", + "generate_id", +] diff --git a/llama_deploy/types/apiserver.py b/llama_deploy/types/apiserver.py new file mode 100644 index 00000000..ea977d69 --- /dev/null +++ b/llama_deploy/types/apiserver.py @@ -0,0 +1,16 @@ +from enum import Enum + +from pydantic import BaseModel + + +class StatusEnum(Enum): + HEALTHY = "Healthy" + UNHEALTHY = "Unhealthy" + DOWN = "Down" + + +class Status(BaseModel): + status: StatusEnum + status_message: str + max_deployments: int | None = None + deployments: list[str] | None = None diff --git a/llama_deploy/types.py b/llama_deploy/types/core.py similarity index 100% rename from llama_deploy/types.py rename to llama_deploy/types/core.py diff --git a/poetry.lock b/poetry.lock index 1948f10e..66f32e75 100644 --- a/poetry.lock +++ b/poetry.lock @@ -288,6 +288,23 @@ doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] trio = ["trio (>=0.26.1)"] +[[package]] +name = "asgiref" +version = "3.8.1" +description = "ASGI specs, helper code, and adapters" +optional = false +python-versions = ">=3.8" +files = [ + {file = "asgiref-3.8.1-py3-none-any.whl", hash = "sha256:3e1e3ecc849832fe52ccf2cb6686b7a55f82bb1d6aee72a58826471390335e47"}, + {file = "asgiref-3.8.1.tar.gz", hash = "sha256:c343bd80a0bec947a9860adb4c432ffa7db769836c64238fc34bdc3fec84d590"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""} + +[package.extras] +tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] + [[package]] name = "async-timeout" version = "4.0.3" @@ -3004,4 +3021,4 @@ redis = ["redis"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "c223e362ffa6f52cfd0c7292961f7f0d38538e362eff496da63f0d6c8fa3dea9" +content-hash = "1063f9f5a1883f755d70e7af6480c831382c661b28518d2c70fa249ab52a09e9" diff --git a/pyproject.toml b/pyproject.toml index 29b186e9..d08ec64c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ types-aiobotocore = {version = "^2.14.0", optional = true, extras = ["sqs", "sns gitpython = "^3.1.43" python-multipart = "^0.0.10" typing_extensions = "^4.0.0" +asgiref = "^3.8.1" [tool.poetry.extras] kafka = ["aiokafka", "kafka-python-ng"] diff --git a/tests/client/models/__init__.py b/tests/client/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/client/models/conftest.py b/tests/client/models/conftest.py new file mode 100644 index 00000000..8816371b --- /dev/null +++ b/tests/client/models/conftest.py @@ -0,0 +1,14 @@ +from typing import Any, Iterator +from unittest import mock + +import pytest + +from llama_deploy.client import Client + + +@pytest.fixture +def client(monkeypatch: Any) -> Iterator[Client]: + with mock.patch("llama_deploy.client.Client", spec=True): + c = Client() + monkeypatch.setattr(c, "request", mock.AsyncMock()) + yield c diff --git a/tests/client/models/test_apiserver.py b/tests/client/models/test_apiserver.py new file mode 100644 index 00000000..cd5683cb --- /dev/null +++ b/tests/client/models/test_apiserver.py @@ -0,0 +1,265 @@ +import io +from typing import Any +from unittest import mock + +import httpx +import pytest + +from llama_deploy.client.models.apiserver import ( + ApiServer, + Deployment, + DeploymentCollection, + Session, + SessionCollection, + Task, + TaskCollection, +) +from llama_deploy.types import SessionDefinition, TaskDefinition, TaskResult + + +@pytest.mark.asyncio +async def test_session_collection_delete(client: Any) -> None: + coll = SessionCollection.instance( + client=client, + items={"a_session": Session.instance(id="a_session", client=client)}, + deployment_id="a_deployment", + ) + await coll.delete("a_session") + client.request.assert_awaited_with( + "POST", + "http://localhost:4501/deployments/a_deployment/sessions/delete", + params={"session_id": "a_session"}, + timeout=120.0, + verify=True, + ) + + +@pytest.mark.asyncio +async def test_task_results(client: Any) -> None: + res = TaskResult(task_id="a_result", history=[], result="some_text", data={}) + client.request.return_value = mock.MagicMock(json=lambda: res.model_dump_json()) + + t = Task.instance( + client=client, + id="a_task", + deployment_id="a_deployment", + session_id="a_session", + ) + await t.results(session_id="a_session") + + client.request.assert_awaited_with( + "GET", + "http://localhost:4501/deployments/a_deployment/tasks/a_task/results", + verify=True, + params={"session_id": "a_session"}, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_collection_run(client: Any) -> None: + client.request.return_value = mock.MagicMock(json=lambda: "some result") + coll = TaskCollection.instance( + client=client, + items={ + "a_session": Task.instance( + id="a_session", + client=client, + deployment_id="a_deployment", + session_id="a_session", + ) + }, + deployment_id="a_deployment", + ) + await coll.run(TaskDefinition(input="some input", task_id="test_id")) + client.request.assert_awaited_with( + "POST", + "http://localhost:4501/deployments/a_deployment/tasks/run", + verify=True, + json={ + "input": "some input", + "task_id": "test_id", + "session_id": None, + "agent_id": None, + }, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_collection_create(client: Any) -> None: + client.request.return_value = mock.MagicMock( + json=lambda: {"session_id": "a_session", "task_id": "test_id"} + ) + coll = TaskCollection.instance( + client=client, + items={ + "a_session": Task.instance( + id="a_session", + client=client, + deployment_id="a_deployment", + session_id="a_session", + ) + }, + deployment_id="a_deployment", + ) + await coll.create(TaskDefinition(input='{"arg": "test_input"}', task_id="test_id")) + client.request.assert_awaited_with( + "POST", + "http://localhost:4501/deployments/a_deployment/tasks/create", + verify=True, + json={ + "input": '{"arg": "test_input"}', + "task_id": "test_id", + "session_id": None, + "agent_id": None, + }, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_deployment_tasks(client: Any) -> None: + d = Deployment.instance(client=client, id="a_deployment") + res: list[TaskDefinition] = [ + TaskDefinition( + input='{"arg": "input"}', task_id="a_task", session_id="a_session" + ) + ] + client.request.return_value = mock.MagicMock(json=lambda: res) + + await d.tasks() + + client.request.assert_awaited_with( + "GET", + "http://localhost:4501/deployments/a_deployment/tasks", + verify=True, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_deployment_sessions(client: Any) -> None: + d = Deployment.instance(client=client, id="a_deployment") + res: list[SessionDefinition] = [SessionDefinition(session_id="a_session")] + client.request.return_value = mock.MagicMock(json=lambda: res) + + await d.sessions() + + client.request.assert_awaited_with( + "GET", + "http://localhost:4501/deployments/a_deployment/sessions", + verify=True, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_deployment_collection_create(client: Any) -> None: + client.request.return_value = mock.MagicMock(json=lambda: {"name": "deployment"}) + + coll = DeploymentCollection.instance(client=client, items={}) + await coll.create(io.StringIO("some config")) + + client.request.assert_awaited_with( + "POST", + "http://localhost:4501/deployments/create", + files={"config_file": "some config"}, + verify=True, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_task_deployment_collection_get(client: Any) -> None: + d = Deployment.instance(client=client, id="a_deployment") + coll = DeploymentCollection.instance(client=client, items={"a_deployment": d}) + client.request.return_value = mock.MagicMock(json=lambda: {"a_deployment": "Up!"}) + + await coll.get("a_deployment") + + client.request.assert_awaited_with( + "GET", + "http://localhost:4501/deployments/a_deployment", + verify=True, + timeout=120.0, + ) + + +@pytest.mark.asyncio +async def test_status_down(client: Any) -> None: + client.request.side_effect = httpx.ConnectError(message="connection error") + + apis = ApiServer.instance(client=client, id="apiserver") + res = await apis.status() + + client.request.assert_awaited_with( + "GET", "http://localhost:4501/status/", verify=True, timeout=120.0 + ) + assert res.status.value == "Down" + + +@pytest.mark.asyncio +async def test_status_unhealthy(client: Any) -> None: + client.request.return_value = mock.MagicMock( + status_code=400, text="This is a drill." + ) + + apis = ApiServer.instance(client=client, id="apiserver") + res = await apis.status() + + client.request.assert_awaited_with( + "GET", "http://localhost:4501/status/", verify=True, timeout=120.0 + ) + assert res.status.value == "Unhealthy" + assert res.status_message == "This is a drill." + + +@pytest.mark.asyncio +async def test_status_healthy_no_deployments(client: Any) -> None: + client.request.return_value = mock.MagicMock( + status_code=200, text="", json=lambda: {} + ) + + apis = ApiServer.instance(client=client, id="apiserver") + res = await apis.status() + + client.request.assert_awaited_with( + "GET", "http://localhost:4501/status/", verify=True, timeout=120.0 + ) + assert res.status.value == "Healthy" + assert ( + res.status_message + == "Llama Deploy is up and running.\nCurrently there are no active deployments" + ) + + +@pytest.mark.asyncio +async def test_status_healthy(client: Any) -> None: + client.request.return_value = mock.MagicMock( + status_code=200, text="", json=lambda: {"deployments": ["foo", "bar"]} + ) + + apis = ApiServer.instance(client=client, id="apiserver") + res = await apis.status() + + client.request.assert_awaited_with( + "GET", "http://localhost:4501/status/", verify=True, timeout=120.0 + ) + assert res.status.value == "Healthy" + assert ( + res.status_message + == "Llama Deploy is up and running.\nActive deployments:\n- foo\n- bar" + ) + + +@pytest.mark.asyncio +async def test_deployments(client: Any) -> None: + client.request.return_value = mock.MagicMock( + status_code=200, text="", json=lambda: {"deployments": ["foo", "bar"]} + ) + apis = ApiServer.instance(client=client, id="apiserver") + await apis.deployments() + client.request.assert_awaited_with( + "GET", "http://localhost:4501/deployments/", verify=True, timeout=120.0 + ) diff --git a/tests/client/models/test_model.py b/tests/client/models/test_model.py new file mode 100644 index 00000000..f01739e3 --- /dev/null +++ b/tests/client/models/test_model.py @@ -0,0 +1,41 @@ +import asyncio + +import pytest + +from llama_deploy.client import Client +from llama_deploy.client.models import Collection, Model +from llama_deploy.client.models.model import _make_sync + + +class SomeAsyncModel(Model): + async def method(self) -> None: + pass + + +def test_no_init(client: Client) -> None: + with pytest.raises( + TypeError, match=r"Please use instance\(\) instead of direct instantiation" + ): + SomeAsyncModel(id="foo", client=client) + + +def test_make_sync() -> None: + assert asyncio.iscoroutinefunction(getattr(SomeAsyncModel, "method")) + some_sync = _make_sync(SomeAsyncModel) + assert not asyncio.iscoroutinefunction(getattr(some_sync, "method")) + + +def test_collection_get() -> None: + class MyCollection(Collection): + pass + + c = Client() + models_list = [ + SomeAsyncModel.instance(client=c, id="foo"), + SomeAsyncModel.instance(client=c, id="bar"), + ] + + coll = MyCollection.instance(client=c, items={m.id: m for m in models_list}) + assert coll.get("foo").id == "foo" + assert coll.get("bar").id == "bar" + assert coll.list() == models_list diff --git a/tests/client/test_client.py b/tests/client/test_client.py new file mode 100644 index 00000000..659e2c12 --- /dev/null +++ b/tests/client/test_client.py @@ -0,0 +1,52 @@ +from unittest import mock + +import pytest + +from llama_deploy.client import Client +from llama_deploy.client.client import _SyncClient +from llama_deploy.client.models.apiserver import ApiServer + + +def test_client_init_default() -> None: + c = Client() + settings = c.settings + assert settings.api_server_url == "http://localhost:4501" + assert settings.disable_ssl is False + assert settings.timeout == 120.0 + assert settings.poll_interval == 0.5 + + +def test_client_init_settings() -> None: + c = Client(api_server_url="test") + assert c.settings.api_server_url == "test" + + +def test_client_sync() -> None: + c = Client() + sc = c.sync + assert type(sc) is _SyncClient + settings = sc.settings + assert settings.api_server_url == "http://localhost:4501" + assert settings.disable_ssl is False + assert settings.timeout == 120.0 + assert settings.poll_interval == 0.5 + + +def test_client_attributes() -> None: + c = Client() + assert type(c.apiserver) is ApiServer + assert issubclass(type(c.sync.apiserver), ApiServer) + + +@pytest.mark.asyncio +async def test_client_request() -> None: + with mock.patch("llama_deploy.client.base.httpx") as _httpx: + mocked_response = mock.MagicMock() + _httpx.AsyncClient.return_value.__aenter__.return_value.request.return_value = ( + mocked_response + ) + + c = Client() + await c.request("GET", "http://example.com", verify=False) + _httpx.AsyncClient.assert_called_with(verify=False) + mocked_response.raise_for_status.assert_called_once() diff --git a/tests/services/test_human_service.py b/tests/services/test_human_service.py index 07ca14e9..803cf8f6 100644 --- a/tests/services/test_human_service.py +++ b/tests/services/test_human_service.py @@ -1,18 +1,20 @@ import asyncio -import pytest -from pydantic import PrivateAttr, ValidationError from typing import Any, List from unittest.mock import MagicMock, patch -from llama_deploy.services import HumanService -from llama_deploy.services.human import HELP_REQUEST_TEMPLATE_STR -from llama_deploy.message_queues.simple import SimpleMessageQueue + +import pytest +from pydantic import PrivateAttr, ValidationError + from llama_deploy.message_consumers.base import BaseMessageQueueConsumer +from llama_deploy.message_queues.simple import SimpleMessageQueue from llama_deploy.messages.base import QueueMessage +from llama_deploy.services import HumanService +from llama_deploy.services.human import HELP_REQUEST_TEMPLATE_STR from llama_deploy.types import ( - TaskDefinition, - ActionTypes, CONTROL_PLANE_NAME, + ActionTypes, ChatMessage, + TaskDefinition, ) @@ -71,7 +73,7 @@ def test_invalid_human_prompt_raises_validation_error() -> None: @pytest.mark.asyncio() -@patch("llama_deploy.types.uuid") +@patch("llama_deploy.types.core.uuid") async def test_create_task(mock_uuid: MagicMock) -> None: # arrange human_service = HumanService(