Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Human Service #20

Merged
merged 17 commits into from
Jun 22, 2024
1 change: 1 addition & 0 deletions agentfile/message_queues/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def _publish_to_consumer(self, message: QueueMessage, **kwargs: Any) -> An
consumer = self._select_consumer(message)
try:
await consumer.process_message(message, **kwargs)
logger.info(f"Successfully published message '{message.type}' to consumer.")
except Exception as e:
logger.debug(
f"Failed to publish message of type '{message.type}' to consumer. Message: {str(e)}"
Expand Down
17 changes: 12 additions & 5 deletions agentfile/orchestrators/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

HISTORY_KEY = "chat_history"
DEFAULT_SUMMARIZE_TMPL = "{history}\n\nThe above represents the progress so far, please condense the messages into a single message."
DEFAULT_FOLLOWUP_TMPL = "Pick the next action to take, or return a final response if my original input is satisfied. As a reminder, the original input was: {original_input}"
DEFAULT_FOLLOWUP_TMPL = (
"Pick the next action to take, or return a final response if my original "
"input is satisfied. As a reminder, the original input was: {original_input}"
)


class AgentOrchestrator(BaseOrchestrator):
Expand All @@ -25,12 +28,12 @@ def __init__(
self.llm = llm
self.summarize_prompt = summarize_prompt
self.followup_prompt = followup_prompt
self.human_tool = ServiceTool(name="human", description=human_description)
self.finalize_tool = ServiceTool(name="finalize", description=human_description)

async def get_next_messages(
self, task_def: TaskDefinition, tools: List[BaseTool], state: Dict[str, Any]
) -> Tuple[List[QueueMessage], Dict[str, Any]]:
tools_plus_human = [self.human_tool, *tools]
tools_plus_human = [self.finalize_tool, *tools]

chat_dicts = state.get(HISTORY_KEY, [])
chat_history = [ChatMessage(**x) for x in chat_dicts]
Expand All @@ -56,7 +59,7 @@ async def get_next_messages(

# check if there was a tool call
queue_messages = []
if len(response.sources) == 0 or response.sources[0].tool_name == "human":
if len(response.sources) == 0 or response.sources[0].tool_name == "finalize":
queue_messages.append(
QueueMessage(
type="human",
Expand All @@ -73,13 +76,17 @@ async def get_next_messages(
name = source.tool_name
input_data = source.raw_input
input_str = next(iter(input_data.values()))
if name == "default_human_service":
action = ActionTypes.REQUEST_FOR_HELP
nerdai marked this conversation as resolved.
Show resolved Hide resolved
else:
action = ActionTypes.NEW_TASK
queue_messages.append(
QueueMessage(
type=name,
data=TaskDefinition(
task_id=task_def.task_id, input=input_str
).model_dump(),
action=ActionTypes.NEW_TASK,
action=action,
)
)

Expand Down
2 changes: 2 additions & 0 deletions agentfile/services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from agentfile.services.base import BaseService
from agentfile.services.agent import AgentService
from agentfile.services.human import HumanService
from agentfile.services.tool import ToolService
from agentfile.services.types import (
_Task,
Expand All @@ -12,6 +13,7 @@
__all__ = [
"BaseService",
"AgentService",
"HumanService",
"ToolService",
"_Task",
"_TaskSate",
Expand Down
209 changes: 209 additions & 0 deletions agentfile/services/human.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import asyncio
import logging
import uuid
import uvicorn
from asyncio import Lock
from contextlib import asynccontextmanager
from fastapi import FastAPI
from typing import Any, AsyncGenerator, Dict, List, Optional
from pydantic import PrivateAttr

from agentfile.message_consumers.base import BaseMessageQueueConsumer
from agentfile.message_consumers.callable import CallableMessageConsumer
from agentfile.message_consumers.remote import RemoteMessageConsumer
from agentfile.message_publishers.publisher import PublishCallback
from agentfile.message_queues.base import BaseMessageQueue
from agentfile.messages.base import QueueMessage
from agentfile.services.base import BaseService
from agentfile.types import (
ActionTypes,
TaskDefinition,
TaskResult,
ServiceDefinition,
CONTROL_PLANE_NAME,
)
from llama_index.core.llms import ChatMessage, MessageRole

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)


HELP_REQUEST_TEMPLATE_STR = (
"Your assistance is needed. Please respond to the request "
"provided below:\n===\n\n"
"{input_str}\n\n===\n"
)


class HumanService(BaseService):
service_name: str
description: str = "Local Human Service."
running: bool = True
step_interval: float = 0.1
host: Optional[str] = None
port: Optional[int] = None

_outstanding_human_requests: Dict[str, TaskDefinition] = PrivateAttr()
_message_queue: BaseMessageQueue = PrivateAttr()
_app: FastAPI = PrivateAttr()
_publisher_id: str = PrivateAttr()
_publish_callback: Optional[PublishCallback] = PrivateAttr()
_lock: Lock = PrivateAttr()

def __init__(
self,
message_queue: BaseMessageQueue,
running: bool = True,
description: str = "Local Human Service",
service_name: str = "default_human_service",
publish_callback: Optional[PublishCallback] = None,
step_interval: float = 0.1,
host: Optional[str] = None,
port: Optional[int] = None,
) -> None:
super().__init__(
running=running,
description=description,
service_name=service_name,
step_interval=step_interval,
host=host,
port=port,
)

self._outstanding_human_requests = {}
self._message_queue = message_queue
self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}"
self._publish_callback = publish_callback
self._lock = asyncio.Lock()
self._app = FastAPI(lifespan=self.lifespan)

self._app.add_api_route("/", self.home, methods=["GET"], tags=["Human Service"])

self._app.add_api_route(
"/help", self.create_human_request, methods=["POST"], tags=["Help Requests"]
)

@property
def service_definition(self) -> ServiceDefinition:
return ServiceDefinition(
service_name=self.service_name,
description=self.description,
prompt=[],
)

@property
def message_queue(self) -> BaseMessageQueue:
return self._message_queue

@property
def publisher_id(self) -> str:
return self._publisher_id

@property
def publish_callback(self) -> Optional[PublishCallback]:
return self._publish_callback

@property
def lock(self) -> Lock:
return self._lock

async def processing_loop(self) -> None:
while True:
if not self.running:
await asyncio.sleep(self.step_interval)
continue

async with self.lock:
current_requests: List[TaskDefinition] = [
*self._outstanding_human_requests.values()
]
for task_def in current_requests:
logger.info(
f"Processing request for human help for task: {task_def.task_id}"
)

# process req
result = input(
HELP_REQUEST_TEMPLATE_STR.format(input_str=task_def.input)
)

# create history
history = [
ChatMessage(
role=MessageRole.ASSISTANT,
content=HELP_REQUEST_TEMPLATE_STR.format(
input_str=task_def.input
),
),
ChatMessage(role=MessageRole.USER, content=result),
]
nerdai marked this conversation as resolved.
Show resolved Hide resolved

# publish the completed task
await self.publish(
QueueMessage(
type=CONTROL_PLANE_NAME,
action=ActionTypes.COMPLETED_TASK,
data=TaskResult(
task_id=task_def.task_id,
history=history,
result=result,
).model_dump(),
)
)

# clean up
async with self.lock:
del self._outstanding_human_requests[task_def.task_id]

await asyncio.sleep(self.step_interval)

async def process_message(self, message: QueueMessage, **kwargs: Any) -> None:
if message.action == ActionTypes.REQUEST_FOR_HELP:
task_def = TaskDefinition(**message.data or {})
async with self.lock:
self._outstanding_human_requests.update({task_def.task_id: task_def})
nerdai marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError(f"Unhandled action: {message.action}")

def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer:
if remote:
url = f"{self.host}:{self.port}/{self._app.url_path_for('process_message')}"
return RemoteMessageConsumer(
url=url,
message_type=self.service_name,
)

return CallableMessageConsumer(
message_type=self.service_name,
handler=self.process_message,
)

async def launch_local(self) -> None:
logger.info(f"{self.service_name} launch_local")
asyncio.create_task(self.processing_loop())

# ---- Server based methods ----

@asynccontextmanager
async def lifespan(self) -> AsyncGenerator[None, None]:
"""Starts the processing loop when the fastapi app starts."""
asyncio.create_task(self.processing_loop())
yield
self.running = False

async def home(self) -> Dict[str, str]:
return {
"service_name": self.service_name,
"description": self.description,
"running": str(self.running),
"step_interval": str(self.step_interval),
}

async def create_human_request(self, req: TaskDefinition) -> Dict[str, str]:
async with self.lock:
self._outstanding_human_requests.update({req.task_id: req})
return {"human_request_id": req.task_id}

def launch_server(self) -> None:
uvicorn.run(self._app, host=self.host, port=self.port)
12 changes: 12 additions & 0 deletions agentfile/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ActionTypes(str, Enum):
NEW_TASK = "new_task"
COMPLETED_TASK = "completed_task"
REQUEST_FOR_HELP = "request_for_help"
COMPLETED_REQUEST_FOR_HELP = "completed_request_for_help"
NEW_TOOL_CALL = "new_tool_call"
COMPLETED_TOOL_CALL = "completed_tool_call"

Expand Down Expand Up @@ -52,6 +53,17 @@ class ToolCallResult(BaseModel):
result: str


class HumanRequest(BaseModel):
id_: str
input: str
source_id: str


class HumanResult(BaseModel):
id_: str
result: str


class ServiceDefinition(BaseModel):
service_name: str = Field(description="The name of the service.")
description: str = Field(
Expand Down
47 changes: 47 additions & 0 deletions example_scripts/human_local_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from agentfile.launchers.local import LocalLauncher
from agentfile.services import HumanService, AgentService
from agentfile.control_plane.fastapi import FastAPIControlPlane
from agentfile.message_queues.simple import SimpleMessageQueue
from agentfile.orchestrators.agent import AgentOrchestrator

from llama_index.core.agent import FunctionCallingAgentWorker
from llama_index.core.tools import FunctionTool
from llama_index.llms.openai import OpenAI


# create an agent
def get_the_secret_fact() -> str:
"""Returns the secret fact."""
return "The secret fact is: A baby llama is called a 'Cria'."


tool = FunctionTool.from_defaults(fn=get_the_secret_fact)

# create our multi-agent framework components
message_queue = SimpleMessageQueue()

worker = FunctionCallingAgentWorker.from_tools([tool], llm=OpenAI())
agent = worker.as_agent()
agent_service = AgentService(
agent=agent,
message_queue=message_queue,
description="Useful for getting the secret fact.",
service_name="secret_fact_agent",
)

human_service = HumanService(
message_queue=message_queue, description="Answers queries about math."
)

control_plane = FastAPIControlPlane(
message_queue=message_queue,
orchestrator=AgentOrchestrator(llm=OpenAI()),
)

# launch it
launcher = LocalLauncher(
[agent_service, human_service],
control_plane,
message_queue,
)
launcher.launch_single("What is 5 + 5?")
nerdai marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading