Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example script for ToolService with Launcher #18

Merged
merged 6 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions agentfile/launchers/local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import uuid
import signal
import sys

from typing import Any, Callable, Dict, List, Optional

from agentfile.services.base import BaseService
Expand Down Expand Up @@ -70,6 +73,15 @@ async def register_consumers(
def launch_single(self, initial_task: str) -> None:
asyncio.run(self.alaunch_single(initial_task))

def get_shutdown_handler(self, tasks: List[asyncio.Task]) -> Callable:
def signal_handler(sig: Any, frame: Any) -> None:
print("\nShutting down.")
for task in tasks:
task.cancel()
sys.exit(0)

return signal_handler

async def alaunch_single(self, initial_task: str) -> None:
# register human consumer
human_consumer = HumanMessageConsumer(
Expand All @@ -79,6 +91,15 @@ async def alaunch_single(self, initial_task: str) -> None:
)
await self.register_consumers([human_consumer])

# register each service to the control plane
for service in self.services:
await self.control_plane.register_service(service.service_definition)

# start services
bg_tasks = []
for service in self.services:
bg_tasks.append(asyncio.create_task(service.launch_local()))

# publish initial task
await self.publish(
QueueMessage(
Expand All @@ -87,14 +108,10 @@ async def alaunch_single(self, initial_task: str) -> None:
data=TaskDefinition(input=initial_task).dict(),
),
)

# register each service to the control plane
for service in self.services:
await self.control_plane.register_service(service.service_definition)

# start services
for service in self.services:
asyncio.create_task(service.launch_local())

# runs until the message queue is stopped by the human consumer
await self.message_queue.start()
mq_task = asyncio.create_task(self.message_queue.start())
shutdown_handler = self.get_shutdown_handler([mq_task] + bg_tasks)
loop = asyncio.get_event_loop()
while loop.is_running():
await asyncio.sleep(0.1)
signal.signal(signal.SIGINT, shutdown_handler)
7 changes: 6 additions & 1 deletion agentfile/message_queues/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ async def _publish_to_consumer(self, message: QueueMessage, **kwargs: Any) -> An
consumer = self._select_consumer(message)
try:
await consumer.process_message(message, **kwargs)
except Exception:
except Exception as e:
logger.debug(
f"Failed to publish message of type '{message.type}' to consumer. Message: {str(e)}"
)
raise

async def start(self) -> None:
Expand All @@ -87,11 +90,13 @@ async def register_consumer(

if message_type_str not in self.consumers:
self.consumers[message_type_str] = {consumer.id_: consumer}
logger.info(f"Consumer {consumer.id_} has been registered.")
else:
if consumer.id_ in self.consumers[message_type_str]:
raise ValueError("Consumer has already been added.")

self.consumers[message_type_str][consumer.id_] = consumer
logger.info(f"Consumer {consumer.id_} has been registered.")

if message_type_str not in self.queues:
self.queues[message_type_str] = deque()
Expand Down
7 changes: 7 additions & 0 deletions agentfile/services/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
CONTROL_PLANE_NAME,
)

import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)


class AgentService(BaseService):
service_name: str
Expand Down Expand Up @@ -164,6 +170,7 @@ def as_consumer(self) -> BaseMessageQueueConsumer:
)

async def launch_local(self) -> None:
logger.info(f"{self.service_name} launch_local")
asyncio.create_task(self.processing_loop())

# ---- Server based methods ----
Expand Down
8 changes: 8 additions & 0 deletions agentfile/services/tool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging
import uuid
import uvicorn
from asyncio import Lock
Expand Down Expand Up @@ -26,6 +27,10 @@
ServiceDefinition,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)


class ToolService(BaseService):
service_name: str
Expand Down Expand Up @@ -117,6 +122,9 @@ async def processing_loop(self) -> None:
self.tools, tool_call.tool_call_bundle.tool_name
)

logger.info(
f"Processing tool call id {tool_call.id_} with {tool.metadata.name}"
)
tool_output = await tool.acall(
*tool_call.tool_call_bundle.tool_args,
**tool_call.tool_call_bundle.tool_kwargs,
Expand Down
21 changes: 11 additions & 10 deletions agentfile/tools/meta_service_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,13 @@
logging.basicConfig(level=logging.DEBUG)


class TimeoutException(Exception):
"""Raise when polling for results from message queue exceed timeout."""

pass


class MetaServiceTool(MessageQueuePublisherMixin, AsyncBaseTool, BaseModel):
tool_call_results: Dict[str, ToolCallResult] = Field(default_factory=dict)
timeout: float = Field(default=10.0, description="timeout interval in seconds.")
tool_service_name: str = Field(default_factory=str)
step_interval: float = 0.1
raise_timeout: bool = False
registered: bool = False

_message_queue: BaseMessageQueue = PrivateAttr()
_publisher_id: str = PrivateAttr()
Expand Down Expand Up @@ -132,10 +127,6 @@ def publish_callback(self) -> Optional[PublishCallback]:
def metadata(self) -> ToolMetadata:
return self._metadata

@metadata.setter
def metadata(self, value: ToolMetadata) -> None:
self._metadata = value

@property
def lock(self) -> asyncio.Lock:
return self._lock
Expand Down Expand Up @@ -174,6 +165,11 @@ async def _poll_for_tool_call_result(self, tool_call_id: str) -> ToolCallResult:
await asyncio.sleep(self.step_interval)
return tool_call_result

async def deregister(self) -> None:
"""Deregister from message queue."""
await self.message_queue.deregister_consumer(self.as_consumer())
self.registered = False

def call(self, *args: Any, **kwargs: Any) -> ToolOutput:
"""Call."""
return asyncio.run(self.acall(*args, **kwargs))
Expand All @@ -184,6 +180,11 @@ async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput:
In order to get a ToolOutput result, this will poll the queue until
the result is written.
"""
if not self.registered:
# register tool to message queue
await self.message_queue.register_consumer(self.as_consumer())
self.registered = True

tool_call = ToolCall(
tool_call_bundle=ToolCallBundle(
tool_name=self.metadata.name, tool_args=args, tool_kwargs=kwargs
Expand Down
59 changes: 59 additions & 0 deletions example_scripts/agentic_toolservice_local_single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from agentfile.launchers.local import LocalLauncher
from agentfile.services import AgentService, ToolService
from agentfile.tools import MetaServiceTool
from agentfile.control_plane.fastapi import FastAPIControlPlane
from agentfile.message_queues.simple import SimpleMessageQueue
from agentfile.orchestrators.agent import AgentOrchestrator

from llama_index.core.agent import FunctionCallingAgentWorker
from llama_index.core.tools import FunctionTool
from llama_index.llms.openai import OpenAI


# create an agent
def get_the_secret_fact() -> str:
"""Returns the secret fact."""
return "The secret fact is: A baby llama is called a 'Cria'."


tool = FunctionTool.from_defaults(fn=get_the_secret_fact)


# create our multi-agent framework components
message_queue = SimpleMessageQueue()
tool_service = ToolService(
message_queue=message_queue,
tools=[tool],
running=True,
step_interval=0.5,
)

control_plane = FastAPIControlPlane(
message_queue=message_queue,
orchestrator=AgentOrchestrator(llm=OpenAI()),
)

meta_tool = MetaServiceTool(
tool_metadata=tool.metadata,
message_queue=message_queue,
tool_service_name=tool_service.service_name,
)
worker1 = FunctionCallingAgentWorker.from_tools(
[meta_tool],
llm=OpenAI(),
)
agent1 = worker1.as_agent()
agent_server_1 = AgentService(
agent=agent1,
message_queue=message_queue,
description="Useful for getting the secret fact.",
service_name="secret_fact_agent",
)

# launch it
launcher = LocalLauncher(
[agent_server_1, tool_service],
control_plane,
message_queue,
)
launcher.launch_single("What is the secret fact?")
7 changes: 4 additions & 3 deletions tests/test_meta_service_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ async def test_init(

# assert
assert meta_service_tool.metadata.name == "multiply"
assert not meta_service_tool.registered


@pytest.mark.asyncio()
Expand All @@ -78,6 +79,7 @@ async def test_create_from_tool_service_direct(

# assert
assert meta_service_tool.metadata.name == "multiply"
assert not meta_service_tool.registered


@pytest.mark.asyncio()
Expand Down Expand Up @@ -137,7 +139,6 @@ async def test_tool_call_output(
meta_service_tool: MetaServiceTool = await MetaServiceTool.from_tool_service(
tool_service=tool_service, message_queue=message_queue, name="multiply"
)
await message_queue.register_consumer(meta_service_tool.as_consumer())
await message_queue.register_consumer(tool_service.as_consumer())
mq_task = asyncio.create_task(message_queue.start())
ts_task = asyncio.create_task(tool_service.processing_loop())
Expand All @@ -155,6 +156,7 @@ async def test_tool_call_output(
assert tool_output.tool_name == "multiply"
assert tool_output.raw_input == {"args": (), "kwargs": {"a": 1, "b": 9}}
assert len(meta_service_tool.tool_call_results) == 0
assert meta_service_tool.registered


@pytest.mark.asyncio()
Expand All @@ -169,7 +171,6 @@ async def test_tool_call_raise_timeout(
timeout=1e-9,
raise_timeout=True,
)
await message_queue.register_consumer(meta_service_tool.as_consumer())
await message_queue.register_consumer(tool_service.as_consumer())
mq_task = asyncio.create_task(message_queue.start())
ts_task = asyncio.create_task(tool_service.processing_loop())
Expand All @@ -196,7 +197,6 @@ async def test_tool_call_reach_timeout(
timeout=1e-9,
raise_timeout=False,
)
await message_queue.register_consumer(meta_service_tool.as_consumer())
await message_queue.register_consumer(tool_service.as_consumer())
mq_task = asyncio.create_task(message_queue.start())
ts_task = asyncio.create_task(tool_service.processing_loop())
Expand All @@ -212,3 +212,4 @@ async def test_tool_call_reach_timeout(
assert tool_output.is_error
assert tool_output.raw_input == {"args": (), "kwargs": {"a": 1, "b": 9}}
assert len(meta_service_tool.tool_call_results) == 0
assert meta_service_tool.registered
Loading