diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 4e084377c4..56dabd3f27 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -25,6 +25,8 @@ # e.g. def my_function(context_variables: Dict[str, Any], my_other_parameters: Any) -> Any: __CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables" +__TOOL_EXECUTOR_NAME__ = "Tool_Execution" + class AfterWorkOption(Enum): TERMINATE = "TERMINATE" @@ -45,6 +47,7 @@ def __post_init__(self): class ON_CONDITION: target: Union["SwarmAgent", Dict[str, Any]] = None condition: str = "" + available: Optional[Union[Callable, str]] = None def __post_init__(self): # Ensure valid types @@ -56,6 +59,9 @@ def __post_init__(self): # Ensure they have a condition assert isinstance(self.condition, str) and self.condition.strip(), "'condition' must be a non-empty string" + if self.available is not None: + assert isinstance(self.available, (Callable, str)), "'available' must be a callable or a string" + def initiate_swarm_chat( initial_agent: "SwarmAgent", @@ -104,10 +110,19 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any messages = [{"role": "user", "content": messages}] tool_execution = SwarmAgent( - name="Tool_Execution", + name=__TOOL_EXECUTOR_NAME__, system_message="Tool Execution", ) - tool_execution._set_to_tool_execution(context_variables=context_variables) + tool_execution._set_to_tool_execution() + + # Update tool execution agent with all the functions from all the agents + for agent in agents: + tool_execution._function_map.update(agent._function_map) + + # Point all SwarmAgent's context variables to this function's context_variables + # providing a single (shared) context across all SwarmAgents in the swarm + for agent in agents + [tool_execution]: + agent._context_variables = context_variables INIT_AGENT_USED = False @@ -209,6 +224,10 @@ def create_nested_chats(agent: SwarmAgent, nested_chat_agents: List[SwarmAgent]) for agent in agents + nested_chat_agents: tool_execution._function_map.update(agent._function_map) + # Add conditional functions to the tool_execution agent + for func_name, (func, on_condition) in agent._conditional_functions.items(): + tool_execution._function_map[func_name] = func + swarm_agent_names = [agent.name for agent in agents + nested_chat_agents] # If there's only one message and there's no identified swarm agent @@ -335,23 +354,31 @@ def __init__( self.after_work = None - # Used only in the tool execution agent for context and transferring to the next agent - # Note: context variables are not stored for each agent - self._context_variables = {} + # Used in the tool execution agent to transfer to the next agent self._next_agent = None # Store nested chats hand offs as we'll establish these in the initiate_swarm_chat # List of Dictionaries containing the nested_chats and condition self._nested_chat_handoffs = [] + # Store conditional functions (and their ON_CONDITION instances) to add/remove later when transitioning to this agent + self._conditional_functions = {} + + # Register the hook to update agent state (except tool executor) + if name != __TOOL_EXECUTOR_NAME__: + self.register_hook("update_agent_state", self._update_agent_state_hook) + + def _update_agent_state_hook(self, agent: Agent, messages: Optional[List[Dict]] = None) -> None: + """Hook to update the agent state and update conditional functions.""" + self._update_conditional_functions() + def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None): """Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents. This agent will be used internally and should not be visible to the user. - It will execute the tool calls and update the context_variables and next_agent accordingly. + It will execute the tool calls and referenced context_variables and next_agent accordingly. """ self._next_agent = None - self._context_variables = context_variables or {} self._reply_func_list.clear() self.register_reply([Agent, None], SwarmAgent.generate_swarm_tool_reply) @@ -400,7 +427,18 @@ def transfer_to_agent() -> "SwarmAgent": return transfer_to_agent transfer_func = make_transfer_function(transit) - self.add_single_function(transfer_func, f"transfer_to_{transit.target.name}", transit.condition) + + # Store function to add/remove later based on it being 'available' + # Function names are made unique and allow multiple ON_CONDITIONS to the same agent + base_func_name = f"transfer_from_{self.name}_to_{transit.target.name}" + func_name = base_func_name + count = 2 + while func_name in self._conditional_functions: + func_name = f"{base_func_name}_{count}" + count += 1 + + # Store function to add/remove later based on it being 'available' + self._conditional_functions[func_name] = (transfer_func, transit) elif isinstance(transit.target, Dict): # Transition to a nested chat @@ -410,6 +448,31 @@ def transfer_to_agent() -> "SwarmAgent": else: raise ValueError("Invalid hand off condition, must be either ON_CONDITION or AFTER_WORK") + def _update_conditional_functions(self): + """Updates the agent's functions based on the ON_CONDITION's available condition.""" + for func_name, (func, on_condition) in self._conditional_functions.items(): + is_available = False + + if on_condition.available is not None: + if isinstance(on_condition.available, Callable): + is_available = on_condition.available(self, next(iter(self.chat_messages.values()))) + elif isinstance(on_condition.available, str): + is_available = self.get_context(on_condition.available) or False + else: + is_available = True + + print(f"DEBUG INFO: Function {func_name} available? {'Yes' if is_available else 'No'}") + + if is_available: + if func_name not in self._function_map: + self.add_single_function(func, func_name, on_condition.condition) + else: + # Remove function using the stored name + if func_name in self._function_map: + self.update_tool_signature(func_name, is_remove=True) + del self._function_map[func_name] + print(f"Removed function: {func_name} from {self.name}") + def generate_swarm_tool_reply( self, messages: Optional[List[Dict]] = None,