Skip to content

Commit

Permalink
Client refactoring
Browse files Browse the repository at this point in the history
try

checkpoint

add asgiref to support async-to-sync:

fix unit tests

remove pydantic warning

add unit tests for client

fix mock path

test model

added tests and fix discovered bugs

fix connection error handling

fix bugs surfaced in end-to-end

use explicity properties

fix awaitable checks

add instance method

revert to return sync class

extract base model

use instance() method on models

add e2e tests

working state

fix unit tests

more fixes
  • Loading branch information
masci committed Oct 23, 2024
1 parent 3c8410b commit e04269e
Show file tree
Hide file tree
Showing 30 changed files with 1,032 additions and 27 deletions.
Empty file added e2e_tests/apiserver/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions e2e_tests/apiserver/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import multiprocessing
import time

import pytest
import uvicorn

from llama_deploy.client import Client
from llama_deploy.client.client_settings import ClientSettings


def run_async_apiserver():
uvicorn.run("llama_deploy.apiserver:app", host="127.0.0.1", port=4501)


@pytest.fixture(scope="module")
def apiserver():
p = multiprocessing.Process(target=run_async_apiserver)
p.start()
time.sleep(3)

yield

p.kill()


@pytest.fixture
def client():
s = ClientSettings(api_server_url="http://localhost:4501")
return Client(**s.model_dump())
15 changes: 15 additions & 0 deletions e2e_tests/apiserver/deployments/deployment1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: TestDeployment1

control-plane: {}

default-service: dummy_workflow

services:
test-workflow:
name: Test Workflow
port: 8002
host: localhost
source:
type: git
name: https://github.com/run-llama/llama_deploy.git
path: tests/apiserver/data/workflow:my_workflow
15 changes: 15 additions & 0 deletions e2e_tests/apiserver/deployments/deployment2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: TestDeployment2

control-plane: {}

default-service: dummy_workflow

services:
test-workflow:
name: Test Workflow
port: 8002
host: localhost
source:
type: git
name: https://github.com/run-llama/llama_deploy.git
path: tests/apiserver/data/workflow:my_workflow
14 changes: 14 additions & 0 deletions e2e_tests/apiserver/deployments/deployment_streaming.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: Streaming

control-plane:
port: 8000

default-service: streaming_workflow

services:
streaming_workflow:
name: Streaming Workflow
source:
type: local
name: ./e2e_tests/apiserver/deployments/src
path: workflow:streaming_workflow
41 changes: 41 additions & 0 deletions e2e_tests/apiserver/deployments/src/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import asyncio

from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)


class Message(Event):
text: str


class EchoWorkflow(Workflow):
"""A dummy workflow streaming three events."""

@step()
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
for i in range(3):
ctx.write_event_to_stream(Message(text=f"message number {i+1}"))
await asyncio.sleep(0.5)

return StopEvent(result="Done.")


streaming_workflow = EchoWorkflow()


async def main():
h = streaming_workflow.run(message="Hello!")
async for ev in h.stream_events():
if type(ev) is Message:
print(ev.text)
print(await h)


if __name__ == "__main__":
asyncio.run(main())
23 changes: 23 additions & 0 deletions e2e_tests/apiserver/test_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

import pytest


@pytest.mark.asyncio
async def test_deploy(apiserver, client):
here = Path(__file__).parent
deployments = await client.apiserver.deployments()
with open(here / "deployments" / "deployment1.yml") as f:
await deployments.create(f)

status = await client.apiserver.status()
assert "TestDeployment1" in status.deployments


def test_deploy_sync(apiserver, client):
here = Path(__file__).parent
deployments = client.sync.apiserver.deployments()
with open(here / "deployments" / "deployment2.yml") as f:
deployments.create(f)

assert "TestDeployment2" in client.sync.apiserver.status().deployments
23 changes: 23 additions & 0 deletions e2e_tests/apiserver/test_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest


@pytest.mark.asyncio
async def test_status_down(client):
res = await client.apiserver.status()
assert res.status.value == "Down"


def test_status_down_sync(client):
res = client.sync.apiserver.status()
assert res.status.value == "Down"


@pytest.mark.asyncio
async def test_status_up(apiserver, client):
res = await client.apiserver.status()
assert res.status.value == "Healthy"


def test_status_up_sync(apiserver, client):
res = client.sync.apiserver.status()
assert res.status.value == "Healthy"
21 changes: 21 additions & 0 deletions e2e_tests/apiserver/test_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import asyncio
from pathlib import Path

import pytest

from llama_deploy.types import TaskDefinition


@pytest.mark.asyncio
async def test_stream(apiserver, client):
here = Path(__file__).parent

with open(here / "deployments" / "deployment_streaming.yml") as f:
deployments = await client.apiserver.deployments()
deployment = await deployments.create(f)
await asyncio.sleep(5)

tasks = await deployment.tasks()
task = await tasks.create(TaskDefinition(input='{"a": "b"}'))
async for ev in task.events():
print(ev)
16 changes: 5 additions & 11 deletions llama_deploy/apiserver/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,26 @@
from typing import Any

from llama_deploy import (
AsyncLlamaDeployClient,
ControlPlaneServer,
SimpleMessageQueue,
SimpleOrchestratorConfig,
SimpleOrchestrator,
SimpleOrchestratorConfig,
WorkflowService,
WorkflowServiceConfig,
AsyncLlamaDeployClient,
)
from llama_deploy.message_queues import (
BaseMessageQueue,
SimpleMessageQueueConfig,
AWSMessageQueue,
BaseMessageQueue,
KafkaMessageQueue,
RabbitMQMessageQueue,
RedisMessageQueue,
SimpleMessageQueueConfig,
)

from .config_parser import (
Config,
SourceType,
Service,
MessageQueueConfig,
)
from .config_parser import Config, MessageQueueConfig, Service, SourceType
from .source_managers import GitSourceManager, LocalSourceManager, SourceManager


SOURCE_MANAGERS: dict[SourceType, SourceManager] = {
SourceType.git: GitSourceManager(),
SourceType.local: LocalSourceManager(),
Expand Down
24 changes: 20 additions & 4 deletions llama_deploy/apiserver/routers/deployments.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import json
from typing import AsyncGenerator

from fastapi import APIRouter, File, UploadFile, HTTPException
from fastapi import APIRouter, File, HTTPException, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
from typing import AsyncGenerator

from llama_deploy.apiserver.server import manager
from llama_deploy.apiserver.config_parser import Config
from llama_deploy.apiserver.server import manager
from llama_deploy.types import TaskDefinition


deployments_router = APIRouter(
prefix="/deployments",
)
Expand Down Expand Up @@ -144,6 +143,23 @@ async def get_task_result(
return JSONResponse(result.result if result else "")


@deployments_router.get("/{deployment_name}/tasks")
async def get_tasks(
deployment_name: str,
) -> JSONResponse:
"""Get the active sessions in a deployment and service."""
deployment = manager.get_deployment(deployment_name)
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")

tasks: list[TaskDefinition] = []
for session_def in await deployment.client.list_sessions():
session = await deployment.client.get_session(session_id=session_def.session_id)
for task_def in await session.get_tasks():
tasks.append(task_def)
return JSONResponse(tasks)


@deployments_router.get("/{deployment_name}/sessions")
async def get_sessions(
deployment_name: str,
Expand Down
7 changes: 4 additions & 3 deletions llama_deploy/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from llama_deploy.client.async_client import AsyncLlamaDeployClient
from llama_deploy.client.sync_client import LlamaDeployClient
from .async_client import AsyncLlamaDeployClient
from .sync_client import LlamaDeployClient
from .client import Client

__all__ = ["AsyncLlamaDeployClient", "LlamaDeployClient"]
__all__ = ["AsyncLlamaDeployClient", "Client", "LlamaDeployClient"]
22 changes: 22 additions & 0 deletions llama_deploy/client/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any

import httpx

from .client_settings import ClientSettings


class _BaseClient:
"""Base type for clients, to be used in Pydantic models to avoid circular imports."""

def __init__(self, **kwargs: Any) -> None:
self.settings = ClientSettings(**kwargs)

async def request(
self, method: str, url: str | httpx.URL, *args: Any, **kwargs: Any
) -> httpx.Response:
"""Performs an async HTTP request using httpx."""
verify = kwargs.pop("verify", True)
async with httpx.AsyncClient(verify=verify) as client:
response = await client.request(method, url, *args, **kwargs)
response.raise_for_status()
return response
23 changes: 23 additions & 0 deletions llama_deploy/client/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .base import _BaseClient
from .models import ApiServer


class Client(_BaseClient):
"""Fixme.
Fixme.
"""

@property
def sync(self) -> "Client":
return _SyncClient(**self.settings.model_dump())

@property
def apiserver(self) -> ApiServer:
return ApiServer.instance(client=self, id="apiserver")


class _SyncClient(Client):
@property
def apiserver(self) -> ApiServer:
return ApiServer.instance(make_sync=True, client=self, id="apiserver")
10 changes: 10 additions & 0 deletions llama_deploy/client/client_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pydantic_settings import BaseSettings, SettingsConfigDict


class ClientSettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="LLAMA_DEPLOY_")

api_server_url: str = "http://localhost:4501"
disable_ssl: bool = False
timeout: float = 120.0
poll_interval: float = 0.5
4 changes: 4 additions & 0 deletions llama_deploy/client/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .apiserver import ApiServer
from .model import Collection, Model

__all__ = ["ApiServer", "Collection", "Model"]
Loading

0 comments on commit e04269e

Please sign in to comment.