Skip to content

Commit

Permalink
refact: Simple message queue refactoring (#413)
Browse files Browse the repository at this point in the history
* cosmetics

* remove internal port concept

* derive from AbstractMessageQueue

* first draft

* rewrite unit tests

* fix tests

* skip tool tests

* fix e2e tests

* fix linter

* exclude untested code from coverage

* forgot wildcard
  • Loading branch information
masci authored Dec 20, 2024
1 parent 9a14319 commit 0022b67
Show file tree
Hide file tree
Showing 36 changed files with 591 additions and 962 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ repos:
--ignore-missing-imports,
--python-version=3.11,
]
exclude: ^(examples/|e2e_tests/)
exclude: ^(examples/|e2e_tests/|tests/tools/)

- repo: https://github.com/adamchainz/blacken-docs
rev: 1.16.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import pytest

from llama_deploy import SimpleMessageQueue
from llama_deploy import SimpleMessageQueueConfig, SimpleMessageQueueServer


@pytest.mark.asyncio
async def test_cancel_launch_server():
mq = SimpleMessageQueue(port=8009)
mq = SimpleMessageQueueServer(SimpleMessageQueueConfig(port=8009))
t = asyncio.create_task(mq.launch_server())

# Make sure the queue starts
Expand Down
7 changes: 5 additions & 2 deletions llama_deploy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from llama_deploy.control_plane import ControlPlaneConfig, ControlPlaneServer
from llama_deploy.deploy import deploy_core, deploy_workflow
from llama_deploy.message_consumers import CallableMessageConsumer
from llama_deploy.message_queues import SimpleMessageQueue, SimpleMessageQueueConfig
from llama_deploy.message_queues import (
SimpleMessageQueueConfig,
SimpleMessageQueueServer,
)
from llama_deploy.messages import QueueMessage
from llama_deploy.orchestrators import SimpleOrchestrator, SimpleOrchestratorConfig
from llama_deploy.services import (
Expand Down Expand Up @@ -52,7 +55,7 @@
# message consumers
"CallableMessageConsumer",
# message queues
"SimpleMessageQueue",
"SimpleMessageQueueServer",
"SimpleMessageQueueConfig",
# deployment
"deploy_core",
Expand Down
15 changes: 8 additions & 7 deletions llama_deploy/apiserver/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import os
import subprocess
import sys
from dotenv import dotenv_values
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Any

from dotenv import dotenv_values

from llama_deploy import (
Client,
ControlPlaneServer,
SimpleMessageQueue,
SimpleMessageQueueServer,
SimpleOrchestrator,
SimpleOrchestratorConfig,
WorkflowService,
Expand All @@ -24,6 +24,7 @@
KafkaMessageQueue,
RabbitMQMessageQueue,
RedisMessageQueue,
SimpleMessageQueue,
SimpleMessageQueueConfig,
SolaceMessageQueue,
)
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(self, *, config: Config, root_path: Path) -> None:
"""
self._name = config.name
self._path = root_path / config.name
self._simple_message_queue: SimpleMessageQueue | None = None
self._simple_message_queue_server: SimpleMessageQueueServer | None = None
self._queue_client = self._load_message_queue_client(config.message_queue)
self._control_plane_config = config.control_plane
self._control_plane = ControlPlaneServer(
Expand Down Expand Up @@ -96,10 +97,10 @@ async def start(self) -> None:
tasks = []

# Spawn SimpleMessageQueue if needed
if self._simple_message_queue:
if self._simple_message_queue_server:
# If SimpleMessageQueue was selected in the config file we take care of running the task
tasks.append(
asyncio.create_task(self._simple_message_queue.launch_server())
asyncio.create_task(self._simple_message_queue_server.launch_server())
)
# the other components need the queue to run in order to start, give the queue some time to start
# FIXME: having to await a magic number of seconds is very brittle, we should rethink the bootstrap process
Expand Down Expand Up @@ -241,8 +242,8 @@ def _load_message_queue_client(
elif cfg.type == "redis":
return RedisMessageQueue(**cfg.model_dump())
elif cfg.type == "simple":
self._simple_message_queue = SimpleMessageQueue(**cfg.model_dump())
return self._simple_message_queue.client
self._simple_message_queue_server = SimpleMessageQueueServer(cfg)
return SimpleMessageQueue(cfg) # type: ignore
elif cfg.type == "solace":
return SolaceMessageQueue(**cfg.model_dump())
else:
Expand Down
11 changes: 4 additions & 7 deletions llama_deploy/deploy/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
RabbitMQMessageQueueConfig,
RedisMessageQueue,
RedisMessageQueueConfig,
SimpleMessageQueue,
SimpleMessageQueueConfig,
SimpleMessageQueueServer,
SolaceMessageQueue,
SolaceMessageQueueConfig,
)
from llama_deploy.message_queues.simple import SimpleRemoteClientMessageQueue
from llama_deploy.message_queues.simple import SimpleMessageQueue
from llama_deploy.orchestrators.simple import (
SimpleOrchestrator,
SimpleOrchestratorConfig,
Expand All @@ -35,7 +35,7 @@


async def _deploy_local_message_queue(config: SimpleMessageQueueConfig) -> asyncio.Task:
queue = SimpleMessageQueue(**config.model_dump())
queue = SimpleMessageQueueServer(config)
task = asyncio.create_task(queue.launch_server())

# let message queue boot up
Expand All @@ -48,8 +48,6 @@ def _get_message_queue_config(config_dict: dict) -> BaseSettings:
key = next(iter(config_dict.keys()))
if key == SimpleMessageQueueConfig.__name__:
return SimpleMessageQueueConfig(**config_dict[key])
elif key == SimpleRemoteClientMessageQueue.__name__:
return SimpleMessageQueueConfig(**config_dict[key])
elif key == AWSMessageQueueConfig.__name__:
return AWSMessageQueueConfig(**config_dict[key])
elif key == KafkaMessageQueueConfig.__name__:
Expand All @@ -66,8 +64,7 @@ def _get_message_queue_config(config_dict: dict) -> BaseSettings:

def _get_message_queue_client(config: BaseSettings) -> BaseMessageQueue:
if isinstance(config, SimpleMessageQueueConfig):
queue = SimpleMessageQueue(**config.model_dump())
return queue.client
return SimpleMessageQueue(config) # type: ignore
elif isinstance(config, AWSMessageQueueConfig):
return AWSMessageQueue(**config.model_dump())
elif isinstance(config, KafkaMessageQueueConfig):
Expand Down
8 changes: 5 additions & 3 deletions llama_deploy/message_consumers/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Message consumers."""

from abc import ABC, abstractmethod
from pydantic import BaseModel, Field, ConfigDict
from typing import Any, Callable, TYPE_CHECKING, Coroutine
from typing import TYPE_CHECKING, Any, Callable, Coroutine

from pydantic import BaseModel, ConfigDict, Field

from llama_deploy.messages.base import QueueMessage
from llama_deploy.types import generate_id
Expand Down Expand Up @@ -42,7 +43,8 @@ async def _process_message(self, message: QueueMessage, **kwargs: Any) -> Any:
async def process_message(self, message: QueueMessage, **kwargs: Any) -> Any:
"""Logic for processing message."""
if message.type != self.message_type:
raise ValueError("Consumer cannot process the given kind of Message.")
msg = f"Consumer cannot process messages of type '{message.type}'."
raise ValueError(msg)
return await self._process_message(message, **kwargs)

async def start_consuming(
Expand Down
3 changes: 2 additions & 1 deletion llama_deploy/message_consumers/remote.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Optional

import httpx
from pydantic import BaseModel, Field
from typing import Any, Optional

from llama_deploy.message_consumers.base import BaseMessageQueueConsumer
from llama_deploy.messages import QueueMessage
Expand Down
8 changes: 5 additions & 3 deletions llama_deploy/message_queues/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from llama_deploy.message_queues.simple import (
SimpleMessageQueue,
SimpleMessageQueueConfig,
SimpleRemoteClientMessageQueue,
SimpleMessageQueueServer,
)
from llama_deploy.message_queues.solace import (
SolaceMessageQueue as SolaceMessageQueue,
)
from llama_deploy.message_queues.solace import (
SolaceMessageQueueConfig as SolaceMessageQueueConfig,
)

Expand All @@ -28,9 +30,9 @@
"RabbitMQMessageQueueConfig",
"RedisMessageQueue",
"RedisMessageQueueConfig",
"SimpleMessageQueue",
"SimpleMessageQueueServer",
"SimpleMessageQueueConfig",
"SimpleRemoteClientMessageQueue",
"SimpleMessageQueue",
"AWSMessageQueue",
"AWSMessageQueueConfig",
]
67 changes: 25 additions & 42 deletions llama_deploy/message_queues/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,28 @@
from abc import ABC, abstractmethod
from logging import getLogger
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Protocol,
Sequence,
)

from pydantic import BaseModel, ConfigDict

from llama_deploy.message_consumers.base import (
BaseMessageQueueConsumer,
StartConsumingCallable,
)
from llama_deploy.messages.base import QueueMessage

if TYPE_CHECKING:
from llama_deploy.message_consumers.base import (
BaseMessageQueueConsumer,
StartConsumingCallable,
)

logger = getLogger(__name__)
AsyncProcessMessageCallable = Callable[[QueueMessage], Awaitable[Any]]


class MessageProcessor(Protocol):
"""Protocol for a callable that processes messages."""

def __call__(self, message: QueueMessage, **kwargs: Any) -> None: ...


class PublishCallback(Protocol):
"""Protocol for a callable that processes messages.
TODO: Variant for Async Publish Callback.
"""

def __call__(self, message: QueueMessage, **kwargs: Any) -> None: ...
PublishCallback = (
Callable[[QueueMessage], Any] | Callable[[QueueMessage], Awaitable[Any]]
)


class AbstractMessageQueue(ABC):
Expand All @@ -56,7 +40,7 @@ async def publish(
self,
message: QueueMessage,
topic: str,
callback: Optional[PublishCallback] = None,
callback: PublishCallback | None = None,
**kwargs: Any,
) -> Any:
"""Send message to a consumer."""
Expand All @@ -76,35 +60,22 @@ async def publish(

@abstractmethod
async def register_consumer(
self, consumer: "BaseMessageQueueConsumer", topic: str | None = None
) -> "StartConsumingCallable":
self, consumer: BaseMessageQueueConsumer, topic: str | None = None
) -> StartConsumingCallable:
"""Register consumer to start consuming messages."""

@abstractmethod
async def deregister_consumer(self, consumer: "BaseMessageQueueConsumer") -> Any:
async def deregister_consumer(self, consumer: BaseMessageQueueConsumer) -> Any:
"""Deregister consumer to stop publishing messages)."""

async def get_consumers(
self,
message_type: str,
) -> Sequence["BaseMessageQueueConsumer"]:
self, message_type: str
) -> Sequence[BaseMessageQueueConsumer]:
"""Gets list of consumers according to a message type."""
raise NotImplementedError(
"`get_consumers()` is not implemented for this class."
)

@abstractmethod
async def processing_loop(self) -> None:
"""The processing loop for the service."""

@abstractmethod
async def launch_local(self) -> asyncio.Task:
"""Launch the service in-process."""

@abstractmethod
async def launch_server(self) -> None:
"""Launch the service as a server."""

@abstractmethod
async def cleanup_local(
self, message_types: List[str], *args: Any, **kwargs: Dict[str, Any]
Expand All @@ -121,3 +92,15 @@ class BaseMessageQueue(BaseModel, AbstractMessageQueue):

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

@abstractmethod
async def processing_loop(self) -> None:
"""The processing loop for the service."""

@abstractmethod
async def launch_local(self) -> asyncio.Task:
"""Launch the service in-process."""

@abstractmethod
async def launch_server(self) -> None:
"""Launch the service as a server."""
8 changes: 4 additions & 4 deletions llama_deploy/message_queues/simple/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .client import SimpleMessageQueue
from .server import (
SimpleMessageQueue,
SimpleMessageQueueConfig,
SimpleRemoteClientMessageQueue,
SimpleMessageQueueServer,
)

__all__ = [
"SimpleMessageQueue",
"SimpleMessageQueueServer",
"SimpleMessageQueueConfig",
"SimpleRemoteClientMessageQueue",
"SimpleMessageQueue",
]
Loading

0 comments on commit 0022b67

Please sign in to comment.