Skip to content

Commit

Permalink
Add example script for ToolService with Launcher (#18)
Browse files Browse the repository at this point in the history
* wip

* put back error_on_no_tool_call=False

* cleaner shutdown and UX

* move registration of tool imn acall

* add deregister method
  • Loading branch information
nerdai authored Jun 20, 2024
1 parent 9aec369 commit 8931048
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 24 deletions.
37 changes: 27 additions & 10 deletions agentfile/launchers/local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import uuid
import signal
import sys

from typing import Any, Callable, Dict, List, Optional

from agentfile.services.base import BaseService
Expand Down Expand Up @@ -70,6 +73,15 @@ async def register_consumers(
def launch_single(self, initial_task: str) -> None:
asyncio.run(self.alaunch_single(initial_task))

def get_shutdown_handler(self, tasks: List[asyncio.Task]) -> Callable:
def signal_handler(sig: Any, frame: Any) -> None:
print("\nShutting down.")
for task in tasks:
task.cancel()
sys.exit(0)

return signal_handler

async def alaunch_single(self, initial_task: str) -> None:
# register human consumer
human_consumer = HumanMessageConsumer(
Expand All @@ -79,6 +91,15 @@ async def alaunch_single(self, initial_task: str) -> None:
)
await self.register_consumers([human_consumer])

# register each service to the control plane
for service in self.services:
await self.control_plane.register_service(service.service_definition)

# start services
bg_tasks = []
for service in self.services:
bg_tasks.append(asyncio.create_task(service.launch_local()))

# publish initial task
await self.publish(
QueueMessage(
Expand All @@ -87,14 +108,10 @@ async def alaunch_single(self, initial_task: str) -> None:
data=TaskDefinition(input=initial_task).dict(),
),
)

# register each service to the control plane
for service in self.services:
await self.control_plane.register_service(service.service_definition)

# start services
for service in self.services:
asyncio.create_task(service.launch_local())

# runs until the message queue is stopped by the human consumer
await self.message_queue.start()
mq_task = asyncio.create_task(self.message_queue.start())
shutdown_handler = self.get_shutdown_handler([mq_task] + bg_tasks)
loop = asyncio.get_event_loop()
while loop.is_running():
await asyncio.sleep(0.1)
signal.signal(signal.SIGINT, shutdown_handler)
7 changes: 6 additions & 1 deletion agentfile/message_queues/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ async def _publish_to_consumer(self, message: QueueMessage, **kwargs: Any) -> An
consumer = self._select_consumer(message)
try:
await consumer.process_message(message, **kwargs)
except Exception:
except Exception as e:
logger.debug(
f"Failed to publish message of type '{message.type}' to consumer. Message: {str(e)}"
)
raise

async def start(self) -> None:
Expand All @@ -87,11 +90,13 @@ async def register_consumer(

if message_type_str not in self.consumers:
self.consumers[message_type_str] = {consumer.id_: consumer}
logger.info(f"Consumer {consumer.id_} has been registered.")
else:
if consumer.id_ in self.consumers[message_type_str]:
raise ValueError("Consumer has already been added.")

self.consumers[message_type_str][consumer.id_] = consumer
logger.info(f"Consumer {consumer.id_} has been registered.")

if message_type_str not in self.queues:
self.queues[message_type_str] = deque()
Expand Down
7 changes: 7 additions & 0 deletions agentfile/services/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
CONTROL_PLANE_NAME,
)

import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)


class AgentService(BaseService):
service_name: str
Expand Down Expand Up @@ -164,6 +170,7 @@ def as_consumer(self) -> BaseMessageQueueConsumer:
)

async def launch_local(self) -> None:
logger.info(f"{self.service_name} launch_local")
asyncio.create_task(self.processing_loop())

# ---- Server based methods ----
Expand Down
8 changes: 8 additions & 0 deletions agentfile/services/tool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging
import uuid
import uvicorn
from asyncio import Lock
Expand Down Expand Up @@ -26,6 +27,10 @@
ServiceDefinition,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)


class ToolService(BaseService):
service_name: str
Expand Down Expand Up @@ -117,6 +122,9 @@ async def processing_loop(self) -> None:
self.tools, tool_call.tool_call_bundle.tool_name
)

logger.info(
f"Processing tool call id {tool_call.id_} with {tool.metadata.name}"
)
tool_output = await tool.acall(
*tool_call.tool_call_bundle.tool_args,
**tool_call.tool_call_bundle.tool_kwargs,
Expand Down
21 changes: 11 additions & 10 deletions agentfile/tools/meta_service_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,13 @@
logging.basicConfig(level=logging.DEBUG)


class TimeoutException(Exception):
"""Raise when polling for results from message queue exceed timeout."""

pass


class MetaServiceTool(MessageQueuePublisherMixin, AsyncBaseTool, BaseModel):
tool_call_results: Dict[str, ToolCallResult] = Field(default_factory=dict)
timeout: float = Field(default=10.0, description="timeout interval in seconds.")
tool_service_name: str = Field(default_factory=str)
step_interval: float = 0.1
raise_timeout: bool = False
registered: bool = False

_message_queue: BaseMessageQueue = PrivateAttr()
_publisher_id: str = PrivateAttr()
Expand Down Expand Up @@ -132,10 +127,6 @@ def publish_callback(self) -> Optional[PublishCallback]:
def metadata(self) -> ToolMetadata:
return self._metadata

@metadata.setter
def metadata(self, value: ToolMetadata) -> None:
self._metadata = value

@property
def lock(self) -> asyncio.Lock:
return self._lock
Expand Down Expand Up @@ -174,6 +165,11 @@ async def _poll_for_tool_call_result(self, tool_call_id: str) -> ToolCallResult:
await asyncio.sleep(self.step_interval)
return tool_call_result

async def deregister(self) -> None:
"""Deregister from message queue."""
await self.message_queue.deregister_consumer(self.as_consumer())
self.registered = False

def call(self, *args: Any, **kwargs: Any) -> ToolOutput:
"""Call."""
return asyncio.run(self.acall(*args, **kwargs))
Expand All @@ -184,6 +180,11 @@ async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput:
In order to get a ToolOutput result, this will poll the queue until
the result is written.
"""
if not self.registered:
# register tool to message queue
await self.message_queue.register_consumer(self.as_consumer())
self.registered = True

tool_call = ToolCall(
tool_call_bundle=ToolCallBundle(
tool_name=self.metadata.name, tool_args=args, tool_kwargs=kwargs
Expand Down
59 changes: 59 additions & 0 deletions example_scripts/agentic_toolservice_local_single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from agentfile.launchers.local import LocalLauncher
from agentfile.services import AgentService, ToolService
from agentfile.tools import MetaServiceTool
from agentfile.control_plane.fastapi import FastAPIControlPlane
from agentfile.message_queues.simple import SimpleMessageQueue
from agentfile.orchestrators.agent import AgentOrchestrator

from llama_index.core.agent import FunctionCallingAgentWorker
from llama_index.core.tools import FunctionTool
from llama_index.llms.openai import OpenAI


# create an agent
def get_the_secret_fact() -> str:
"""Returns the secret fact."""
return "The secret fact is: A baby llama is called a 'Cria'."


tool = FunctionTool.from_defaults(fn=get_the_secret_fact)


# create our multi-agent framework components
message_queue = SimpleMessageQueue()
tool_service = ToolService(
message_queue=message_queue,
tools=[tool],
running=True,
step_interval=0.5,
)

control_plane = FastAPIControlPlane(
message_queue=message_queue,
orchestrator=AgentOrchestrator(llm=OpenAI()),
)

meta_tool = MetaServiceTool(
tool_metadata=tool.metadata,
message_queue=message_queue,
tool_service_name=tool_service.service_name,
)
worker1 = FunctionCallingAgentWorker.from_tools(
[meta_tool],
llm=OpenAI(),
)
agent1 = worker1.as_agent()
agent_server_1 = AgentService(
agent=agent1,
message_queue=message_queue,
description="Useful for getting the secret fact.",
service_name="secret_fact_agent",
)

# launch it
launcher = LocalLauncher(
[agent_server_1, tool_service],
control_plane,
message_queue,
)
launcher.launch_single("What is the secret fact?")
7 changes: 4 additions & 3 deletions tests/test_meta_service_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ async def test_init(

# assert
assert meta_service_tool.metadata.name == "multiply"
assert not meta_service_tool.registered


@pytest.mark.asyncio()
Expand All @@ -78,6 +79,7 @@ async def test_create_from_tool_service_direct(

# assert
assert meta_service_tool.metadata.name == "multiply"
assert not meta_service_tool.registered


@pytest.mark.asyncio()
Expand Down Expand Up @@ -137,7 +139,6 @@ async def test_tool_call_output(
meta_service_tool: MetaServiceTool = await MetaServiceTool.from_tool_service(
tool_service=tool_service, message_queue=message_queue, name="multiply"
)
await message_queue.register_consumer(meta_service_tool.as_consumer())
await message_queue.register_consumer(tool_service.as_consumer())
mq_task = asyncio.create_task(message_queue.start())
ts_task = asyncio.create_task(tool_service.processing_loop())
Expand All @@ -155,6 +156,7 @@ async def test_tool_call_output(
assert tool_output.tool_name == "multiply"
assert tool_output.raw_input == {"args": (), "kwargs": {"a": 1, "b": 9}}
assert len(meta_service_tool.tool_call_results) == 0
assert meta_service_tool.registered


@pytest.mark.asyncio()
Expand All @@ -169,7 +171,6 @@ async def test_tool_call_raise_timeout(
timeout=1e-9,
raise_timeout=True,
)
await message_queue.register_consumer(meta_service_tool.as_consumer())
await message_queue.register_consumer(tool_service.as_consumer())
mq_task = asyncio.create_task(message_queue.start())
ts_task = asyncio.create_task(tool_service.processing_loop())
Expand All @@ -196,7 +197,6 @@ async def test_tool_call_reach_timeout(
timeout=1e-9,
raise_timeout=False,
)
await message_queue.register_consumer(meta_service_tool.as_consumer())
await message_queue.register_consumer(tool_service.as_consumer())
mq_task = asyncio.create_task(message_queue.start())
ts_task = asyncio.create_task(tool_service.processing_loop())
Expand All @@ -212,3 +212,4 @@ async def test_tool_call_reach_timeout(
assert tool_output.is_error
assert tool_output.raw_input == {"args": (), "kwargs": {"a": 1, "b": 9}}
assert len(meta_service_tool.tool_call_results) == 0
assert meta_service_tool.registered

0 comments on commit 8931048

Please sign in to comment.