Skip to content

Commit

Permalink
use agent id in apis (autogenhub#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits authored Jun 17, 2024
1 parent c29218b commit edb939f
Show file tree
Hide file tree
Showing 19 changed files with 137 additions and 102 deletions.
2 changes: 1 addition & 1 deletion examples/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def assistant_chat(runtime: AgentRuntime) -> UserProxyAgent: # type: ignore
description="A group chat manager.",
runtime=runtime,
memory=BufferedChatMemory(buffer_size=10),
participants=[assistant, user],
participants=[assistant.id, user.id],
)
return user

Expand Down
2 changes: 1 addition & 1 deletion examples/chess_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def get_board_text() -> Annotated[str, "The current board state"]:
description="A chess game between two agents.",
runtime=runtime,
memory=BufferedChatMemory(buffer_size=10),
participants=[white, black], # white goes first
participants=[white.id, black.id], # white goes first
)


Expand Down
2 changes: 1 addition & 1 deletion examples/coder_reviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def coder_reviewer(runtime: AgentRuntime) -> None:
name="Manager",
description="A manager that orchestrates a back-and-forth converation between a coder and a reviewer.",
runtime=runtime,
participants=[coder, reviewer], # The order of the participants indicates the order of speaking.
participants=[coder.id, reviewer.id], # The order of the participants indicates the order of speaking.
memory=BufferedChatMemory(buffer_size=10),
termination_word="APPROVE",
on_message_received=lambda message: print(f"{'-'*80}\n{message.source}: {message.content}"),
Expand Down
22 changes: 11 additions & 11 deletions examples/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from agnext.application import SingleThreadedAgentRuntime
from agnext.components import TypeRoutedAgent, message_handler
from agnext.core import Agent, AgentRuntime, CancellationToken
from agnext.core import AgentProxy, AgentRuntime, CancellationToken


@dataclass
Expand All @@ -13,35 +13,35 @@ class MessageType:


class Inner(TypeRoutedAgent): # type: ignore
def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore
super().__init__(name, "The inner agent", router)
def __init__(self, name: str, runtime: AgentRuntime) -> None: # type: ignore
super().__init__(name, "The inner agent", runtime)

@message_handler() # type: ignore
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
return MessageType(body=f"Inner: {message.body}", sender=self.metadata["name"])


class Outer(TypeRoutedAgent): # type: ignore
def __init__(self, name: str, router: AgentRuntime, inner: Agent) -> None: # type: ignore
super().__init__(name, "The outter agent", router)
def __init__(self, name: str, runtime: AgentRuntime, inner: AgentProxy) -> None: # type: ignore
super().__init__(name, "The outter agent", runtime)
self._inner = inner

@message_handler() # type: ignore
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
inner_response = self._send_message(message, self._inner)
inner_response = self._send_message(message, self._inner.id)
inner_message = await inner_response
assert isinstance(inner_message, MessageType)
return MessageType(body=f"Outer: {inner_message.body}", sender=self.metadata["name"])


async def main() -> None:
router = SingleThreadedAgentRuntime()
inner = Inner("inner", router)
outer = Outer("outer", router, inner)
response = router.send_message(MessageType(body="Hello", sender="external"), outer)
runtime = SingleThreadedAgentRuntime()
inner = Inner("inner", runtime)
outer = Outer("outer", runtime, AgentProxy(inner, runtime))
response = runtime.send_message(MessageType(body="Hello", sender="external"), outer)

while not response.done():
await router.process_next()
await runtime.process_next()

print(await response)

Expand Down
6 changes: 3 additions & 3 deletions examples/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def software_development(runtime: AgentRuntime) -> OrchestratorChat: # type: ig
"OrchestratorChat",
"A software development team.",
runtime,
orchestrator=orchestrator,
planner=planner,
specialists=[developer, product_manager, tester],
orchestrator=orchestrator.id,
planner=planner.id,
specialists=[developer.id, product_manager.id, tester.id],
)


Expand Down
2 changes: 1 addition & 1 deletion examples/software_consultancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def software_consultancy(runtime: AgentRuntime, user_agent: Agent) -> None: # t
runtime=runtime,
memory=HeadAndTailChatMemory(head_size=1, tail_size=10),
# model_client=OpenAI(model="gpt-4-turbo"),
participants=[developer, product_manager, ux_designer, illustrator, user_agent],
participants=[developer.id, product_manager.id, ux_designer.id, illustrator.id, user_agent.id],
)


Expand Down
45 changes: 31 additions & 14 deletions src/agnext/application/_single_threaded_agent_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from typing import Any, Awaitable, Dict, List, Mapping, Set

from ..core import Agent, AgentMetadata, AgentRuntime, CancellationToken
from ..core import Agent, AgentId, AgentMetadata, AgentRuntime, CancellationToken
from ..core.exceptions import MessageDroppedException
from ..core.intervention import DropMessage, InterventionHandler

Expand Down Expand Up @@ -77,18 +77,14 @@ def unprocessed_messages(
def send_message(
self,
message: Any,
recipient: Agent,
recipient: Agent | AgentId,
*,
sender: Agent | None = None,
sender: Agent | AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[Any | None]:
if cancellation_token is None:
cancellation_token = CancellationToken()

logger.info(
f"Sending message of type {type(message).__name__} to {recipient.metadata['name']}: {message.__dict__}"
)

# event_logger.info(
# MessageEvent(
# payload=message,
Expand All @@ -99,6 +95,14 @@ def send_message(
# )
# )

recipient = self._get_agent(recipient)
if sender is not None:
sender = self._get_agent(sender)

logger.info(
f"Sending message of type {type(message).__name__} to {recipient.metadata['name']}: {message.__dict__}"
)

future = asyncio.get_event_loop().create_future()
if recipient not in self._agents:
future.set_exception(Exception("Recipient not found"))
Expand All @@ -119,12 +123,15 @@ def publish_message(
self,
message: Any,
*,
sender: Agent | None = None,
sender: Agent | AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Future[None]:
if cancellation_token is None:
cancellation_token = CancellationToken()

if sender is not None:
sender = self._get_agent(sender)

logger.info(f"Publishing message of type {type(message).__name__} to all subscribers: {message.__dict__}")

# event_logger.info(
Expand Down Expand Up @@ -300,11 +307,21 @@ async def process_next(self) -> None:
# Yield control to the message loop to allow other tasks to run
await asyncio.sleep(0)

def agent_metadata(self, agent: Agent) -> AgentMetadata:
return agent.metadata
def agent_metadata(self, agent: Agent | AgentId) -> AgentMetadata:
return self._get_agent(agent).metadata

def agent_save_state(self, agent: Agent | AgentId) -> Mapping[str, Any]:
return self._get_agent(agent).save_state()

def agent_save_state(self, agent: Agent) -> Mapping[str, Any]:
return agent.save_state()
def agent_load_state(self, agent: Agent | AgentId, state: Mapping[str, Any]) -> None:
self._get_agent(agent).load_state(state)

def _get_agent(self, agent_or_id: Agent | AgentId) -> Agent:
if isinstance(agent_or_id, Agent):
return agent_or_id

for agent in self._agents:
if agent.metadata["name"] == agent_or_id.name:
return agent

def agent_load_state(self, agent: Agent, state: Mapping[str, Any]) -> None:
agent.load_state(state)
raise ValueError(f"Agent with name {agent_or_id} not found")
4 changes: 2 additions & 2 deletions src/agnext/chat/agents/chat_completion_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ async def _generate_response(
# Send a function call message to itself.
response = await self._send_message(
message=FunctionCallMessage(content=response.content, source=self.metadata["name"]),
recipient=self,
recipient=self.id,
cancellation_token=cancellation_token,
)
# Make an assistant message from the response.
Expand Down Expand Up @@ -232,7 +232,7 @@ async def _execute_function(
)
approval_response = await self._send_message(
message=approval_request,
recipient=self._tool_approver,
recipient=self._tool_approver.id,
cancellation_token=cancellation_token,
)
if not isinstance(approval_response, ToolApprovalResponse):
Expand Down
4 changes: 2 additions & 2 deletions src/agnext/chat/patterns/group_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, List, Protocol, Sequence

from ...components import TypeRoutedAgent, message_handler
from ...core import Agent, AgentRuntime, CancellationToken
from ...core import AgentId, AgentRuntime, CancellationToken
from ..types import Reset, RespondNow, TextMessage


Expand All @@ -19,7 +19,7 @@ def __init__(
name: str,
description: str,
runtime: AgentRuntime,
participants: Sequence[Agent],
participants: Sequence[AgentId],
num_rounds: int,
output: GroupChatOutput,
) -> None:
Expand Down
26 changes: 14 additions & 12 deletions src/agnext/chat/patterns/group_chat_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ...components import TypeRoutedAgent, message_handler
from ...components.models import ChatCompletionClient
from ...core import Agent, AgentRuntime, CancellationToken
from ...core import AgentId, AgentProxy, AgentRuntime, CancellationToken
from ..memory import ChatMemory
from ..types import (
PublishNow,
Expand Down Expand Up @@ -40,29 +40,31 @@ def __init__(
name: str,
description: str,
runtime: AgentRuntime,
participants: List[Agent],
participants: List[AgentId],
memory: ChatMemory,
model_client: ChatCompletionClient | None = None,
termination_word: str = "TERMINATE",
transitions: Mapping[Agent, List[Agent]] = {},
transitions: Mapping[AgentId, List[AgentId]] = {},
on_message_received: Callable[[TextMessage], None] | None = None,
):
super().__init__(name, description, runtime)
self._memory = memory
self._client = model_client
self._participants = participants
self._participants = [AgentProxy(x, runtime) for x in participants]
self._termination_word = termination_word
for key, value in transitions.items():
proxy = AgentProxy(key, runtime)
if not value:
# Make sure no empty transitions are provided.
raise ValueError(f"Empty transition list provided for {key.metadata['name']}.")
if key not in participants:
raise ValueError(f"Empty transition list provided for {proxy.metadata['name']}.")
if proxy.id not in participants:
# Make sure all keys are in the list of participants.
raise ValueError(f"Transition key {key.metadata['name']} not found in participants.")
raise ValueError(f"Transition key {proxy.metadata['name']} not found in participants.")
for v in value:
if v not in participants:
proxy = AgentProxy(v, runtime)
if proxy.id not in participants:
# Make sure all values are in the list of participants.
raise ValueError(f"Transition value {v.metadata['name']} not found in participants.")
raise ValueError(f"Transition value {proxy.metadata['name']} not found in participants.")
self._tranistiions = transitions
self._on_message_received = on_message_received

Expand Down Expand Up @@ -107,15 +109,15 @@ async def on_text_message(self, message: TextMessage, cancellation_token: Cancel
candidates = self._participants
if last_speaker_index is not None:
last_speaker = self._participants[last_speaker_index]
if self._tranistiions.get(last_speaker) is not None:
candidates = self._tranistiions[last_speaker]
if self._tranistiions.get(last_speaker.id) is not None:
candidates = [AgentProxy(x, self._runtime) for x in self._tranistiions[last_speaker.id]]
if len(candidates) == 1:
speaker = candidates[0]
else:
speaker = await select_speaker(self._memory, self._client, candidates)

# Send the message to the selected speaker to ask it to publish a response.
await self._send_message(PublishNow(), speaker)
await self._send_message(PublishNow(), speaker.id)

def save_state(self) -> Mapping[str, Any]:
return {
Expand Down
6 changes: 3 additions & 3 deletions src/agnext/chat/patterns/group_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from typing import Dict, List

from ...components.models import ChatCompletionClient, SystemMessage
from ...core import Agent
from ...core import AgentProxy
from ..memory import ChatMemory
from ..types import TextMessage


async def select_speaker(memory: ChatMemory, client: ChatCompletionClient, agents: List[Agent]) -> Agent:
async def select_speaker(memory: ChatMemory, client: ChatCompletionClient, agents: List[AgentProxy]) -> AgentProxy:
"""Selects the next speaker in a group chat using a ChatCompletion client."""
# TODO: Handle multi-modal messages.

Expand Down Expand Up @@ -47,7 +47,7 @@ async def select_speaker(memory: ChatMemory, client: ChatCompletionClient, agent
return agent


def mentioned_agents(message_content: str, agents: List[Agent]) -> Dict[str, int]:
def mentioned_agents(message_content: str, agents: List[AgentProxy]) -> Dict[str, int]:
"""Counts the number of times each agent is mentioned in the provided message content.
Agent names will match under any of the following conditions (all case-sensitive):
- Exact name match
Expand Down
Loading

0 comments on commit edb939f

Please sign in to comment.