Skip to content

Commit

Permalink
refact: Make topic explicit in message queue API (#358)
Browse files Browse the repository at this point in the history
* Make topic explicit in message queue API

* fix unit tests

* more fixes

* fix quotes in e2e code

* better naming
  • Loading branch information
masci authored Nov 13, 2024
1 parent 1fcfa2f commit 4449d9f
Show file tree
Hide file tree
Showing 39 changed files with 470 additions and 353 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.1.5
rev: v0.7.3

hooks:
- id: ruff
Expand Down
3 changes: 1 addition & 2 deletions e2e_tests/apiserver/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ async def test_stream(apiserver, client):
here = Path(__file__).parent

with open(here / "deployments" / "deployment_streaming.yml") as f:
deployments = await client.apiserver.deployments()
deployment = await deployments.create(f)
deployment = await client.apiserver.deployments.create(f)
await asyncio.sleep(5)

tasks = await deployment.tasks()
Expand Down
103 changes: 103 additions & 0 deletions e2e_tests/message_queues/message_queue_kafka/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import asyncio
import multiprocessing
import subprocess
import time
from pathlib import Path

import pytest

from llama_deploy import (
ControlPlaneConfig,
WorkflowServiceConfig,
deploy_core,
deploy_workflow,
)
from llama_deploy.message_queues import KafkaMessageQueue, KafkaMessageQueueConfig

from .workflow import BasicWorkflow


@pytest.fixture(scope="package")
def kafka_service():
compose_file = Path(__file__).resolve().parent / "docker-compose.yml"
proc = subprocess.Popen(
["docker", "compose", "-f", f"{compose_file}", "up", "-d", "--wait"]
)
proc.communicate()
yield
subprocess.Popen(["docker", "compose", "-f", f"{compose_file}", "down"])


@pytest.fixture
def mq(kafka_service):
return KafkaMessageQueue(KafkaMessageQueueConfig())


def run_workflow_one():
asyncio.run(
deploy_workflow(
BasicWorkflow(timeout=10, name="Workflow one"),
WorkflowServiceConfig(
host="127.0.0.1",
port=8003,
service_name="basic",
),
ControlPlaneConfig(topic_namespace="core_one", port=8001),
)
)


def run_workflow_two():
asyncio.run(
deploy_workflow(
BasicWorkflow(timeout=10, name="Workflow two"),
WorkflowServiceConfig(
host="127.0.0.1",
port=8004,
service_name="basic",
),
ControlPlaneConfig(topic_namespace="core_two", port=8002),
)
)


def run_core_one():
asyncio.run(
deploy_core(
ControlPlaneConfig(topic_namespace="core_one", port=8001),
KafkaMessageQueueConfig(),
)
)


def run_core_two():
asyncio.run(
deploy_core(
ControlPlaneConfig(topic_namespace="core_two", port=8002),
KafkaMessageQueueConfig(),
)
)


@pytest.fixture
def control_planes(kafka_service):
p1 = multiprocessing.Process(target=run_core_one)
p1.start()

p2 = multiprocessing.Process(target=run_core_two)
p2.start()

time.sleep(3)

p3 = multiprocessing.Process(target=run_workflow_one)
p3.start()

p4 = multiprocessing.Process(target=run_workflow_two)
p4.start()

yield

p1.kill()
p2.kill()
p3.kill()
p4.kill()
41 changes: 20 additions & 21 deletions e2e_tests/message_queues/message_queue_kafka/test_message_queue.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,12 @@
import asyncio
import subprocess
from pathlib import Path

import pytest

from llama_deploy import Client
from llama_deploy.message_consumers.callable import CallableMessageConsumer
from llama_deploy.message_queues import KafkaMessageQueue, KafkaMessageQueueConfig
from llama_deploy.messages import QueueMessage


@pytest.fixture
def kafka_service():
compose_file = Path(__file__).resolve().parent / "docker-compose.yml"
proc = subprocess.Popen(
["docker", "compose", "-f", f"{compose_file}", "up", "-d", "--wait"]
)
proc.communicate()
yield
subprocess.Popen(["docker", "compose", "-f", f"{compose_file}", "down"])


@pytest.fixture
def mq(kafka_service):
return KafkaMessageQueue(KafkaMessageQueueConfig(topic_name="test_message"))


@pytest.mark.e2e
@pytest.mark.asyncio
async def test_roundtrip(mq):
Expand All @@ -37,13 +19,13 @@ def message_handler(message: QueueMessage) -> None:
test_consumer = CallableMessageConsumer(
message_type="test_message", handler=message_handler
)
start_consuming_callable = await mq.register_consumer(test_consumer)
start_consuming_callable = await mq.register_consumer(test_consumer, topic="test")

# produce a message
test_message = QueueMessage(type="test_message", data={"message": "this is a test"})

# await asyncio.gather(start_consuming_callable(), mq.publish(test_message))
await mq.publish(test_message)
await mq.publish(test_message, topic="test")
t = asyncio.create_task(start_consuming_callable())
await asyncio.sleep(0.5)
# at this point message should've been arrived
Expand All @@ -52,3 +34,20 @@ def message_handler(message: QueueMessage) -> None:

assert len(received_messages) == 1
assert test_message in received_messages


@pytest.mark.e2e
@pytest.mark.asyncio
async def test_multiple_control_planes(control_planes):
c1 = Client(control_plane_url="http://localhost:8001")
c2 = Client(control_plane_url="http://localhost:8002")

session = await c1.core.sessions.create()
r1 = await session.run("basic", arg="Hello One!")
await c1.core.sessions.delete(session.id)
assert r1 == "Workflow one received Hello One!"

session = await c2.core.sessions.create()
r2 = await session.run("basic", arg="Hello Two!")
await c2.core.sessions.delete(session.id)
assert r2 == "Workflow two received Hello Two!"
12 changes: 12 additions & 0 deletions e2e_tests/message_queues/message_queue_kafka/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step


class BasicWorkflow(Workflow):
def __init__(self, *args, **kwargs):
self._name = kwargs.pop("name")
super().__init__(*args, **kwargs)

@step()
async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent:
received = ev.get("arg")
return StopEvent(result=f"{self._name} received {received}")
10 changes: 4 additions & 6 deletions llama_deploy/apiserver/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
}


class DeploymentError(Exception):
...
class DeploymentError(Exception): ...


class Deployment:
Expand All @@ -55,10 +54,11 @@ def __init__(self, *, config: Config, root_path: Path) -> None:
self._path = root_path / config.name
self._simple_message_queue: SimpleMessageQueue | None = None
self._queue_client = self._load_message_queue_client(config.message_queue)
self._control_plane_config = config.control_plane
self._control_plane = ControlPlaneServer(
self._queue_client,
SimpleOrchestrator(**SimpleOrchestratorConfig().model_dump()),
**config.control_plane.model_dump(),
config=config.control_plane,
)
self._workflow_services: list[WorkflowService] = self._load_services(config)
self._client = AsyncLlamaDeployClient(config.control_plane)
Expand Down Expand Up @@ -111,9 +111,7 @@ async def start(self) -> None:
service_task = asyncio.create_task(wfs.launch_server())
tasks.append(service_task)
consumer_fn = await wfs.register_to_message_queue()
control_plane_url = (
f"http://{self._control_plane.host}:{self._control_plane.port}"
)
control_plane_url = f"http://{self._control_plane_config.host}:{self._control_plane_config.port}"
await wfs.register_to_control_plane(control_plane_url)
consumer_task = asyncio.create_task(consumer_fn())
tasks.append(consumer_task)
Expand Down
32 changes: 14 additions & 18 deletions llama_deploy/client/models/apiserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ async def create(self, config: TextIO) -> Deployment:
model_class = self._prepare(Deployment)
return model_class(client=self.client, id=r.json().get("name"))

async def get(self, deployment_id: str) -> Deployment:
async def get(self, id: str) -> Deployment:
"""Gets a deployment by id."""
get_url = f"{self.client.api_server_url}/deployments/{deployment_id}"
get_url = f"{self.client.api_server_url}/deployments/{id}"
# Current version of apiserver doesn't returns anything useful in this endpoint, let's just ignore it
await self.client.request(
"GET",
Expand All @@ -237,7 +237,14 @@ async def get(self, deployment_id: str) -> Deployment:
timeout=self.client.timeout,
)
model_class = self._prepare(Deployment)
return model_class(client=self.client, id=deployment_id)
return model_class(client=self.client, id=id)

async def list(self) -> list[Deployment]:
deployments_url = f"{self.client.api_server_url}/deployments/"
r = await self.client.request("GET", deployments_url)
model_class = self._prepare(Deployment)
deployments = [model_class(client=self.client, id=name) for name in r.json()]
return deployments


class ApiServer(Model):
Expand Down Expand Up @@ -280,19 +287,8 @@ async def status(self) -> Status:
deployments=deployments,
)

async def deployments(self) -> DeploymentCollection:
@property
def deployments(self) -> DeploymentCollection:
"""Returns a collection of deployments currently active in the API Server."""
status_url = f"{self.client.api_server_url}/deployments/"

r = await self.client.request(
"GET",
status_url,
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)
model_class = self._prepare(Deployment)
deployments = {
"id": model_class(client=self.client, id=name) for name in r.json()
}
coll_model_class = self._prepare(DeploymentCollection)
return coll_model_class(client=self.client, items=deployments)
model_class = self._prepare(DeploymentCollection)
return model_class(client=self.client, items={})
2 changes: 1 addition & 1 deletion llama_deploy/client/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get(self, id: str) -> T:
"""Returns an item from the collection."""
return self.items[id]

def list(self) -> list[T]:
async def list(self) -> list[T]:
"""Returns a list of all the items in the collection."""
return [self.get(id) for id in self.items.keys()]

Expand Down
5 changes: 3 additions & 2 deletions llama_deploy/control_plane/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from llama_deploy.control_plane.base import BaseControlPlane
from llama_deploy.control_plane.server import ControlPlaneServer, ControlPlaneConfig
from .base import BaseControlPlane
from .config import ControlPlaneConfig
from .server import ControlPlaneServer

__all__ = ["BaseControlPlane", "ControlPlaneServer", "ControlPlaneConfig"]
10 changes: 7 additions & 3 deletions llama_deploy/control_plane/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

from llama_deploy.message_queues.base import BaseMessageQueue
from llama_deploy.message_consumers.base import (
BaseMessageQueueConsumer,
StartConsumingCallable,
)
from llama_deploy.message_publishers.publisher import MessageQueuePublisherMixin
from llama_deploy.message_queues.base import BaseMessageQueue
from llama_deploy.types import (
ServiceDefinition,
TaskDefinition,
SessionDefinition,
TaskDefinition,
TaskResult,
)

from .config import ControlPlaneConfig


class BaseControlPlane(MessageQueuePublisherMixin, ABC):
"""The control plane for the system.
Expand Down Expand Up @@ -46,7 +48,9 @@ def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer:
...

@abstractmethod
async def register_service(self, service_def: ServiceDefinition) -> None:
async def register_service(
self, service_def: ServiceDefinition
) -> ControlPlaneConfig:
"""
Register a service with the control plane.
Expand Down
30 changes: 30 additions & 0 deletions llama_deploy/control_plane/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List

from pydantic_settings import BaseSettings, SettingsConfigDict


class ControlPlaneConfig(BaseSettings):
"""Control plane configuration."""

model_config = SettingsConfigDict(
env_prefix="CONTROL_PLANE_", arbitrary_types_allowed=True
)

services_store_key: str = "services"
tasks_store_key: str = "tasks"
session_store_key: str = "sessions"
step_interval: float = 0.1
host: str = "127.0.0.1"
port: int = 8000
internal_host: str | None = None
internal_port: int | None = None
running: bool = True
cors_origins: List[str] | None = None
topic_namespace: str = "llama_deploy"

@property
def url(self) -> str:
if self.port:
return f"http://{self.host}:{self.port}"
else:
return f"http://{self.host}"
Loading

0 comments on commit 4449d9f

Please sign in to comment.