Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add apiserver support to Python SDK #327

Merged
merged 8 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added e2e_tests/apiserver/__init__.py
Empty file.
27 changes: 27 additions & 0 deletions e2e_tests/apiserver/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import multiprocessing
import time

import pytest
import uvicorn

from llama_deploy.client import Client


def run_async_apiserver():
uvicorn.run("llama_deploy.apiserver:app", host="127.0.0.1", port=4501)


@pytest.fixture(scope="module")
def apiserver():
p = multiprocessing.Process(target=run_async_apiserver)
p.start()
time.sleep(3)

yield

p.kill()


@pytest.fixture
def client():
return Client(api_server_url="http://localhost:4501")
15 changes: 15 additions & 0 deletions e2e_tests/apiserver/deployments/deployment1.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: TestDeployment1

control-plane: {}

default-service: dummy_workflow

services:
test-workflow:
name: Test Workflow
port: 8002
host: localhost
source:
type: git
name: https://github.com/run-llama/llama_deploy.git
path: tests/apiserver/data/workflow:my_workflow
15 changes: 15 additions & 0 deletions e2e_tests/apiserver/deployments/deployment2.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: TestDeployment2

control-plane: {}

default-service: dummy_workflow

services:
test-workflow:
name: Test Workflow
port: 8002
host: localhost
source:
type: git
name: https://github.com/run-llama/llama_deploy.git
path: tests/apiserver/data/workflow:my_workflow
14 changes: 14 additions & 0 deletions e2e_tests/apiserver/deployments/deployment_streaming.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: Streaming

control-plane:
port: 8000

default-service: streaming_workflow

services:
streaming_workflow:
name: Streaming Workflow
source:
type: local
name: ./e2e_tests/apiserver/deployments/src
path: workflow:streaming_workflow
41 changes: 41 additions & 0 deletions e2e_tests/apiserver/deployments/src/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import asyncio

from llama_index.core.workflow import (
Context,
Event,
StartEvent,
StopEvent,
Workflow,
step,
)


class Message(Event):
text: str


class EchoWorkflow(Workflow):
"""A dummy workflow streaming three events."""

@step()
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
for i in range(3):
ctx.write_event_to_stream(Message(text=f"message number {i+1}"))
await asyncio.sleep(0.5)

return StopEvent(result="Done.")


streaming_workflow = EchoWorkflow()


async def main():
h = streaming_workflow.run(message="Hello!")
async for ev in h.stream_events():
if type(ev) is Message:
print(ev.text)
print(await h)


if __name__ == "__main__":
asyncio.run(main())
23 changes: 23 additions & 0 deletions e2e_tests/apiserver/test_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

import pytest


@pytest.mark.asyncio
async def test_deploy(apiserver, client):
here = Path(__file__).parent
deployments = await client.apiserver.deployments()
with open(here / "deployments" / "deployment1.yml") as f:
await deployments.create(f)

status = await client.apiserver.status()
assert "TestDeployment1" in status.deployments


def test_deploy_sync(apiserver, client):
here = Path(__file__).parent
deployments = client.sync.apiserver.deployments()
with open(here / "deployments" / "deployment2.yml") as f:
deployments.create(f)

assert "TestDeployment2" in client.sync.apiserver.status().deployments
23 changes: 23 additions & 0 deletions e2e_tests/apiserver/test_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest


@pytest.mark.asyncio
async def test_status_down(client):
res = await client.apiserver.status()
assert res.status.value == "Down"


def test_status_down_sync(client):
res = client.sync.apiserver.status()
assert res.status.value == "Down"


@pytest.mark.asyncio
async def test_status_up(apiserver, client):
res = await client.sync.apiserver.status()
assert res.status.value == "Healthy"


def test_status_up_sync(apiserver, client):
res = client.sync.apiserver.status()
assert res.status.value == "Healthy"
27 changes: 27 additions & 0 deletions e2e_tests/apiserver/test_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import asyncio
from pathlib import Path

import pytest

from llama_deploy.types import TaskDefinition


@pytest.mark.asyncio
async def test_stream(apiserver, client):
here = Path(__file__).parent

with open(here / "deployments" / "deployment_streaming.yml") as f:
deployments = await client.apiserver.deployments()
deployment = await deployments.create(f)
await asyncio.sleep(5)

tasks = await deployment.tasks()
task = await tasks.create(TaskDefinition(input='{"a": "b"}'))
read_events = []
async for ev in task.events():
if "text" in ev:
read_events.append(ev)
assert len(read_events) == 3
# the workflow produces events sequentially, so here we can assume events arrived in order
for i, ev in enumerate(read_events):
assert ev["text"] == f"message number {i+1}"
27 changes: 14 additions & 13 deletions llama_deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from llama_deploy.client import AsyncLlamaDeployClient, LlamaDeployClient
from llama_deploy.control_plane import ControlPlaneServer, ControlPlaneConfig
# configure logger
import logging

from llama_deploy.client import AsyncLlamaDeployClient, Client, LlamaDeployClient
from llama_deploy.control_plane import ControlPlaneConfig, ControlPlaneServer
from llama_deploy.deploy import deploy_core, deploy_workflow
from llama_deploy.message_consumers import CallableMessageConsumer
from llama_deploy.message_queues import SimpleMessageQueue, SimpleMessageQueueConfig
from llama_deploy.messages import QueueMessage
from llama_deploy.orchestrators import SimpleOrchestrator, SimpleOrchestratorConfig
from llama_deploy.services import (
AgentService,
ComponentService,
HumanService,
ToolService,
WorkflowService,
WorkflowServiceConfig,
)
from llama_deploy.tools import (
AgentServiceTool,
MetaServiceTool,
ServiceAsTool,
ServiceComponent,
ServiceTool,
)
from llama_deploy.services import (
AgentService,
ToolService,
HumanService,
ComponentService,
WorkflowService,
WorkflowServiceConfig,
)

# configure logger
import logging

root_logger = logging.getLogger("llama_deploy")

Expand All @@ -39,6 +39,7 @@
# clients
"LlamaDeployClient",
"AsyncLlamaDeployClient",
"Client",
# services
"AgentService",
"HumanService",
Expand Down
16 changes: 5 additions & 11 deletions llama_deploy/apiserver/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,26 @@
from typing import Any

from llama_deploy import (
AsyncLlamaDeployClient,
ControlPlaneServer,
SimpleMessageQueue,
SimpleOrchestratorConfig,
SimpleOrchestrator,
SimpleOrchestratorConfig,
WorkflowService,
WorkflowServiceConfig,
AsyncLlamaDeployClient,
)
from llama_deploy.message_queues import (
BaseMessageQueue,
SimpleMessageQueueConfig,
AWSMessageQueue,
BaseMessageQueue,
KafkaMessageQueue,
RabbitMQMessageQueue,
RedisMessageQueue,
SimpleMessageQueueConfig,
)

from .config_parser import (
Config,
SourceType,
Service,
MessageQueueConfig,
)
from .config_parser import Config, MessageQueueConfig, Service, SourceType
from .source_managers import GitSourceManager, LocalSourceManager, SourceManager


SOURCE_MANAGERS: dict[SourceType, SourceManager] = {
SourceType.git: GitSourceManager(),
SourceType.local: LocalSourceManager(),
Expand Down
24 changes: 20 additions & 4 deletions llama_deploy/apiserver/routers/deployments.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import json
from typing import AsyncGenerator

from fastapi import APIRouter, File, UploadFile, HTTPException
from fastapi import APIRouter, File, HTTPException, UploadFile
from fastapi.responses import JSONResponse, StreamingResponse
from typing import AsyncGenerator

from llama_deploy.apiserver.server import manager
from llama_deploy.apiserver.config_parser import Config
from llama_deploy.apiserver.server import manager
from llama_deploy.types import TaskDefinition


deployments_router = APIRouter(
prefix="/deployments",
)
Expand Down Expand Up @@ -144,6 +143,23 @@ async def get_task_result(
return JSONResponse(result.result if result else "")


@deployments_router.get("/{deployment_name}/tasks")
async def get_tasks(
deployment_name: str,
) -> JSONResponse:
"""Get all the tasks from all the sessions in a given deployment."""
deployment = manager.get_deployment(deployment_name)
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")

tasks: list[TaskDefinition] = []
for session_def in await deployment.client.list_sessions():
session = await deployment.client.get_session(session_id=session_def.session_id)
for task_def in await session.get_tasks():
tasks.append(task_def)
return JSONResponse(tasks)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a dict of session_id -> list[task] ? (Just thinking about building a UI that would use this function)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, many return payloads of the apiserver API are not well thought, we can start fixing it from here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look and I realised if we change this we need to update the CLI as well and the PR would grow considerably. I propose we do that in a follow up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracked in #337

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me!



@deployments_router.get("/{deployment_name}/sessions")
async def get_sessions(
deployment_name: str,
Expand Down
7 changes: 4 additions & 3 deletions llama_deploy/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from llama_deploy.client.async_client import AsyncLlamaDeployClient
from llama_deploy.client.sync_client import LlamaDeployClient
from .async_client import AsyncLlamaDeployClient
from .sync_client import LlamaDeployClient
from .client import Client

__all__ = ["AsyncLlamaDeployClient", "LlamaDeployClient"]
__all__ = ["AsyncLlamaDeployClient", "Client", "LlamaDeployClient"]
29 changes: 29 additions & 0 deletions llama_deploy/client/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any

import httpx
from pydantic_settings import BaseSettings, SettingsConfigDict


class _BaseClient(BaseSettings):
"""Base type for clients, to be used in Pydantic models to avoid circular imports.

Settings can be passed to the Client constructor when creating an instance, or defined with environment variables
having names prefixed with the string `LLAMA_DEPLOY_`, e.g. `LLAMA_DEPLOY_DISABLE_SSL`.
"""

model_config = SettingsConfigDict(env_prefix="LLAMA_DEPLOY_")

api_server_url: str = "http://localhost:4501"
disable_ssl: bool = False
timeout: float = 120.0
poll_interval: float = 0.5

async def request(
self, method: str, url: str | httpx.URL, *args: Any, **kwargs: Any
) -> httpx.Response:
"""Performs an async HTTP request using httpx."""
verify = kwargs.pop("verify", True)
async with httpx.AsyncClient(verify=verify) as client:
response = await client.request(method, url, *args, **kwargs)
response.raise_for_status()
return response
Loading
Loading