Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact: Simple message queue refactoring #413

Merged
merged 11 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ repos:
--ignore-missing-imports,
--python-version=3.11,
]
exclude: ^(examples/|e2e_tests/)
exclude: ^(examples/|e2e_tests/|tests/tools/)

- repo: https://github.com/adamchainz/blacken-docs
rev: 1.16.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import pytest

from llama_deploy import SimpleMessageQueue
from llama_deploy import SimpleMessageQueueConfig, SimpleMessageQueueServer


@pytest.mark.asyncio
async def test_cancel_launch_server():
mq = SimpleMessageQueue(port=8009)
mq = SimpleMessageQueueServer(SimpleMessageQueueConfig(port=8009))
t = asyncio.create_task(mq.launch_server())

# Make sure the queue starts
Expand Down
7 changes: 5 additions & 2 deletions llama_deploy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from llama_deploy.control_plane import ControlPlaneConfig, ControlPlaneServer
from llama_deploy.deploy import deploy_core, deploy_workflow
from llama_deploy.message_consumers import CallableMessageConsumer
from llama_deploy.message_queues import SimpleMessageQueue, SimpleMessageQueueConfig
from llama_deploy.message_queues import (
SimpleMessageQueueConfig,
SimpleMessageQueueServer,
)
from llama_deploy.messages import QueueMessage
from llama_deploy.orchestrators import SimpleOrchestrator, SimpleOrchestratorConfig
from llama_deploy.services import (
Expand Down Expand Up @@ -52,7 +55,7 @@
# message consumers
"CallableMessageConsumer",
# message queues
"SimpleMessageQueue",
"SimpleMessageQueueServer",
"SimpleMessageQueueConfig",
# deployment
"deploy_core",
Expand Down
15 changes: 8 additions & 7 deletions llama_deploy/apiserver/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
import os
import subprocess
import sys
from dotenv import dotenv_values
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Any

from dotenv import dotenv_values

from llama_deploy import (
Client,
ControlPlaneServer,
SimpleMessageQueue,
SimpleMessageQueueServer,
SimpleOrchestrator,
SimpleOrchestratorConfig,
WorkflowService,
Expand All @@ -24,6 +24,7 @@
KafkaMessageQueue,
RabbitMQMessageQueue,
RedisMessageQueue,
SimpleMessageQueue,
SimpleMessageQueueConfig,
SolaceMessageQueue,
)
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(self, *, config: Config, root_path: Path) -> None:
"""
self._name = config.name
self._path = root_path / config.name
self._simple_message_queue: SimpleMessageQueue | None = None
self._simple_message_queue_server: SimpleMessageQueueServer | None = None
self._queue_client = self._load_message_queue_client(config.message_queue)
self._control_plane_config = config.control_plane
self._control_plane = ControlPlaneServer(
Expand Down Expand Up @@ -96,10 +97,10 @@ async def start(self) -> None:
tasks = []

# Spawn SimpleMessageQueue if needed
if self._simple_message_queue:
if self._simple_message_queue_server:
# If SimpleMessageQueue was selected in the config file we take care of running the task
tasks.append(
asyncio.create_task(self._simple_message_queue.launch_server())
asyncio.create_task(self._simple_message_queue_server.launch_server())
)
# the other components need the queue to run in order to start, give the queue some time to start
# FIXME: having to await a magic number of seconds is very brittle, we should rethink the bootstrap process
Expand Down Expand Up @@ -241,8 +242,8 @@ def _load_message_queue_client(
elif cfg.type == "redis":
return RedisMessageQueue(**cfg.model_dump())
elif cfg.type == "simple":
self._simple_message_queue = SimpleMessageQueue(**cfg.model_dump())
return self._simple_message_queue.client
self._simple_message_queue_server = SimpleMessageQueueServer(cfg)
return SimpleMessageQueue(cfg) # type: ignore
elif cfg.type == "solace":
return SolaceMessageQueue(**cfg.model_dump())
else:
Expand Down
11 changes: 4 additions & 7 deletions llama_deploy/deploy/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
RabbitMQMessageQueueConfig,
RedisMessageQueue,
RedisMessageQueueConfig,
SimpleMessageQueue,
SimpleMessageQueueConfig,
SimpleMessageQueueServer,
SolaceMessageQueue,
SolaceMessageQueueConfig,
)
from llama_deploy.message_queues.simple import SimpleRemoteClientMessageQueue
from llama_deploy.message_queues.simple import SimpleMessageQueue
from llama_deploy.orchestrators.simple import (
SimpleOrchestrator,
SimpleOrchestratorConfig,
Expand All @@ -35,7 +35,7 @@


async def _deploy_local_message_queue(config: SimpleMessageQueueConfig) -> asyncio.Task:
queue = SimpleMessageQueue(**config.model_dump())
queue = SimpleMessageQueueServer(config)
task = asyncio.create_task(queue.launch_server())

# let message queue boot up
Expand All @@ -48,8 +48,6 @@ def _get_message_queue_config(config_dict: dict) -> BaseSettings:
key = next(iter(config_dict.keys()))
if key == SimpleMessageQueueConfig.__name__:
return SimpleMessageQueueConfig(**config_dict[key])
elif key == SimpleRemoteClientMessageQueue.__name__:
return SimpleMessageQueueConfig(**config_dict[key])
elif key == AWSMessageQueueConfig.__name__:
return AWSMessageQueueConfig(**config_dict[key])
elif key == KafkaMessageQueueConfig.__name__:
Expand All @@ -66,8 +64,7 @@ def _get_message_queue_config(config_dict: dict) -> BaseSettings:

def _get_message_queue_client(config: BaseSettings) -> BaseMessageQueue:
if isinstance(config, SimpleMessageQueueConfig):
queue = SimpleMessageQueue(**config.model_dump())
return queue.client
return SimpleMessageQueue(config) # type: ignore
elif isinstance(config, AWSMessageQueueConfig):
return AWSMessageQueue(**config.model_dump())
elif isinstance(config, KafkaMessageQueueConfig):
Expand Down
8 changes: 5 additions & 3 deletions llama_deploy/message_consumers/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Message consumers."""

from abc import ABC, abstractmethod
from pydantic import BaseModel, Field, ConfigDict
from typing import Any, Callable, TYPE_CHECKING, Coroutine
from typing import TYPE_CHECKING, Any, Callable, Coroutine

from pydantic import BaseModel, ConfigDict, Field

from llama_deploy.messages.base import QueueMessage
from llama_deploy.types import generate_id
Expand Down Expand Up @@ -42,7 +43,8 @@ async def _process_message(self, message: QueueMessage, **kwargs: Any) -> Any:
async def process_message(self, message: QueueMessage, **kwargs: Any) -> Any:
"""Logic for processing message."""
if message.type != self.message_type:
raise ValueError("Consumer cannot process the given kind of Message.")
msg = f"Consumer cannot process messages of type '{message.type}'."
raise ValueError(msg)
return await self._process_message(message, **kwargs)

async def start_consuming(
Expand Down
3 changes: 2 additions & 1 deletion llama_deploy/message_consumers/remote.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Optional

import httpx
from pydantic import BaseModel, Field
from typing import Any, Optional

from llama_deploy.message_consumers.base import BaseMessageQueueConsumer
from llama_deploy.messages import QueueMessage
Expand Down
8 changes: 5 additions & 3 deletions llama_deploy/message_queues/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from llama_deploy.message_queues.simple import (
SimpleMessageQueue,
SimpleMessageQueueConfig,
SimpleRemoteClientMessageQueue,
SimpleMessageQueueServer,
)
from llama_deploy.message_queues.solace import (
SolaceMessageQueue as SolaceMessageQueue,
)
from llama_deploy.message_queues.solace import (
SolaceMessageQueueConfig as SolaceMessageQueueConfig,
)

Expand All @@ -28,9 +30,9 @@
"RabbitMQMessageQueueConfig",
"RedisMessageQueue",
"RedisMessageQueueConfig",
"SimpleMessageQueue",
"SimpleMessageQueueServer",
"SimpleMessageQueueConfig",
"SimpleRemoteClientMessageQueue",
"SimpleMessageQueue",
"AWSMessageQueue",
"AWSMessageQueueConfig",
]
67 changes: 25 additions & 42 deletions llama_deploy/message_queues/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,28 @@
from abc import ABC, abstractmethod
from logging import getLogger
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Protocol,
Sequence,
)

from pydantic import BaseModel, ConfigDict

from llama_deploy.message_consumers.base import (
BaseMessageQueueConsumer,
StartConsumingCallable,
)
from llama_deploy.messages.base import QueueMessage

if TYPE_CHECKING:
from llama_deploy.message_consumers.base import (
BaseMessageQueueConsumer,
StartConsumingCallable,
)

logger = getLogger(__name__)
AsyncProcessMessageCallable = Callable[[QueueMessage], Awaitable[Any]]


class MessageProcessor(Protocol):
"""Protocol for a callable that processes messages."""

def __call__(self, message: QueueMessage, **kwargs: Any) -> None: ...


class PublishCallback(Protocol):
"""Protocol for a callable that processes messages.

TODO: Variant for Async Publish Callback.
"""

def __call__(self, message: QueueMessage, **kwargs: Any) -> None: ...
PublishCallback = (
Callable[[QueueMessage], Any] | Callable[[QueueMessage], Awaitable[Any]]
)


class AbstractMessageQueue(ABC):
Expand All @@ -56,7 +40,7 @@ async def publish(
self,
message: QueueMessage,
topic: str,
callback: Optional[PublishCallback] = None,
callback: PublishCallback | None = None,
**kwargs: Any,
) -> Any:
"""Send message to a consumer."""
Expand All @@ -76,35 +60,22 @@ async def publish(

@abstractmethod
async def register_consumer(
self, consumer: "BaseMessageQueueConsumer", topic: str | None = None
) -> "StartConsumingCallable":
self, consumer: BaseMessageQueueConsumer, topic: str | None = None
) -> StartConsumingCallable:
"""Register consumer to start consuming messages."""

@abstractmethod
async def deregister_consumer(self, consumer: "BaseMessageQueueConsumer") -> Any:
async def deregister_consumer(self, consumer: BaseMessageQueueConsumer) -> Any:
"""Deregister consumer to stop publishing messages)."""

async def get_consumers(
self,
message_type: str,
) -> Sequence["BaseMessageQueueConsumer"]:
self, message_type: str
) -> Sequence[BaseMessageQueueConsumer]:
"""Gets list of consumers according to a message type."""
raise NotImplementedError(
"`get_consumers()` is not implemented for this class."
)

@abstractmethod
async def processing_loop(self) -> None:
"""The processing loop for the service."""

@abstractmethod
async def launch_local(self) -> asyncio.Task:
"""Launch the service in-process."""

@abstractmethod
async def launch_server(self) -> None:
"""Launch the service as a server."""

@abstractmethod
async def cleanup_local(
self, message_types: List[str], *args: Any, **kwargs: Dict[str, Any]
Expand All @@ -121,3 +92,15 @@ class BaseMessageQueue(BaseModel, AbstractMessageQueue):

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

@abstractmethod
async def processing_loop(self) -> None:
"""The processing loop for the service."""

@abstractmethod
async def launch_local(self) -> asyncio.Task:
"""Launch the service in-process."""

@abstractmethod
async def launch_server(self) -> None:
"""Launch the service as a server."""
8 changes: 4 additions & 4 deletions llama_deploy/message_queues/simple/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .client import SimpleMessageQueue
from .server import (
SimpleMessageQueue,
SimpleMessageQueueConfig,
SimpleRemoteClientMessageQueue,
SimpleMessageQueueServer,
)

__all__ = [
"SimpleMessageQueue",
"SimpleMessageQueueServer",
"SimpleMessageQueueConfig",
"SimpleRemoteClientMessageQueue",
"SimpleMessageQueue",
]
Loading
Loading