Skip to content

Commit

Permalink
Human Service (#20)
Browse files Browse the repository at this point in the history
* start human service

* add input req

* human service first unit test

* second unit test

* add 3rd unit test

* add process from queue

* nit

* start human local service script

* update human service to TaskDefinition

* kind of working example script

* use NEW_TASK

* create_human_req passing

* passing tests

* add pipeline human service example

* remove unused action type

* remove unused action type
  • Loading branch information
nerdai authored Jun 22, 2024
1 parent 9e9df9a commit 2e0dd52
Show file tree
Hide file tree
Showing 9 changed files with 497 additions and 5 deletions.
3 changes: 2 additions & 1 deletion agentfile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
ServiceTool,
)
from agentfile.tools import MetaServiceTool
from agentfile.services import AgentService, ToolService
from agentfile.services import AgentService, ToolService, HumanService

__all__ = [
# services
"AgentService",
"HumanService",
"ToolService",
# message queues
"SimpleMessageQueue",
Expand Down
1 change: 1 addition & 0 deletions agentfile/message_queues/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def _publish_to_consumer(self, message: QueueMessage, **kwargs: Any) -> An
consumer = self._select_consumer(message)
try:
await consumer.process_message(message, **kwargs)
logger.info(f"Successfully published message '{message.type}' to consumer.")
except Exception as e:
logger.debug(
f"Failed to publish message of type '{message.type}' to consumer. Message: {str(e)}"
Expand Down
11 changes: 7 additions & 4 deletions agentfile/orchestrators/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

HISTORY_KEY = "chat_history"
DEFAULT_SUMMARIZE_TMPL = "{history}\n\nThe above represents the progress so far, please condense the messages into a single message."
DEFAULT_FOLLOWUP_TMPL = "Pick the next action to take, or return a final response if my original input is satisfied. As a reminder, the original input was: {original_input}"
DEFAULT_FOLLOWUP_TMPL = (
"Pick the next action to take, or return a final response if my original "
"input is satisfied. As a reminder, the original input was: {original_input}"
)


class AgentOrchestrator(BaseOrchestrator):
Expand All @@ -25,12 +28,12 @@ def __init__(
self.llm = llm
self.summarize_prompt = summarize_prompt
self.followup_prompt = followup_prompt
self.human_tool = ServiceTool(name="human", description=human_description)
self.finalize_tool = ServiceTool(name="finalize", description=human_description)

async def get_next_messages(
self, task_def: TaskDefinition, tools: List[BaseTool], state: Dict[str, Any]
) -> Tuple[List[QueueMessage], Dict[str, Any]]:
tools_plus_human = [self.human_tool, *tools]
tools_plus_human = [self.finalize_tool, *tools]

chat_dicts = state.get(HISTORY_KEY, [])
chat_history = [ChatMessage(**x) for x in chat_dicts]
Expand All @@ -56,7 +59,7 @@ async def get_next_messages(

# check if there was a tool call
queue_messages = []
if len(response.sources) == 0 or response.sources[0].tool_name == "human":
if len(response.sources) == 0 or response.sources[0].tool_name == "finalize":
queue_messages.append(
QueueMessage(
type="human",
Expand Down
2 changes: 2 additions & 0 deletions agentfile/services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from agentfile.services.base import BaseService
from agentfile.services.agent import AgentService
from agentfile.services.human import HumanService
from agentfile.services.tool import ToolService
from agentfile.services.types import (
_Task,
Expand All @@ -12,6 +13,7 @@
__all__ = [
"BaseService",
"AgentService",
"HumanService",
"ToolService",
"_Task",
"_TaskSate",
Expand Down
229 changes: 229 additions & 0 deletions agentfile/services/human.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import asyncio
import logging
import uuid
import uvicorn
from asyncio import Lock
from contextlib import asynccontextmanager
from fastapi import FastAPI
from typing import Any, AsyncGenerator, Dict, Optional
from pydantic import BaseModel, Field, PrivateAttr

from agentfile.message_consumers.base import BaseMessageQueueConsumer
from agentfile.message_consumers.callable import CallableMessageConsumer
from agentfile.message_consumers.remote import RemoteMessageConsumer
from agentfile.message_publishers.publisher import PublishCallback
from agentfile.message_queues.base import BaseMessageQueue
from agentfile.messages.base import QueueMessage
from agentfile.services.base import BaseService
from agentfile.types import (
ActionTypes,
TaskDefinition,
TaskResult,
ServiceDefinition,
CONTROL_PLANE_NAME,
generate_id,
)
from llama_index.core.llms import ChatMessage, MessageRole

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logging.basicConfig(level=logging.DEBUG)


HELP_REQUEST_TEMPLATE_STR = (
"Your assistance is needed. Please respond to the request "
"provided below:\n===\n\n"
"{input_str}\n\n===\n"
)


class HumanService(BaseService):
service_name: str
description: str = "Local Human Service."
running: bool = True
step_interval: float = 0.1
host: Optional[str] = None
port: Optional[int] = None

_outstanding_human_tasks: Dict[str, "HumanTask"] = PrivateAttr()
_message_queue: BaseMessageQueue = PrivateAttr()
_app: FastAPI = PrivateAttr()
_publisher_id: str = PrivateAttr()
_publish_callback: Optional[PublishCallback] = PrivateAttr()
_lock: Lock = PrivateAttr()

def __init__(
self,
message_queue: BaseMessageQueue,
running: bool = True,
description: str = "Local Human Service",
service_name: str = "default_human_service",
publish_callback: Optional[PublishCallback] = None,
step_interval: float = 0.1,
host: Optional[str] = None,
port: Optional[int] = None,
) -> None:
super().__init__(
running=running,
description=description,
service_name=service_name,
step_interval=step_interval,
host=host,
port=port,
)

self._outstanding_human_tasks = {}
self._message_queue = message_queue
self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}"
self._publish_callback = publish_callback
self._lock = asyncio.Lock()
self._app = FastAPI(lifespan=self.lifespan)

self._app.add_api_route("/", self.home, methods=["GET"], tags=["Human Service"])

self._app.add_api_route(
"/help", self.create_task, methods=["POST"], tags=["Help Requests"]
)

@property
def service_definition(self) -> ServiceDefinition:
return ServiceDefinition(
service_name=self.service_name,
description=self.description,
prompt=[],
)

@property
def message_queue(self) -> BaseMessageQueue:
return self._message_queue

@property
def publisher_id(self) -> str:
return self._publisher_id

@property
def publish_callback(self) -> Optional[PublishCallback]:
return self._publish_callback

@property
def lock(self) -> Lock:
return self._lock

class HumanTask(BaseModel):
"""Lightweight container object over TaskDefinitions.
This is needed since orchestrators may send multiple `TaskDefinition`
with the same task_id. In such a case, this human service is expected
to address these multiple (sub)tasks for the overall task. In other words,
these sub tasks are all legitimate and should be processed.
"""

id_: str = Field(default_factory=generate_id)
task_definition: TaskDefinition

class Config:
arbitrary_types_allowed = True

async def processing_loop(self) -> None:
while True:
if not self.running:
await asyncio.sleep(self.step_interval)
continue

async with self.lock:
current_human_tasks = [*self._outstanding_human_tasks.values()]
for human_task in current_human_tasks:
task_def = human_task.task_definition
logger.info(
f"Processing request for human help for task: {task_def.task_id}"
)

# process req
result = input(
HELP_REQUEST_TEMPLATE_STR.format(input_str=task_def.input)
)

# create history
history = [
ChatMessage(
role=MessageRole.ASSISTANT,
content=HELP_REQUEST_TEMPLATE_STR.format(
input_str=task_def.input
),
),
ChatMessage(role=MessageRole.USER, content=result),
]

# publish the completed task
await self.publish(
QueueMessage(
type=CONTROL_PLANE_NAME,
action=ActionTypes.COMPLETED_TASK,
data=TaskResult(
task_id=task_def.task_id,
history=history,
result=result,
).model_dump(),
)
)

# clean up
async with self.lock:
del self._outstanding_human_tasks[human_task.id_]

await asyncio.sleep(self.step_interval)

async def process_message(self, message: QueueMessage, **kwargs: Any) -> None:
if message.action == ActionTypes.NEW_TASK:
task_def = TaskDefinition(**message.data or {})
human_task = self.HumanTask(task_definition=task_def)
async with self.lock:
self._outstanding_human_tasks.update({human_task.id_: human_task})
else:
raise ValueError(f"Unhandled action: {message.action}")

def as_consumer(self, remote: bool = False) -> BaseMessageQueueConsumer:
if remote:
url = f"{self.host}:{self.port}/{self._app.url_path_for('process_message')}"
return RemoteMessageConsumer(
url=url,
message_type=self.service_name,
)

return CallableMessageConsumer(
message_type=self.service_name,
handler=self.process_message,
)

async def launch_local(self) -> None:
logger.info(f"{self.service_name} launch_local")
asyncio.create_task(self.processing_loop())

# ---- Server based methods ----

@asynccontextmanager
async def lifespan(self) -> AsyncGenerator[None, None]:
"""Starts the processing loop when the fastapi app starts."""
asyncio.create_task(self.processing_loop())
yield
self.running = False

async def home(self) -> Dict[str, str]:
return {
"service_name": self.service_name,
"description": self.description,
"running": str(self.running),
"step_interval": str(self.step_interval),
}

async def create_task(self, task: TaskDefinition) -> Dict[str, str]:
human_task = self.HumanTask(task_definition=task)
async with self.lock:
self._outstanding_human_tasks.update({human_task.id_: human_task})
return {"task_id": task.task_id}

def launch_server(self) -> None:
uvicorn.run(self._app, host=self.host, port=self.port)


HumanService.model_rebuild()
47 changes: 47 additions & 0 deletions example_scripts/agentic_human_local_single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from agentfile.launchers.local import LocalLauncher
from agentfile.services import HumanService, AgentService
from agentfile.control_plane.fastapi import FastAPIControlPlane
from agentfile.message_queues.simple import SimpleMessageQueue
from agentfile.orchestrators.agent import AgentOrchestrator

from llama_index.core.agent import FunctionCallingAgentWorker
from llama_index.core.tools import FunctionTool
from llama_index.llms.openai import OpenAI


# create an agent
def get_the_secret_fact() -> str:
"""Returns the secret fact."""
return "The secret fact is: A baby llama is called a 'Cria'."


tool = FunctionTool.from_defaults(fn=get_the_secret_fact)

# create our multi-agent framework components
message_queue = SimpleMessageQueue()

worker = FunctionCallingAgentWorker.from_tools([tool], llm=OpenAI())
agent = worker.as_agent()
agent_service = AgentService(
agent=agent,
message_queue=message_queue,
description="Useful for getting the secret fact.",
service_name="secret_fact_agent",
)

human_service = HumanService(
message_queue=message_queue, description="Answers queries about math."
)

control_plane = FastAPIControlPlane(
message_queue=message_queue,
orchestrator=AgentOrchestrator(llm=OpenAI()),
)

# launch it
launcher = LocalLauncher(
[agent_service, human_service],
control_plane,
message_queue,
)
launcher.launch_single("What is 5 + 5?")
Loading

0 comments on commit 2e0dd52

Please sign in to comment.