From 3cddfcb970ebe6bae9cea83cb4232c2d93e98df7 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 2 Dec 2024 23:20:59 +0000 Subject: [PATCH] Refactoring, nested chats now create agents for each Signed-off-by: Mark Sze --- autogen/agentchat/contrib/swarm_agent.py | 212 +++++++++-------------- test/agentchat/contrib/test_swarm.py | 10 +- 2 files changed, 90 insertions(+), 132 deletions(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 1a8eee6341..425baa18dd 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -43,25 +43,15 @@ def __post_init__(self): @dataclass class ON_CONDITION: - agent: Optional["SwarmAgent"] = None - nested_chat: Optional[Dict[str, Any]] = None + target: Union["SwarmAgent", Dict[str, Any]] = None condition: str = "" def __post_init__(self): # Ensure valid types - if self.agent is not None: - assert isinstance(self.agent, SwarmAgent), "'agent' must be a SwarmAgent" - - if self.nested_chat is not None: - assert isinstance(self.nested_chat, Dict), "'nested_chat' must be a Dict" - - # Ensure they have an agent or nested_chat - assert self.agent is not None or self.nested_chat is not None, "'agent' or 'nested_chat' must be provided" - - # Ensure they don't have both an agent and a nested_chat - assert not ( - self.agent is not None and self.nested_chat is not None - ), "'agent' and 'nested_chat' cannot both be provided" + if self.target is not None: + assert isinstance(self.target, SwarmAgent) or isinstance( + self.target, Dict + ), "'target' must be a SwarmAgent or a Dict" # Ensure they have a condition assert isinstance(self.condition, str) and self.condition.strip(), "'condition' must be a non-empty string" @@ -113,18 +103,12 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any if isinstance(messages, str): messages = [{"role": "user", "content": messages}] - swarm_agent_names = [agent.name for agent in agents] - tool_execution = SwarmAgent( name="Tool_Execution", system_message="Tool Execution", ) tool_execution._set_to_tool_execution(context_variables=context_variables) - # Update tool execution agent with all the functions from all the agents - for agent in agents: - tool_execution._function_map.update(agent._function_map) - INIT_AGENT_USED = False def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat): @@ -179,6 +163,43 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat): else: raise ValueError("Invalid After Work condition") + def create_nested_chats(agent: SwarmAgent, nested_chat_agents: List[SwarmAgent]): + """Create nested chat agents and register nested chats""" + for i, nested_chat_handoff in enumerate(agent._nested_chat_handoffs): + nested_chats: Dict[str, Any] = nested_chat_handoff["nested_chats"] + condition = nested_chat_handoff["condition"] + + # Create a nested chat agent specifically for this nested chat + nested_chat_agent = SwarmAgent(name=f"nested_chat_{agent.name}_{i + 1}") + + nested_chat_agent.register_nested_chats( + nested_chats["chat_queue"], + reply_func_from_nested_chats=nested_chats.get("reply_func_from_nested_chats") + or "summary_from_nested_chats", + config=nested_chats.get("config", None), + trigger=lambda sender: True, + position=0, + use_async=nested_chats.get("use_async", False), + ) + + # After the nested chat is complete, transfer back to the parent agent + nested_chat_agent.register_hand_off(AFTER_WORK(agent=agent)) + + nested_chat_agents.append(nested_chat_agent) + + # Nested chat is triggered through an agent transfer to this nested chat agent + agent.register_hand_off(ON_CONDITION(nested_chat_agent, condition)) + + nested_chat_agents = [] + for agent in agents: + create_nested_chats(agent, nested_chat_agents) + + # Update tool execution agent with all the functions from all the agents + for agent in agents + nested_chat_agents: + tool_execution._function_map.update(agent._function_map) + + swarm_agent_names = [agent.name for agent in agents + nested_chat_agents] + # If there's only one message and there's no identified swarm agent # Start with a user proxy agent, creating one if they haven't passed one in if len(messages) == 1 and "name" not in messages[0] and not user_agent: @@ -187,7 +208,10 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat): temp_user_proxy = [] groupchat = GroupChat( - agents=[tool_execution] + agents + ([user_agent] if user_agent is not None else temp_user_proxy), + agents=[tool_execution] + + agents + + nested_chat_agents + + ([user_agent] if user_agent is not None else temp_user_proxy), messages=[], # Set to empty. We will resume the conversation with the messages max_round=max_rounds, speaker_selection_method=swarm_transition, @@ -195,12 +219,6 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat): manager = GroupChatManager(groupchat) clear_history = True - # We associate the groupchat manager with SwarmAgents - # to be able to access group messages, tool executor context variables - for agent in agents: - if isinstance(agent, SwarmAgent): - agent.associate_groupchat(manager) - if len(messages) > 1: last_agent, last_message = manager.resume(messages=messages) clear_history = False @@ -311,6 +329,10 @@ def __init__( self._context_variables = {} self._next_agent = None + # Store nested chats hand offs as we'll establish these in the initiate_swarm_chat + # List of Dictionaries containing the nested_chats and condition + self._nested_chat_handoffs = [] + def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None): """Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents. This agent will be used internally and should not be visible to the user. @@ -355,97 +377,24 @@ def transfer_to_agent_name() -> SwarmAgent: self.after_work = transit elif isinstance(transit, ON_CONDITION): - if transit.agent: + if isinstance(transit.target, SwarmAgent): # Transition to agent # Create closure with current loop transit value # to ensure the condition matches the one in the loop def make_transfer_function(current_transit: ON_CONDITION): def transfer_to_agent() -> "SwarmAgent": - return current_transit.agent + return current_transit.target return transfer_to_agent transfer_func = make_transfer_function(transit) - self.add_single_function(transfer_func, f"transfer_to_{transit.agent.name}", transit.condition) + self.add_single_function(transfer_func, f"transfer_to_{transit.target.name}", transit.condition) - else: + elif isinstance(transit.target, Dict): # Transition to a nested chat - - # Create closure (see above note) - def make_transfer_nested_function( - chat_queue: List[Dict[str, Any]], - config: Optional[Any], - reply_func_from_nested_chats: Union[str, Callable], - use_async: bool, - ): - # _reply_func = reply_func_from_nested_chats # Explicitly store parameter - - def transfer_to_nested_chat() -> str: - - # All messages, excluding the tool call message for swarm - base_messages = copy.deepcopy(self.chat_messages[self._groupchatmanager]) - base_messages.pop() - - # Note: This flow is based on ConversableAgent.register_nested_chats as we are doing this instead of registering a nested chat - - if use_async: - for chat in chat_queue: - if chat.get("chat_id") is None: - raise ValueError("chat_id is required for async nested chats") - - if use_async: - if callable(reply_func_from_nested_chats): - _reply_func = ( - reply_func_from_nested_chats # Have to re-assign in this nested function - ) - elif reply_func_from_nested_chats == "summary_from_nested_chats": - _reply_func = self._a_summary_from_nested_chats - - if not callable(_reply_func) or not inspect.iscoroutinefunction(_reply_func): - raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine") - - else: - if callable(reply_func_from_nested_chats): - _reply_func = ( - reply_func_from_nested_chats # Have to re-assign in this nested function - ) - elif reply_func_from_nested_chats == "summary_from_nested_chats": - _reply_func = self._summary_from_nested_chats - if not callable(_reply_func): - raise ValueError("reply_func_from_nested_chats must be a callable") - - # Run the summary_from_nested_chats, or equivalent callable, to get the final output of the nested chat - # Recipient will be the SwarmAgent the function is registered to. - _, reply_str = _reply_func( - chat_queue=chat_queue, - recipient=self, - messages=base_messages, - sender=self._groupchatmanager, - config=config, - ) - - return reply_str - - return transfer_to_nested_chat - - # Extract the nested chat configuration - chat_queue = transit.nested_chat["chat_queue"] - config = transit.nested_chat.get("config", None) - config_reply_func_from_nested_chats = transit.nested_chat.get("reply_func_from_nested_chats", None) - if not config_reply_func_from_nested_chats: - config_reply_func_from_nested_chats = "summary_from_nested_chats" - use_async = transit.nested_chat.get("use_async", False) - - # Make the function for the nested chat - transfer_func = make_transfer_nested_function( - chat_queue, config, config_reply_func_from_nested_chats, use_async - ) - - # Add the function to the agent so it can be triggered as a tool call - self.add_single_function( - transfer_func, f"transfer_to_nested_chat_{len(self._function_map)}", transit.condition - ) + # We will store them here and establish them in the initiate_swarm_chat + self._nested_chat_handoffs.append({"nested_chats": transit.target, "condition": transit.condition}) else: raise ValueError("Invalid hand off condition, must be either ON_CONDITION or AFTER_WORK") @@ -564,10 +513,6 @@ def add_functions(self, func_list: List[Callable]): for func in func_list: self.add_single_function(func) - def associate_groupchat(self, groupchatmanager: GroupChatManager): - """Associate the group chat with an agent so we can access overall messages and other agents""" - self._groupchatmanager = groupchatmanager - def get_swarm_context_variables(self) -> Dict[str, Any]: """Returns the context variables from the tool execution agent""" for agent in self._groupchatmanager.groupchat.agents: @@ -578,7 +523,11 @@ def get_swarm_context_variables(self) -> Dict[str, Any]: @staticmethod def process_nested_chat_carryover( - chat: Dict[str, Any], recipient: ConversableAgent, messages: List[Dict[str, Any]], sender: ConversableAgent + chat: Dict[str, Any], + recipient: ConversableAgent, + messages: List[Dict[str, Any]], + sender: ConversableAgent, + trim_n_messages: int = 0, ) -> None: """Process carryover messages for a nested chat (typically for the first chat of a swarm) @@ -591,6 +540,13 @@ def process_nested_chat_carryover( "last_msg" - the last message will be incorporated "reflection_with_llm" - an llm will summarise all the messages and the summary will be incorporated as a single message Callable - a callable with the signature: my_method(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str + + Args: + chat: The chat dictionary containing the carryover configuration + recipient: The recipient agent + messages: The messages from the parent chat + sender: The sender agent + trim_n_messages: The number of latest messages to trim from the messages list """ def concat_carryover(chat_message: str, carryover_message: Union[str, List[Dict[str, Any]]]) -> str: @@ -618,36 +574,38 @@ def concat_carryover(chat_message: str, carryover_message: Union[str, List[Dict[ chat_message = chat.get("message", "") + # deep copy and trim the latest messages + content_messages = copy.deepcopy(messages) + content_messages = content_messages[:-trim_n_messages] + if carryover_summary_method == "all": # Put a string concatenated value of all parent messages into the first message # (e.g. message = \nContext: \n\n\n...) - carry_over_message = concat_carryover(chat_message, messages) + carry_over_message = concat_carryover(chat_message, content_messages) elif carryover_summary_method == "last_msg": # (e.g. message = \nContext: \n) - carry_over_message = concat_carryover(chat_message, messages[-1]["content"]) + carry_over_message = concat_carryover(chat_message, content_messages[-1]["content"]) elif carryover_summary_method == "reflection_with_llm": - # If the last message is a tool call, we need to remove it (typical for Swarm as this is triggered by a tool call) - restore_tool_call = False - if "tool_calls" in recipient._oai_messages[sender][-1]: - last_tool_message = recipient._oai_messages[sender].pop() - restore_tool_call = True + # (e.g. message = \nContext: \n) + + # Add the messages to the nested chat agent for reflection (we'll clear after reflection) + chat["recipient"]._oai_messages[sender] = content_messages carry_over_message_llm = ConversableAgent._reflection_with_llm_as_summary( sender=sender, - recipient=recipient, + recipient=chat["recipient"], # Chat recipient LLM config will be used for the reflection summary_args=carryover_summary_args, ) - carry_over_message = concat_carryover(chat_message, carry_over_message_llm) + recipient._oai_messages[sender] = [] - # Restore the tool call message - if restore_tool_call: - recipient._oai_messages[sender].append(last_tool_message) + carry_over_message = concat_carryover(chat_message, carry_over_message_llm) elif isinstance(carryover_summary_method, Callable): - carry_over_message_result = carryover_summary_method(recipient, messages, carryover_summary_args) + # (e.g. message = \nContext: \n) + carry_over_message_result = carryover_summary_method(recipient, content_messages, carryover_summary_args) carry_over_message = concat_carryover(chat_message, carry_over_message_result) @@ -674,9 +632,9 @@ def _summary_from_nested_chats( Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. """ - # Carryover configuration allowed on the first chat in the queue only + # Carryover configuration allowed on the first chat in the queue only, trim the last two messages specifically for swarm nested chat carryover as these are the messages for the transition to the nested chat agent if len(chat_queue) > 0 and "carryover_config" in chat_queue[0]: - SwarmAgent.process_nested_chat_carryover(chat_queue[0], recipient, messages, sender) + SwarmAgent.process_nested_chat_carryover(chat_queue[0], recipient, messages, sender, 2) chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config) if not chat_to_run: diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 4a63ac477d..3ec4f69a55 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -79,8 +79,8 @@ def test_on_condition(): # Test with a ConversableAgent test_conversable_agent = ConversableAgent("test_conversable_agent") - with pytest.raises(AssertionError, match="'agent' must be a SwarmAgent"): - _ = ON_CONDITION(agent=test_conversable_agent, condition="test condition") + with pytest.raises(AssertionError, match="'target' must be a SwarmAgent or a Dict"): + _ = ON_CONDITION(target=test_conversable_agent, condition="test condition") def test_receiving_agent(): @@ -245,7 +245,7 @@ def test_on_condition_handoff(): agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) agent2 = SwarmAgent("agent2", llm_config=testing_llm_config) - agent1.register_hand_off(hand_to=ON_CONDITION(agent=agent2, condition="always take me to agent 2")) + agent1.register_hand_off(hand_to=ON_CONDITION(target=agent2, condition="always take me to agent 2")) # Fake generate_oai_reply def mock_generate_oai_reply(*args, **kwargs): @@ -428,8 +428,8 @@ def test_non_swarm_in_hand_off(): with pytest.raises(AssertionError, match="Invalid After Work value"): agent1.register_hand_off(hand_to=AFTER_WORK(bad_agent)) - with pytest.raises(AssertionError, match="'agent' must be a SwarmAgent"): - agent1.register_hand_off(hand_to=ON_CONDITION(agent=bad_agent, condition="Testing")) + with pytest.raises(AssertionError, match="'target' must be a SwarmAgent or a Dict"): + agent1.register_hand_off(hand_to=ON_CONDITION(target=bad_agent, condition="Testing")) with pytest.raises(ValueError, match="hand_to must be a list of ON_CONDITION or AFTER_WORK"): agent1.register_hand_off(0)