From c9f991885fa07906bb4c81c041c63e49670c24f6 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 27 Nov 2024 15:19:24 +0100 Subject: [PATCH 1/8] feat: make the KV store configurable via env vars --- llama_deploy/control_plane/config.py | 39 +++++++++++-- llama_deploy/control_plane/server.py | 59 +++++++++++-------- tests/control_plane/test_config.py | 70 +++++++++++++++++++++++ tests/control_plane/test_control_plane.py | 0 tests/control_plane/test_server.py | 10 ++++ 5 files changed, 151 insertions(+), 27 deletions(-) create mode 100644 tests/control_plane/test_config.py delete mode 100644 tests/control_plane/test_control_plane.py create mode 100644 tests/control_plane/test_server.py diff --git a/llama_deploy/control_plane/config.py b/llama_deploy/control_plane/config.py index df964a24..5325bbe5 100644 --- a/llama_deploy/control_plane/config.py +++ b/llama_deploy/control_plane/config.py @@ -1,5 +1,7 @@ from typing import List +from urllib.parse import urlparse +from llama_index.core.storage.kvstore.types import BaseKVStore from pydantic_settings import BaseSettings, SettingsConfigDict @@ -24,7 +26,36 @@ class ControlPlaneConfig(BaseSettings): @property def url(self) -> str: - if self.port: - return f"http://{self.host}:{self.port}" - else: - return f"http://{self.host}" + return f"http://{self.host}:{self.port}" + + +def parse_state_store_uri(uri: str) -> BaseKVStore: + bits = urlparse(uri) + + if bits.scheme == "redis": + try: + from llama_index.storage.kvstore.redis import RedisKVStore # type: ignore + + return RedisKVStore(uri=uri) + except ImportError: + msg = ( + f"key-value store {bits.scheme} is not available, please install the required " + "llama_index integration with 'pip install llama-index-storage-kvstore-redis'." + ) + raise ValueError(msg) + elif bits.scheme == "mongodb+srv": + try: + from llama_index.storage.kvstore.mongodb import ( # type:ignore + MongoDBKVStore, + ) + + return MongoDBKVStore(uri=uri) + except ImportError: + msg = ( + f"key-value store {bits.scheme} is not available, please install the required " + "llama_index integration with 'pip install llama-index-storage-kvstore-mongodb'." + ) + raise ValueError(msg) + else: + msg = f"key-value store '{bits.scheme}' is not supported." + raise ValueError(msg) diff --git a/llama_deploy/control_plane/server.py b/llama_deploy/control_plane/server.py index 0f85dbdf..1ff5d386 100644 --- a/llama_deploy/control_plane/server.py +++ b/llama_deploy/control_plane/server.py @@ -20,6 +20,7 @@ from llama_deploy.message_consumers.remote import RemoteMessageConsumer from llama_deploy.message_queues.base import BaseMessageQueue, PublishCallback from llama_deploy.messages.base import QueueMessage +from llama_deploy.orchestrators import SimpleOrchestrator, SimpleOrchestratorConfig from llama_deploy.orchestrators.base import BaseOrchestrator from llama_deploy.orchestrators.utils import get_result_key, get_stream_key from llama_deploy.types import ( @@ -32,7 +33,7 @@ TaskStream, ) -from .config import ControlPlaneConfig +from .config import ControlPlaneConfig, parse_state_store_uri logger = getLogger(__name__) @@ -80,13 +81,25 @@ class ControlPlaneServer(BaseControlPlane): def __init__( self, message_queue: BaseMessageQueue, - orchestrator: BaseOrchestrator, + orchestrator: BaseOrchestrator | None = None, publish_callback: PublishCallback | None = None, state_store: BaseKVStore | None = None, + state_store_uri: str | None = None, config: ControlPlaneConfig | None = None, ) -> None: - self.orchestrator = orchestrator - self.state_store = state_store or SimpleKVStore() + self._orchestrator = orchestrator or SimpleOrchestrator( + **SimpleOrchestratorConfig().model_dump() + ) + + if state_store is not None and state_store_uri is not None: + raise ValueError("Please use either 'state_store' or 'state_store_uri'.") + + if state_store: + self._state_store = state_store + elif state_store_uri: + self._state_store = parse_state_store_uri(state_store_uri) + else: + self._state_store = state_store or SimpleKVStore() self._config = config or ControlPlaneConfig() self._message_queue = message_queue @@ -286,7 +299,7 @@ async def home(self) -> Dict[str, str]: async def register_service( self, service_def: ServiceDefinition ) -> ControlPlaneConfig: - await self.state_store.aput( + await self._state_store.aput( service_def.service_name, service_def.model_dump(), collection=self._config.services_store_key, @@ -294,12 +307,12 @@ async def register_service( return self._config async def deregister_service(self, service_name: str) -> None: - await self.state_store.adelete( + await self._state_store.adelete( service_name, collection=self._config.services_store_key ) async def get_service(self, service_name: str) -> ServiceDefinition: - service_dict = await self.state_store.aget( + service_dict = await self._state_store.aget( service_name, collection=self._config.services_store_key ) if service_dict is None: @@ -308,7 +321,7 @@ async def get_service(self, service_name: str) -> ServiceDefinition: return ServiceDefinition.model_validate(service_dict) async def get_all_services(self) -> Dict[str, ServiceDefinition]: - service_dicts = await self.state_store.aget_all( + service_dicts = await self._state_store.aget_all( collection=self._config.services_store_key ) @@ -319,7 +332,7 @@ async def get_all_services(self) -> Dict[str, ServiceDefinition]: async def create_session(self) -> str: session = SessionDefinition() - await self.state_store.aput( + await self._state_store.aput( session.session_id, session.model_dump(), collection=self._config.session_store_key, @@ -328,7 +341,7 @@ async def create_session(self) -> str: return session.session_id async def get_session(self, session_id: str) -> SessionDefinition: - session_dict = await self.state_store.aget( + session_dict = await self._state_store.aget( session_id, collection=self._config.session_store_key ) if session_dict is None: @@ -337,12 +350,12 @@ async def get_session(self, session_id: str) -> SessionDefinition: return SessionDefinition.model_validate(session_dict) async def delete_session(self, session_id: str) -> None: - await self.state_store.adelete( + await self._state_store.adelete( session_id, collection=self._config.session_store_key ) async def get_all_sessions(self) -> Dict[str, SessionDefinition]: - session_dicts = await self.state_store.aget_all( + session_dicts = await self._state_store.aget_all( collection=self._config.session_store_key ) @@ -367,7 +380,7 @@ async def get_current_task(self, session_id: str) -> Optional[TaskDefinition]: async def add_task_to_session( self, session_id: str, task_def: TaskDefinition ) -> str: - session_dict = await self.state_store.aget( + session_dict = await self._state_store.aget( session_id, collection=self._config.session_store_key ) if session_dict is None: @@ -378,11 +391,11 @@ async def add_task_to_session( session = SessionDefinition(**session_dict) session.task_ids.append(task_def.task_id) - await self.state_store.aput( + await self._state_store.aput( session_id, session.model_dump(), collection=self._config.session_store_key ) - await self.state_store.aput( + await self._state_store.aput( task_def.task_id, task_def.model_dump(), collection=self._config.tasks_store_key, @@ -398,7 +411,7 @@ async def send_task_to_service(self, task_def: TaskDefinition) -> TaskDefinition session = await self.get_session(task_def.session_id) - next_messages, session_state = await self.orchestrator.get_next_messages( + next_messages, session_state = await self._orchestrator.get_next_messages( task_def, session.state ) @@ -409,7 +422,7 @@ async def send_task_to_service(self, task_def: TaskDefinition) -> TaskDefinition session.state.update(session_state) - await self.state_store.aput( + await self._state_store.aput( task_def.session_id, session.model_dump(), collection=self._config.session_store_key, @@ -427,11 +440,11 @@ async def handle_service_completion( raise ValueError(f"Task with id {task_result.task_id} has no session") session = await self.get_session(task_def.session_id) - state = await self.orchestrator.add_result_to_state(task_result, session.state) + state = await self._orchestrator.add_result_to_state(task_result, session.state) # update session state session.state.update(state) - await self.state_store.aput( + await self._state_store.aput( session.session_id, session.model_dump(), collection=self._config.session_store_key, @@ -440,14 +453,14 @@ async def handle_service_completion( # generate and send new tasks when needed task_def = await self.send_task_to_service(task_def) - await self.state_store.aput( + await self._state_store.aput( task_def.task_id, task_def.model_dump(), collection=self._config.tasks_store_key, ) async def get_task(self, task_id: str) -> TaskDefinition: - state_dict = await self.state_store.aget( + state_dict = await self._state_store.aget( task_id, collection=self._config.tasks_store_key ) if state_dict is None: @@ -506,7 +519,7 @@ async def add_stream_to_session(self, task_stream: TaskStream) -> None: session.state[get_stream_key(task_stream.task_id)] = existing_stream # update session state in store - await self.state_store.aput( + await self._state_store.aput( task_stream.session_id, session.model_dump(), collection=self._config.session_store_key, @@ -592,7 +605,7 @@ async def update_session_state( session = await self.get_session(session_id) session.state.update(state) - await self.state_store.aput( + await self._state_store.aput( session_id, session.model_dump(), collection=self._config.session_store_key ) diff --git a/tests/control_plane/test_config.py b/tests/control_plane/test_config.py new file mode 100644 index 00000000..f42201be --- /dev/null +++ b/tests/control_plane/test_config.py @@ -0,0 +1,70 @@ +from typing import Any +from unittest import mock + +import pytest + +from llama_deploy.control_plane import ControlPlaneConfig +from llama_deploy.control_plane.config import parse_state_store_uri + + +def test_config_url() -> None: + cfg = ControlPlaneConfig(host="localhost", port=4242) + assert cfg.url == "http://localhost:4242" + + +def test_parse_state_store_uri_malformed() -> None: + with pytest.raises(ValueError, match="key-value store '' is not supported."): + parse_state_store_uri("some_wrong_uri") + + with pytest.raises(ValueError, match="key-value store 'foo' is not supported."): + parse_state_store_uri("foo://user:pass@host/database") + + +def test_parse_state_store_uri_redis_not_installed(monkeypatch: Any) -> None: + try: + # Ensure the module is never available, even if the package is installed + monkeypatch.delattr("llama_index.storage.kvstore.redis") + except Exception: + pass + + with pytest.raises( + ValueError, match="pip install llama-index-storage-kvstore-redis" + ): + parse_state_store_uri("redis://localhost/") + + +def test_parse_state_store_uri_redis() -> None: + redis_mock = mock.MagicMock() + + with mock.patch.dict( + "sys.modules", {"llama_index.storage.kvstore.redis": redis_mock} + ): + parse_state_store_uri("redis://localhost/") + calls = redis_mock.mock_calls + assert len(calls) == 1 + assert calls[0].kwargs == {"uri": "redis://localhost/"} + + +def test_parse_state_store_uri_mongodb_not_installed(monkeypatch: Any) -> None: + try: + # Ensure the module is never available, even if the package is installed + monkeypatch.delattr("llama_index.storage.kvstore.mongodb") + except Exception: + pass + + with pytest.raises( + ValueError, match="pip install llama-index-storage-kvstore-mongodb" + ): + parse_state_store_uri("mongodb+srv://localhost/") + + +def test_parse_state_store_uri_mongodb() -> None: + redis_mock = mock.MagicMock() + + with mock.patch.dict( + "sys.modules", {"llama_index.storage.kvstore.mongodb": redis_mock} + ): + parse_state_store_uri("mongodb+srv://localhost/") + calls = redis_mock.mock_calls + assert len(calls) == 1 + assert calls[0].kwargs == {"uri": "mongodb+srv://localhost/"} diff --git a/tests/control_plane/test_control_plane.py b/tests/control_plane/test_control_plane.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/control_plane/test_server.py b/tests/control_plane/test_server.py new file mode 100644 index 00000000..3d93bcdc --- /dev/null +++ b/tests/control_plane/test_server.py @@ -0,0 +1,10 @@ +from llama_deploy.control_plane import ControlPlaneServer +from llama_deploy.message_queues import SimpleMessageQueue + + +def test_control_plane_init() -> None: + cp = ControlPlaneServer(SimpleMessageQueue()) + assert cp._orchestrator is not None + assert cp._publish_callback is None + assert cp._state_store is not None + assert cp._config is not None From f6fdb1401f0e5a88bda6f6b158463523df91e815 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 27 Nov 2024 15:43:42 +0100 Subject: [PATCH 2/8] more unit tests --- llama_deploy/control_plane/base.py | 17 --------------- tests/control_plane/test_server.py | 33 ++++++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/llama_deploy/control_plane/base.py b/llama_deploy/control_plane/base.py index 48da02cd..27ef32ae 100644 --- a/llama_deploy/control_plane/base.py +++ b/llama_deploy/control_plane/base.py @@ -45,7 +45,6 @@ def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer: Returns: BaseMessageQueueConsumer: Message queue consumer. """ - ... @abstractmethod async def register_service( @@ -57,7 +56,6 @@ async def register_service( Args: service_def (ServiceDefinition): Definition of the service. """ - ... @abstractmethod async def deregister_service(self, service_name: str) -> None: @@ -67,7 +65,6 @@ async def deregister_service(self, service_name: str) -> None: Args: service_name (str): Name of the service. """ - ... @abstractmethod async def get_service(self, service_name: str) -> ServiceDefinition: @@ -80,7 +77,6 @@ async def get_service(self, service_name: str) -> ServiceDefinition: Returns: ServiceDefinition: Definition of the service. """ - ... @abstractmethod async def get_all_services(self) -> Dict[str, ServiceDefinition]: @@ -90,7 +86,6 @@ async def get_all_services(self) -> Dict[str, ServiceDefinition]: Returns: dict: All services, mapped from service name to service definition. """ - ... @abstractmethod async def create_session(self) -> str: @@ -100,7 +95,6 @@ async def create_session(self) -> str: Returns: str: Session ID. """ - ... @abstractmethod async def add_task_to_session( @@ -116,7 +110,6 @@ async def add_task_to_session( Returns: str: Task ID. """ - ... @abstractmethod async def send_task_to_service(self, task_def: TaskDefinition) -> TaskDefinition: @@ -129,7 +122,6 @@ async def send_task_to_service(self, task_def: TaskDefinition) -> TaskDefinition Returns: TaskDefinition: Task definition with updated state. """ - ... @abstractmethod async def handle_service_completion( @@ -142,7 +134,6 @@ async def handle_service_completion( Args: task_result (TaskResult): Result of the task. """ - ... @abstractmethod async def get_session(self, session_id: str) -> SessionDefinition: @@ -155,7 +146,6 @@ async def get_session(self, session_id: str) -> SessionDefinition: Returns: SessionDefinition: The session definition. """ - ... @abstractmethod async def delete_session(self, session_id: str) -> None: @@ -165,7 +155,6 @@ async def delete_session(self, session_id: str) -> None: Args: session_id (str): Unique identifier of the session. """ - ... @abstractmethod async def get_all_sessions(self) -> Dict[str, SessionDefinition]: @@ -175,7 +164,6 @@ async def get_all_sessions(self) -> Dict[str, SessionDefinition]: Returns: dict: All sessions, mapped from session ID to session definition. """ - ... @abstractmethod async def get_session_tasks(self, session_id: str) -> List[TaskDefinition]: @@ -188,7 +176,6 @@ async def get_session_tasks(self, session_id: str) -> List[TaskDefinition]: Returns: List[TaskDefinition]: All tasks in the session. """ - ... @abstractmethod async def get_current_task(self, session_id: str) -> Optional[TaskDefinition]: @@ -201,7 +188,6 @@ async def get_current_task(self, session_id: str) -> Optional[TaskDefinition]: Returns: Optional[TaskDefinition]: The current task, if any. """ - ... @abstractmethod async def get_task(self, task_id: str) -> TaskDefinition: @@ -214,7 +200,6 @@ async def get_task(self, task_id: str) -> TaskDefinition: Returns: TaskDefinition: The task definition. """ - ... @abstractmethod async def get_message_queue_config(self) -> Dict[str, dict]: @@ -224,14 +209,12 @@ async def get_message_queue_config(self) -> Dict[str, dict]: Returns: Dict[str, dict]: A dict of message queue name -> config dict """ - ... @abstractmethod async def launch_server(self) -> None: """ Launch the control plane server. """ - ... @abstractmethod async def register_to_message_queue(self) -> StartConsumingCallable: diff --git a/tests/control_plane/test_server.py b/tests/control_plane/test_server.py index 3d93bcdc..86c4535a 100644 --- a/tests/control_plane/test_server.py +++ b/tests/control_plane/test_server.py @@ -1,10 +1,39 @@ +from unittest import mock + +import pytest + from llama_deploy.control_plane import ControlPlaneServer from llama_deploy.message_queues import SimpleMessageQueue def test_control_plane_init() -> None: - cp = ControlPlaneServer(SimpleMessageQueue()) + mq = SimpleMessageQueue() + cp = ControlPlaneServer(mq) assert cp._orchestrator is not None - assert cp._publish_callback is None assert cp._state_store is not None assert cp._config is not None + + assert cp.message_queue == mq + assert cp.publisher_id.startswith("ControlPlaneServer-") + assert cp.publish_callback is None + + assert cp.get_topic("msg_type") == "llama_deploy.msg_type" + + +def test_control_plane_init_state_store() -> None: + mocked_store = mock.MagicMock() + with pytest.raises(ValueError): + ControlPlaneServer( + SimpleMessageQueue(), + state_store=mocked_store, + state_store_uri="test/uri", + ) + + cp = ControlPlaneServer(SimpleMessageQueue(), state_store=mocked_store) + assert cp._state_store == mocked_store + + with mock.patch( + "llama_deploy.control_plane.server.parse_state_store_uri" + ) as mocked_parse: + ControlPlaneServer(SimpleMessageQueue(), state_store_uri="test/uri") + mocked_parse.assert_called_with("test/uri") From 15052c3c2c44336bbdc00179bd0d78aa34954da7 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 27 Nov 2024 15:59:11 +0100 Subject: [PATCH 3/8] move kvstore uri in the config --- llama_deploy/control_plane/config.py | 1 + llama_deploy/control_plane/server.py | 9 ++++----- tests/control_plane/test_server.py | 8 +++++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/llama_deploy/control_plane/config.py b/llama_deploy/control_plane/config.py index 5325bbe5..64479a95 100644 --- a/llama_deploy/control_plane/config.py +++ b/llama_deploy/control_plane/config.py @@ -23,6 +23,7 @@ class ControlPlaneConfig(BaseSettings): running: bool = True cors_origins: List[str] | None = None topic_namespace: str = "llama_deploy" + state_store_uri: str | None = None @property def url(self) -> str: diff --git a/llama_deploy/control_plane/server.py b/llama_deploy/control_plane/server.py index 1ff5d386..c998700b 100644 --- a/llama_deploy/control_plane/server.py +++ b/llama_deploy/control_plane/server.py @@ -84,24 +84,23 @@ def __init__( orchestrator: BaseOrchestrator | None = None, publish_callback: PublishCallback | None = None, state_store: BaseKVStore | None = None, - state_store_uri: str | None = None, config: ControlPlaneConfig | None = None, ) -> None: self._orchestrator = orchestrator or SimpleOrchestrator( **SimpleOrchestratorConfig().model_dump() ) + self._config = config or ControlPlaneConfig() - if state_store is not None and state_store_uri is not None: + if state_store is not None and self._config.state_store_uri is not None: raise ValueError("Please use either 'state_store' or 'state_store_uri'.") if state_store: self._state_store = state_store - elif state_store_uri: - self._state_store = parse_state_store_uri(state_store_uri) + elif self._config.state_store_uri: + self._state_store = parse_state_store_uri(self._config.state_store_uri) else: self._state_store = state_store or SimpleKVStore() - self._config = config or ControlPlaneConfig() self._message_queue = message_queue self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}" self._publish_callback = publish_callback diff --git a/tests/control_plane/test_server.py b/tests/control_plane/test_server.py index 86c4535a..a61cd581 100644 --- a/tests/control_plane/test_server.py +++ b/tests/control_plane/test_server.py @@ -2,7 +2,7 @@ import pytest -from llama_deploy.control_plane import ControlPlaneServer +from llama_deploy.control_plane import ControlPlaneConfig, ControlPlaneServer from llama_deploy.message_queues import SimpleMessageQueue @@ -26,7 +26,7 @@ def test_control_plane_init_state_store() -> None: ControlPlaneServer( SimpleMessageQueue(), state_store=mocked_store, - state_store_uri="test/uri", + config=ControlPlaneConfig(state_store_uri="test/uri"), ) cp = ControlPlaneServer(SimpleMessageQueue(), state_store=mocked_store) @@ -35,5 +35,7 @@ def test_control_plane_init_state_store() -> None: with mock.patch( "llama_deploy.control_plane.server.parse_state_store_uri" ) as mocked_parse: - ControlPlaneServer(SimpleMessageQueue(), state_store_uri="test/uri") + ControlPlaneServer( + SimpleMessageQueue(), config=ControlPlaneConfig(state_store_uri="test/uri") + ) mocked_parse.assert_called_with("test/uri") From cb88a62387964bde18a467ccf8804088f5dd1926 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 27 Nov 2024 17:15:58 +0100 Subject: [PATCH 4/8] add redis example --- examples/redis_state_store/README.md | 69 +++++++++++++++++++ examples/redis_state_store/docker-compose.yml | 12 ++++ examples/redis_state_store/redis_store.yml | 15 ++++ examples/redis_state_store/requirements.txt | 1 + examples/redis_state_store/src/workflow.py | 24 +++++++ 5 files changed, 121 insertions(+) create mode 100644 examples/redis_state_store/README.md create mode 100644 examples/redis_state_store/docker-compose.yml create mode 100644 examples/redis_state_store/redis_store.yml create mode 100644 examples/redis_state_store/requirements.txt create mode 100644 examples/redis_state_store/src/workflow.py diff --git a/examples/redis_state_store/README.md b/examples/redis_state_store/README.md new file mode 100644 index 00000000..97b281f5 --- /dev/null +++ b/examples/redis_state_store/README.md @@ -0,0 +1,69 @@ +# Using Redis as State Store + +> [!NOTE] +> This example is mostly based on the [Quick Start](../quick_start/README.md), see there for more details. + +We'll be deploying a simple workflow on a local instance of Llama Deploy using Redis as a scalable storage for the +global state. This is mostly needed when you have more than one control plane running concurrently. + +Before starting Llama Deploy, use Docker compose to start the Redis container and run it in the background: + +``` +$ docker compose up -d +``` + +Make sure to install the package to support the Redis KV store in the virtual environment where we'll run Llama Deploy: + +``` +$ pip install -r requirements.txt +``` + +This is the code defining our deployment, with comments to the relevant bits: + +```yaml +name: QuickStart + +control-plane: + port: 8000 + # Here we tell the Control Plane to use Redis + state_store_uri: redis://localhost:6379 + +default-service: echo_workflow + +services: + echo_workflow: + name: Echo Workflow + source: + type: local + name: ./src + path: workflow:echo_workflow +``` + +Note how we provide a connection URI for Redis in the `state_store_uri` field of the control plane configuration. + +At this point we have all we need to run this deployment. Ideally, we would have the API server already running +somewhere in the cloud, but to get started let's start an instance locally. Run the following python script +from a shell: + +``` +$ python -m llama_deploy.apiserver +INFO: Started server process [10842] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://0.0.0.0:4501 (Press CTRL+C to quit) +``` + +From another shell, use the CLI, `llamactl`, to create the deployment: + +``` +$ llamactl deploy quick_start.yml +Deployment successful: QuickStart +``` + +Our workflow is now part of the `QuickStart` deployment and ready to serve requests! We can use `llamactl` to interact +with this deployment: + +``` +$ llamactl run --deployment QuickStart --arg message 'Hello from my shell!' +Message received: Hello from my shell! +``` diff --git a/examples/redis_state_store/docker-compose.yml b/examples/redis_state_store/docker-compose.yml new file mode 100644 index 00000000..7fa9ea5f --- /dev/null +++ b/examples/redis_state_store/docker-compose.yml @@ -0,0 +1,12 @@ +services: + redis: + # Use as KV store + image: redis:latest + hostname: redis + ports: + - "6379:6379" + healthcheck: + test: redis-cli --raw incr ping + interval: 5s + timeout: 3s + retries: 5 diff --git a/examples/redis_state_store/redis_store.yml b/examples/redis_state_store/redis_store.yml new file mode 100644 index 00000000..06d85861 --- /dev/null +++ b/examples/redis_state_store/redis_store.yml @@ -0,0 +1,15 @@ +name: QuickStart + +control-plane: + port: 8000 + state_store_uri: redis://localhost:6379 + +default-service: dummy_workflow + +services: + dummy_workflow: + name: Dummy Workflow + source: + type: local + name: src + path: workflow:echo_workflow diff --git a/examples/redis_state_store/requirements.txt b/examples/redis_state_store/requirements.txt new file mode 100644 index 00000000..05f76689 --- /dev/null +++ b/examples/redis_state_store/requirements.txt @@ -0,0 +1 @@ +llama-index-storage-kvstore-redis diff --git a/examples/redis_state_store/src/workflow.py b/examples/redis_state_store/src/workflow.py new file mode 100644 index 00000000..88f2e970 --- /dev/null +++ b/examples/redis_state_store/src/workflow.py @@ -0,0 +1,24 @@ +import asyncio + +from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step + + +# create a dummy workflow +class EchoWorkflow(Workflow): + """A dummy workflow with only one step sending back the input given.""" + + @step() + async def run_step(self, ev: StartEvent) -> StopEvent: + message = str(ev.get("message", "")) + return StopEvent(result=f"Message received: {message}") + + +echo_workflow = EchoWorkflow() + + +async def main(): + print(await echo_workflow.run(message="Hello!")) + + +if __name__ == "__main__": + asyncio.run(main()) From 7b3b6f4616e54b5b70b758279e30214628810983 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 27 Nov 2024 17:20:42 +0100 Subject: [PATCH 5/8] fix Redis creation --- llama_deploy/control_plane/config.py | 2 +- tests/control_plane/test_config.py | 23 +++++++---------------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/llama_deploy/control_plane/config.py b/llama_deploy/control_plane/config.py index 64479a95..81906168 100644 --- a/llama_deploy/control_plane/config.py +++ b/llama_deploy/control_plane/config.py @@ -37,7 +37,7 @@ def parse_state_store_uri(uri: str) -> BaseKVStore: try: from llama_index.storage.kvstore.redis import RedisKVStore # type: ignore - return RedisKVStore(uri=uri) + return RedisKVStore(redis_uri=uri) except ImportError: msg = ( f"key-value store {bits.scheme} is not available, please install the required " diff --git a/tests/control_plane/test_config.py b/tests/control_plane/test_config.py index f42201be..b3cad7d7 100644 --- a/tests/control_plane/test_config.py +++ b/tests/control_plane/test_config.py @@ -1,4 +1,3 @@ -from typing import Any from unittest import mock import pytest @@ -20,13 +19,9 @@ def test_parse_state_store_uri_malformed() -> None: parse_state_store_uri("foo://user:pass@host/database") -def test_parse_state_store_uri_redis_not_installed(monkeypatch: Any) -> None: - try: - # Ensure the module is never available, even if the package is installed - monkeypatch.delattr("llama_index.storage.kvstore.redis") - except Exception: - pass - +# Ensure the module is never available, even if the package is installed +@mock.patch.dict("sys.modules", {"llama_index.storage.kvstore.redis": None}) +def test_parse_state_store_uri_redis_not_installed() -> None: with pytest.raises( ValueError, match="pip install llama-index-storage-kvstore-redis" ): @@ -42,16 +37,12 @@ def test_parse_state_store_uri_redis() -> None: parse_state_store_uri("redis://localhost/") calls = redis_mock.mock_calls assert len(calls) == 1 - assert calls[0].kwargs == {"uri": "redis://localhost/"} - + assert calls[0].kwargs == {"redis_uri": "redis://localhost/"} -def test_parse_state_store_uri_mongodb_not_installed(monkeypatch: Any) -> None: - try: - # Ensure the module is never available, even if the package is installed - monkeypatch.delattr("llama_index.storage.kvstore.mongodb") - except Exception: - pass +# Ensure the module is never available, even if the package is installed +@mock.patch.dict("sys.modules", {"llama_index.storage.kvstore.mongodb": None}) +def test_parse_state_store_uri_mongodb_not_installed() -> None: with pytest.raises( ValueError, match="pip install llama-index-storage-kvstore-mongodb" ): From 6ec8ba8654ec1a5d1081677066c9b637f28fdbdf Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 28 Nov 2024 11:58:47 +0100 Subject: [PATCH 6/8] add more docs --- .../llama_deploy/control_plane.md | 2 ++ .../llama_deploy/20_core_components.md | 15 ++++++++++++++ examples/redis_state_store/README.md | 3 ++- llama_deploy/control_plane/config.py | 20 +++++++++++++++---- 4 files changed, 35 insertions(+), 5 deletions(-) diff --git a/docs/docs/api_reference/llama_deploy/control_plane.md b/docs/docs/api_reference/llama_deploy/control_plane.md index ebd3d9ce..5b0c7b2e 100644 --- a/docs/docs/api_reference/llama_deploy/control_plane.md +++ b/docs/docs/api_reference/llama_deploy/control_plane.md @@ -1,3 +1,5 @@ # `control_plane` ::: llama_deploy.control_plane + options: + show_docstring_parameters: true diff --git a/docs/docs/module_guides/llama_deploy/20_core_components.md b/docs/docs/module_guides/llama_deploy/20_core_components.md index 75d30f6f..e85bc7df 100644 --- a/docs/docs/module_guides/llama_deploy/20_core_components.md +++ b/docs/docs/module_guides/llama_deploy/20_core_components.md @@ -56,6 +56,21 @@ The control plane is responsible for managing the state of the system, including - Handling service completion. - Launching the control plane server. +The state of the system is persisted in a key-value store that by default consists of a simple mapping in memory. +In particular, the state store contains: + +- The name and definition of the registered services. +- The active sessions and their relative tasks and event streams. +- The Context, in case the service is of type Workflow, + +In case you need a more scalable storage for the system state, you can set the `state_store_uri` field in the Control +Plane configuration to point to one of the databases we support (see +[the Python API reference](../../api_reference/llama_deploy/control_plane.md)) for more details. +Using a scalable storage for the global state is mostly needed when: +- You want to scale the control plane horizontally, and you want every instance to share the same global state. +- The control plane has to deal with high traffic (many services, sessions and tasks). +- The global state needs to be persisted across restarts (for example, workflow contexts are stored in the global state). + ## Service The general structure of a service is as follows: diff --git a/examples/redis_state_store/README.md b/examples/redis_state_store/README.md index 97b281f5..a48d8956 100644 --- a/examples/redis_state_store/README.md +++ b/examples/redis_state_store/README.md @@ -4,7 +4,8 @@ > This example is mostly based on the [Quick Start](../quick_start/README.md), see there for more details. We'll be deploying a simple workflow on a local instance of Llama Deploy using Redis as a scalable storage for the -global state. This is mostly needed when you have more than one control plane running concurrently. +global state. See [the Control Plane documentation](https://docs.llamaindex.ai/en/stable/module_guides/llama_deploy/20_core_components/#control-plane) +for an overview of what the global state consists of and when the default storage might not be enough. Before starting Llama Deploy, use Docker compose to start the Redis container and run it in the background: diff --git a/llama_deploy/control_plane/config.py b/llama_deploy/control_plane/config.py index 81906168..08f219bf 100644 --- a/llama_deploy/control_plane/config.py +++ b/llama_deploy/control_plane/config.py @@ -2,6 +2,7 @@ from urllib.parse import urlparse from llama_index.core.storage.kvstore.types import BaseKVStore +from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -16,14 +17,25 @@ class ControlPlaneConfig(BaseSettings): tasks_store_key: str = "tasks" session_store_key: str = "sessions" step_interval: float = 0.1 - host: str = "127.0.0.1" - port: int = 8000 + host: str = Field( + default="127.0.0.1", + description="The host where to run the control plane server", + ) + port: int = Field( + default=8000, description="The TCP port where to bind the control plane server" + ) internal_host: str | None = None internal_port: int | None = None running: bool = True cors_origins: List[str] | None = None - topic_namespace: str = "llama_deploy" - state_store_uri: str | None = None + topic_namespace: str = Field( + default="llama_deploy", + description="The prefix used in the message queue topic to namespace messages from this control plane", + ) + state_store_uri: str | None = Field( + default=None, + description="The connection URI of the database where to store state. If None, SimpleKVStore will be used", + ) @property def url(self) -> str: From 906c0c5217cf6b7356451e1336e861eddc9907aa Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 28 Nov 2024 13:27:27 +0100 Subject: [PATCH 7/8] use state in the example --- examples/redis_state_store/README.md | 33 ++++++++++++++-------- examples/redis_state_store/redis_store.yml | 10 +++---- examples/redis_state_store/src/workflow.py | 16 ++++++----- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/examples/redis_state_store/README.md b/examples/redis_state_store/README.md index a48d8956..baa858a7 100644 --- a/examples/redis_state_store/README.md +++ b/examples/redis_state_store/README.md @@ -22,22 +22,22 @@ $ pip install -r requirements.txt This is the code defining our deployment, with comments to the relevant bits: ```yaml -name: QuickStart +name: RedisStateStore control-plane: port: 8000 # Here we tell the Control Plane to use Redis state_store_uri: redis://localhost:6379 -default-service: echo_workflow +default-service: counter_workflow_service services: - echo_workflow: - name: Echo Workflow + counter_workflow_service: + name: Counter Workflow source: type: local - name: ./src - path: workflow:echo_workflow + name: src + path: workflow:counter_workflow ``` Note how we provide a connection URI for Redis in the `state_store_uri` field of the control plane configuration. @@ -57,14 +57,23 @@ INFO: Uvicorn running on http://0.0.0.0:4501 (Press CTRL+C to quit) From another shell, use the CLI, `llamactl`, to create the deployment: ``` -$ llamactl deploy quick_start.yml -Deployment successful: QuickStart +$ llamactl deploy redis_store.yml +Deployment successful: RedisStateStore ``` -Our workflow is now part of the `QuickStart` deployment and ready to serve requests! We can use `llamactl` to interact -with this deployment: +Our workflow is now part of the `RedisStateStore` deployment and ready to serve requests! Since we want to persist +a counter across workflow runs, first we manually create a session: ``` -$ llamactl run --deployment QuickStart --arg message 'Hello from my shell!' -Message received: Hello from my shell! +$ llamactl sessions create -d RedisStateStore +session_id='' task_ids=[] state={} +``` + +Then we run the workflow multiple times, always using the same session we created in the previous step: + +``` +$ lamactl run --deployment RedisStateStore --arg amount 3 -i +Current balance: 3.0 +$ lamactl run --deployment RedisStateStore --arg amount 3 -i +Current balance: 3.5 ``` diff --git a/examples/redis_state_store/redis_store.yml b/examples/redis_state_store/redis_store.yml index 06d85861..763e549e 100644 --- a/examples/redis_state_store/redis_store.yml +++ b/examples/redis_state_store/redis_store.yml @@ -1,15 +1,15 @@ -name: QuickStart +name: RedisStateStore control-plane: port: 8000 state_store_uri: redis://localhost:6379 -default-service: dummy_workflow +default-service: counter_workflow_service services: - dummy_workflow: - name: Dummy Workflow + counter_workflow_service: + name: Counter Workflow source: type: local name: src - path: workflow:echo_workflow + path: workflow:counter_workflow diff --git a/examples/redis_state_store/src/workflow.py b/examples/redis_state_store/src/workflow.py index 88f2e970..c32827f0 100644 --- a/examples/redis_state_store/src/workflow.py +++ b/examples/redis_state_store/src/workflow.py @@ -1,23 +1,25 @@ import asyncio -from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step +from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step # create a dummy workflow -class EchoWorkflow(Workflow): +class CounterWorkflow(Workflow): """A dummy workflow with only one step sending back the input given.""" @step() - async def run_step(self, ev: StartEvent) -> StopEvent: - message = str(ev.get("message", "")) - return StopEvent(result=f"Message received: {message}") + async def run_step(self, ctx: Context, ev: StartEvent) -> StopEvent: + amount = float(ev.get("amount", 0.0)) + total = await ctx.get("total", 0.0) + amount + await ctx.set("total", total) + return StopEvent(result=f"Current balance: {total}") -echo_workflow = EchoWorkflow() +counter_workflow = CounterWorkflow() async def main(): - print(await echo_workflow.run(message="Hello!")) + print(await counter_workflow.run(message=10.0)) if __name__ == "__main__": From 08b73e32315038ed71ab23512eb829b4be04de61 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 28 Nov 2024 13:45:20 +0100 Subject: [PATCH 8/8] add sessions create command to the CLI --- llama_deploy/cli/__init__.py | 2 ++ llama_deploy/cli/run.py | 4 +++ llama_deploy/cli/sessions.py | 34 +++++++++++++++++++ llama_deploy/client/models/apiserver.py | 2 ++ tests/cli/test_run.py | 4 +-- tests/cli/test_sessions.py | 45 +++++++++++++++++++++++++ 6 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 llama_deploy/cli/sessions.py create mode 100644 tests/cli/test_sessions.py diff --git a/llama_deploy/cli/__init__.py b/llama_deploy/cli/__init__.py index b49e1c99..1bdbf166 100644 --- a/llama_deploy/cli/__init__.py +++ b/llama_deploy/cli/__init__.py @@ -2,6 +2,7 @@ from .deploy import deploy as deploy_cmd from .run import run as run_cmd +from .sessions import sessions as sessions_cmd from .status import status as status_cmd @@ -35,3 +36,4 @@ def llamactl(ctx: click.Context, server: str, insecure: bool, timeout: float) -> llamactl.add_command(deploy_cmd) llamactl.add_command(run_cmd) llamactl.add_command(status_cmd) +llamactl.add_command(sessions_cmd) diff --git a/llama_deploy/cli/run.py b/llama_deploy/cli/run.py index 65e71803..0beb385a 100644 --- a/llama_deploy/cli/run.py +++ b/llama_deploy/cli/run.py @@ -20,6 +20,7 @@ help="'key value' argument to pass to the task, e.g. '-a age 30'", ) @click.option("-s", "--service", is_flag=False, help="Service name") +@click.option("-i", "--session-id", is_flag=False, help="Session ID") @click.pass_context def run( ctx: click.Context, @@ -27,6 +28,7 @@ def run( deployment: str, arg: tuple[tuple[str, str]], service: str, + session_id: str, ) -> None: server_url, disable_ssl, timeout = global_config client = Client(api_server_url=server_url, disable_ssl=disable_ssl, timeout=timeout) @@ -34,6 +36,8 @@ def run( payload = {"input": json.dumps(dict(arg))} if service: payload["agent_id"] = service + if session_id: + payload["session_id"] = session_id try: d = client.sync.apiserver.deployments.get(deployment) diff --git a/llama_deploy/cli/sessions.py b/llama_deploy/cli/sessions.py new file mode 100644 index 00000000..125bfa03 --- /dev/null +++ b/llama_deploy/cli/sessions.py @@ -0,0 +1,34 @@ +import click + +from llama_deploy import Client + + +@click.group +def sessions() -> None: + pass + + +@click.command() +@click.pass_obj # global_config +@click.option( + "-d", "--deployment", required=True, is_flag=False, help="Deployment name" +) +@click.pass_context +def create( + ctx: click.Context, + global_config: tuple, + deployment: str, +) -> None: + server_url, disable_ssl, timeout = global_config + client = Client(api_server_url=server_url, disable_ssl=disable_ssl, timeout=timeout) + + try: + d = client.sync.apiserver.deployments.get(deployment) + session_def = d.sessions.create() + except Exception as e: + raise click.ClickException(str(e)) + + click.echo(session_def) + + +sessions.add_command(create) diff --git a/llama_deploy/client/models/apiserver.py b/llama_deploy/client/models/apiserver.py index 2197d7f6..428fa069 100644 --- a/llama_deploy/client/models/apiserver.py +++ b/llama_deploy/client/models/apiserver.py @@ -118,6 +118,8 @@ async def run(self, task: TaskDefinition) -> Any: run_url = ( f"{self.client.api_server_url}/deployments/{self.deployment_id}/tasks/run" ) + if task.session_id: + run_url += f"?session_id={task.session_id}" r = await self.client.request( "POST", diff --git a/tests/cli/test_run.py b/tests/cli/test_run.py index 41b4bc80..12a525fb 100644 --- a/tests/cli/test_run.py +++ b/tests/cli/test_run.py @@ -17,7 +17,7 @@ def test_run(runner: CliRunner) -> None: result = runner.invoke( llamactl, - ["run", "-d", "deployment_name", "-s", "service_name"], + ["run", "-d", "deployment_name", "-s", "service_name", "-i", "session_id"], ) mocked_client.assert_called_with( @@ -29,7 +29,7 @@ def test_run(runner: CliRunner) -> None: expected = TaskDefinition(agent_id="service_name", input="{}") assert expected.input == actual.input assert expected.agent_id == actual.agent_id - assert actual.session_id is None + assert actual.session_id is not None assert result.exit_code == 0 diff --git a/tests/cli/test_sessions.py b/tests/cli/test_sessions.py new file mode 100644 index 00000000..65cbc2d1 --- /dev/null +++ b/tests/cli/test_sessions.py @@ -0,0 +1,45 @@ +from unittest import mock + +import httpx +from click.testing import CliRunner + +from llama_deploy.cli import llamactl + + +def test_session_create(runner: CliRunner) -> None: + with mock.patch("llama_deploy.cli.sessions.Client") as mocked_client: + mocked_deployment = mock.MagicMock() + mocked_deployment.sessions.create.return_value = mock.MagicMock( + id="test_session" + ) + mocked_client.return_value.sync.apiserver.deployments.get.return_value = ( + mocked_deployment + ) + + result = runner.invoke( + llamactl, + ["sessions", "create", "-d", "deployment_name"], + ) + + mocked_client.assert_called_with( + api_server_url="http://localhost:4501", disable_ssl=False, timeout=120.0 + ) + + mocked_deployment.sessions.create.assert_called_once() + assert result.exit_code == 0 + + +def test_sessions_create_error(runner: CliRunner) -> None: + with mock.patch("llama_deploy.cli.sessions.Client") as mocked_client: + mocked_client.return_value.sync.apiserver.deployments.get.side_effect = ( + httpx.HTTPStatusError( + "test error", response=mock.MagicMock(), request=mock.MagicMock() + ) + ) + + result = runner.invoke( + llamactl, ["sessions", "create", "-d", "deployment_name"] + ) + + assert result.exit_code == 1 + assert result.output == "Error: test error\n"