Skip to content

Commit

Permalink
Refactoring, nested chats now create agents for each
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Sze <[email protected]>
  • Loading branch information
marklysze committed Dec 2, 2024
1 parent 57bfab1 commit 3cddfcb
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 132 deletions.
212 changes: 85 additions & 127 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,15 @@ def __post_init__(self):

@dataclass
class ON_CONDITION:
agent: Optional["SwarmAgent"] = None
nested_chat: Optional[Dict[str, Any]] = None
target: Union["SwarmAgent", Dict[str, Any]] = None
condition: str = ""

def __post_init__(self):
# Ensure valid types
if self.agent is not None:
assert isinstance(self.agent, SwarmAgent), "'agent' must be a SwarmAgent"

if self.nested_chat is not None:
assert isinstance(self.nested_chat, Dict), "'nested_chat' must be a Dict"

# Ensure they have an agent or nested_chat
assert self.agent is not None or self.nested_chat is not None, "'agent' or 'nested_chat' must be provided"

# Ensure they don't have both an agent and a nested_chat
assert not (
self.agent is not None and self.nested_chat is not None
), "'agent' and 'nested_chat' cannot both be provided"
if self.target is not None:
assert isinstance(self.target, SwarmAgent) or isinstance(
self.target, Dict
), "'target' must be a SwarmAgent or a Dict"

# Ensure they have a condition
assert isinstance(self.condition, str) and self.condition.strip(), "'condition' must be a non-empty string"
Expand Down Expand Up @@ -113,18 +103,12 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

swarm_agent_names = [agent.name for agent in agents]

tool_execution = SwarmAgent(
name="Tool_Execution",
system_message="Tool Execution",
)
tool_execution._set_to_tool_execution(context_variables=context_variables)

# Update tool execution agent with all the functions from all the agents
for agent in agents:
tool_execution._function_map.update(agent._function_map)

INIT_AGENT_USED = False

def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
Expand Down Expand Up @@ -179,6 +163,43 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
else:
raise ValueError("Invalid After Work condition")

def create_nested_chats(agent: SwarmAgent, nested_chat_agents: List[SwarmAgent]):
"""Create nested chat agents and register nested chats"""
for i, nested_chat_handoff in enumerate(agent._nested_chat_handoffs):
nested_chats: Dict[str, Any] = nested_chat_handoff["nested_chats"]
condition = nested_chat_handoff["condition"]

# Create a nested chat agent specifically for this nested chat
nested_chat_agent = SwarmAgent(name=f"nested_chat_{agent.name}_{i + 1}")

nested_chat_agent.register_nested_chats(
nested_chats["chat_queue"],
reply_func_from_nested_chats=nested_chats.get("reply_func_from_nested_chats")
or "summary_from_nested_chats",
config=nested_chats.get("config", None),
trigger=lambda sender: True,
position=0,
use_async=nested_chats.get("use_async", False),
)

# After the nested chat is complete, transfer back to the parent agent
nested_chat_agent.register_hand_off(AFTER_WORK(agent=agent))

nested_chat_agents.append(nested_chat_agent)

# Nested chat is triggered through an agent transfer to this nested chat agent
agent.register_hand_off(ON_CONDITION(nested_chat_agent, condition))

nested_chat_agents = []
for agent in agents:
create_nested_chats(agent, nested_chat_agents)

# Update tool execution agent with all the functions from all the agents
for agent in agents + nested_chat_agents:
tool_execution._function_map.update(agent._function_map)

swarm_agent_names = [agent.name for agent in agents + nested_chat_agents]

# If there's only one message and there's no identified swarm agent
# Start with a user proxy agent, creating one if they haven't passed one in
if len(messages) == 1 and "name" not in messages[0] and not user_agent:
Expand All @@ -187,20 +208,17 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
temp_user_proxy = []

groupchat = GroupChat(
agents=[tool_execution] + agents + ([user_agent] if user_agent is not None else temp_user_proxy),
agents=[tool_execution]
+ agents
+ nested_chat_agents
+ ([user_agent] if user_agent is not None else temp_user_proxy),
messages=[], # Set to empty. We will resume the conversation with the messages
max_round=max_rounds,
speaker_selection_method=swarm_transition,
)
manager = GroupChatManager(groupchat)
clear_history = True

# We associate the groupchat manager with SwarmAgents
# to be able to access group messages, tool executor context variables
for agent in agents:
if isinstance(agent, SwarmAgent):
agent.associate_groupchat(manager)

if len(messages) > 1:
last_agent, last_message = manager.resume(messages=messages)
clear_history = False
Expand Down Expand Up @@ -311,6 +329,10 @@ def __init__(
self._context_variables = {}
self._next_agent = None

# Store nested chats hand offs as we'll establish these in the initiate_swarm_chat
# List of Dictionaries containing the nested_chats and condition
self._nested_chat_handoffs = []

def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None):
"""Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents.
This agent will be used internally and should not be visible to the user.
Expand Down Expand Up @@ -355,97 +377,24 @@ def transfer_to_agent_name() -> SwarmAgent:
self.after_work = transit
elif isinstance(transit, ON_CONDITION):

if transit.agent:
if isinstance(transit.target, SwarmAgent):
# Transition to agent

# Create closure with current loop transit value
# to ensure the condition matches the one in the loop
def make_transfer_function(current_transit: ON_CONDITION):
def transfer_to_agent() -> "SwarmAgent":
return current_transit.agent
return current_transit.target

return transfer_to_agent

transfer_func = make_transfer_function(transit)
self.add_single_function(transfer_func, f"transfer_to_{transit.agent.name}", transit.condition)
self.add_single_function(transfer_func, f"transfer_to_{transit.target.name}", transit.condition)

else:
elif isinstance(transit.target, Dict):
# Transition to a nested chat

# Create closure (see above note)
def make_transfer_nested_function(
chat_queue: List[Dict[str, Any]],
config: Optional[Any],
reply_func_from_nested_chats: Union[str, Callable],
use_async: bool,
):
# _reply_func = reply_func_from_nested_chats # Explicitly store parameter

def transfer_to_nested_chat() -> str:

# All messages, excluding the tool call message for swarm
base_messages = copy.deepcopy(self.chat_messages[self._groupchatmanager])
base_messages.pop()

# Note: This flow is based on ConversableAgent.register_nested_chats as we are doing this instead of registering a nested chat

if use_async:
for chat in chat_queue:
if chat.get("chat_id") is None:
raise ValueError("chat_id is required for async nested chats")

if use_async:
if callable(reply_func_from_nested_chats):
_reply_func = (
reply_func_from_nested_chats # Have to re-assign in this nested function
)
elif reply_func_from_nested_chats == "summary_from_nested_chats":
_reply_func = self._a_summary_from_nested_chats

if not callable(_reply_func) or not inspect.iscoroutinefunction(_reply_func):
raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine")

else:
if callable(reply_func_from_nested_chats):
_reply_func = (
reply_func_from_nested_chats # Have to re-assign in this nested function
)
elif reply_func_from_nested_chats == "summary_from_nested_chats":
_reply_func = self._summary_from_nested_chats
if not callable(_reply_func):
raise ValueError("reply_func_from_nested_chats must be a callable")

# Run the summary_from_nested_chats, or equivalent callable, to get the final output of the nested chat
# Recipient will be the SwarmAgent the function is registered to.
_, reply_str = _reply_func(
chat_queue=chat_queue,
recipient=self,
messages=base_messages,
sender=self._groupchatmanager,
config=config,
)

return reply_str

return transfer_to_nested_chat

# Extract the nested chat configuration
chat_queue = transit.nested_chat["chat_queue"]
config = transit.nested_chat.get("config", None)
config_reply_func_from_nested_chats = transit.nested_chat.get("reply_func_from_nested_chats", None)
if not config_reply_func_from_nested_chats:
config_reply_func_from_nested_chats = "summary_from_nested_chats"
use_async = transit.nested_chat.get("use_async", False)

# Make the function for the nested chat
transfer_func = make_transfer_nested_function(
chat_queue, config, config_reply_func_from_nested_chats, use_async
)

# Add the function to the agent so it can be triggered as a tool call
self.add_single_function(
transfer_func, f"transfer_to_nested_chat_{len(self._function_map)}", transit.condition
)
# We will store them here and establish them in the initiate_swarm_chat
self._nested_chat_handoffs.append({"nested_chats": transit.target, "condition": transit.condition})

else:
raise ValueError("Invalid hand off condition, must be either ON_CONDITION or AFTER_WORK")
Expand Down Expand Up @@ -564,10 +513,6 @@ def add_functions(self, func_list: List[Callable]):
for func in func_list:
self.add_single_function(func)

def associate_groupchat(self, groupchatmanager: GroupChatManager):
"""Associate the group chat with an agent so we can access overall messages and other agents"""
self._groupchatmanager = groupchatmanager

def get_swarm_context_variables(self) -> Dict[str, Any]:
"""Returns the context variables from the tool execution agent"""
for agent in self._groupchatmanager.groupchat.agents:
Expand All @@ -578,7 +523,11 @@ def get_swarm_context_variables(self) -> Dict[str, Any]:

@staticmethod
def process_nested_chat_carryover(
chat: Dict[str, Any], recipient: ConversableAgent, messages: List[Dict[str, Any]], sender: ConversableAgent
chat: Dict[str, Any],
recipient: ConversableAgent,
messages: List[Dict[str, Any]],
sender: ConversableAgent,
trim_n_messages: int = 0,
) -> None:
"""Process carryover messages for a nested chat (typically for the first chat of a swarm)
Expand All @@ -591,6 +540,13 @@ def process_nested_chat_carryover(
"last_msg" - the last message will be incorporated
"reflection_with_llm" - an llm will summarise all the messages and the summary will be incorporated as a single message
Callable - a callable with the signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str
Args:
chat: The chat dictionary containing the carryover configuration
recipient: The recipient agent
messages: The messages from the parent chat
sender: The sender agent
trim_n_messages: The number of latest messages to trim from the messages list
"""

def concat_carryover(chat_message: str, carryover_message: Union[str, List[Dict[str, Any]]]) -> str:
Expand Down Expand Up @@ -618,36 +574,38 @@ def concat_carryover(chat_message: str, carryover_message: Union[str, List[Dict[

chat_message = chat.get("message", "")

# deep copy and trim the latest messages
content_messages = copy.deepcopy(messages)
content_messages = content_messages[:-trim_n_messages]

if carryover_summary_method == "all":
# Put a string concatenated value of all parent messages into the first message
# (e.g. message = <first nested chat message>\nContext: \n<swarm message 1>\n<swarm message 2>\n...)
carry_over_message = concat_carryover(chat_message, messages)
carry_over_message = concat_carryover(chat_message, content_messages)

elif carryover_summary_method == "last_msg":
# (e.g. message = <first nested chat message>\nContext: \n<last swarm message>)
carry_over_message = concat_carryover(chat_message, messages[-1]["content"])
carry_over_message = concat_carryover(chat_message, content_messages[-1]["content"])

elif carryover_summary_method == "reflection_with_llm":
# If the last message is a tool call, we need to remove it (typical for Swarm as this is triggered by a tool call)
restore_tool_call = False
if "tool_calls" in recipient._oai_messages[sender][-1]:
last_tool_message = recipient._oai_messages[sender].pop()
restore_tool_call = True
# (e.g. message = <first nested chat message>\nContext: \n<llm summary>)

# Add the messages to the nested chat agent for reflection (we'll clear after reflection)
chat["recipient"]._oai_messages[sender] = content_messages

carry_over_message_llm = ConversableAgent._reflection_with_llm_as_summary(
sender=sender,
recipient=recipient,
recipient=chat["recipient"], # Chat recipient LLM config will be used for the reflection
summary_args=carryover_summary_args,
)

carry_over_message = concat_carryover(chat_message, carry_over_message_llm)
recipient._oai_messages[sender] = []

# Restore the tool call message
if restore_tool_call:
recipient._oai_messages[sender].append(last_tool_message)
carry_over_message = concat_carryover(chat_message, carry_over_message_llm)

elif isinstance(carryover_summary_method, Callable):
carry_over_message_result = carryover_summary_method(recipient, messages, carryover_summary_args)
# (e.g. message = <first nested chat message>\nContext: \n<function's return string>)
carry_over_message_result = carryover_summary_method(recipient, content_messages, carryover_summary_args)

carry_over_message = concat_carryover(chat_message, carry_over_message_result)

Expand All @@ -674,9 +632,9 @@ def _summary_from_nested_chats(
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
"""

# Carryover configuration allowed on the first chat in the queue only
# Carryover configuration allowed on the first chat in the queue only, trim the last two messages specifically for swarm nested chat carryover as these are the messages for the transition to the nested chat agent
if len(chat_queue) > 0 and "carryover_config" in chat_queue[0]:
SwarmAgent.process_nested_chat_carryover(chat_queue[0], recipient, messages, sender)
SwarmAgent.process_nested_chat_carryover(chat_queue[0], recipient, messages, sender, 2)

chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
if not chat_to_run:
Expand Down
10 changes: 5 additions & 5 deletions test/agentchat/contrib/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def test_on_condition():

# Test with a ConversableAgent
test_conversable_agent = ConversableAgent("test_conversable_agent")
with pytest.raises(AssertionError, match="'agent' must be a SwarmAgent"):
_ = ON_CONDITION(agent=test_conversable_agent, condition="test condition")
with pytest.raises(AssertionError, match="'target' must be a SwarmAgent or a Dict"):
_ = ON_CONDITION(target=test_conversable_agent, condition="test condition")


def test_receiving_agent():
Expand Down Expand Up @@ -245,7 +245,7 @@ def test_on_condition_handoff():
agent1 = SwarmAgent("agent1", llm_config=testing_llm_config)
agent2 = SwarmAgent("agent2", llm_config=testing_llm_config)

agent1.register_hand_off(hand_to=ON_CONDITION(agent=agent2, condition="always take me to agent 2"))
agent1.register_hand_off(hand_to=ON_CONDITION(target=agent2, condition="always take me to agent 2"))

# Fake generate_oai_reply
def mock_generate_oai_reply(*args, **kwargs):
Expand Down Expand Up @@ -428,8 +428,8 @@ def test_non_swarm_in_hand_off():
with pytest.raises(AssertionError, match="Invalid After Work value"):
agent1.register_hand_off(hand_to=AFTER_WORK(bad_agent))

with pytest.raises(AssertionError, match="'agent' must be a SwarmAgent"):
agent1.register_hand_off(hand_to=ON_CONDITION(agent=bad_agent, condition="Testing"))
with pytest.raises(AssertionError, match="'target' must be a SwarmAgent or a Dict"):
agent1.register_hand_off(hand_to=ON_CONDITION(target=bad_agent, condition="Testing"))

with pytest.raises(ValueError, match="hand_to must be a list of ON_CONDITION or AFTER_WORK"):
agent1.register_hand_off(0)
Expand Down

0 comments on commit 3cddfcb

Please sign in to comment.