Skip to content

Commit

Permalink
Initial commit - conditional ON_CONDITION with 'available'
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Sze <[email protected]>
  • Loading branch information
marklysze committed Dec 12, 2024
1 parent 4fe28fc commit b5d3d98
Showing 1 changed file with 71 additions and 8 deletions.
79 changes: 71 additions & 8 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
# e.g. def my_function(context_variables: Dict[str, Any], my_other_parameters: Any) -> Any:
__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables"

__TOOL_EXECUTOR_NAME__ = "Tool_Execution"


class AfterWorkOption(Enum):
TERMINATE = "TERMINATE"
Expand All @@ -45,6 +47,7 @@ def __post_init__(self):
class ON_CONDITION:
target: Union["SwarmAgent", Dict[str, Any]] = None
condition: str = ""
available: Optional[Union[Callable, str]] = None

def __post_init__(self):
# Ensure valid types
Expand All @@ -56,6 +59,9 @@ def __post_init__(self):
# Ensure they have a condition
assert isinstance(self.condition, str) and self.condition.strip(), "'condition' must be a non-empty string"

if self.available is not None:
assert isinstance(self.available, (Callable, str)), "'available' must be a callable or a string"


def initiate_swarm_chat(
initial_agent: "SwarmAgent",
Expand Down Expand Up @@ -104,10 +110,19 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any
messages = [{"role": "user", "content": messages}]

tool_execution = SwarmAgent(
name="Tool_Execution",
name=__TOOL_EXECUTOR_NAME__,
system_message="Tool Execution",
)
tool_execution._set_to_tool_execution(context_variables=context_variables)
tool_execution._set_to_tool_execution()

# Update tool execution agent with all the functions from all the agents
for agent in agents:
tool_execution._function_map.update(agent._function_map)

# Point all SwarmAgent's context variables to this function's context_variables
# providing a single (shared) context across all SwarmAgents in the swarm
for agent in agents + [tool_execution]:
agent._context_variables = context_variables

INIT_AGENT_USED = False

Expand Down Expand Up @@ -209,6 +224,10 @@ def create_nested_chats(agent: SwarmAgent, nested_chat_agents: List[SwarmAgent])
for agent in agents + nested_chat_agents:
tool_execution._function_map.update(agent._function_map)

# Add conditional functions to the tool_execution agent
for func_name, (func, on_condition) in agent._conditional_functions.items():
tool_execution._function_map[func_name] = func

swarm_agent_names = [agent.name for agent in agents + nested_chat_agents]

# If there's only one message and there's no identified swarm agent
Expand Down Expand Up @@ -335,23 +354,31 @@ def __init__(

self.after_work = None

# Used only in the tool execution agent for context and transferring to the next agent
# Note: context variables are not stored for each agent
self._context_variables = {}
# Used in the tool execution agent to transfer to the next agent
self._next_agent = None

# Store nested chats hand offs as we'll establish these in the initiate_swarm_chat
# List of Dictionaries containing the nested_chats and condition
self._nested_chat_handoffs = []

# Store conditional functions (and their ON_CONDITION instances) to add/remove later when transitioning to this agent
self._conditional_functions = {}

# Register the hook to update agent state (except tool executor)
if name != __TOOL_EXECUTOR_NAME__:
self.register_hook("update_agent_state", self._update_agent_state_hook)

def _update_agent_state_hook(self, agent: Agent, messages: Optional[List[Dict]] = None) -> None:
"""Hook to update the agent state and update conditional functions."""
self._update_conditional_functions()

def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None):
"""Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents.
This agent will be used internally and should not be visible to the user.
It will execute the tool calls and update the context_variables and next_agent accordingly.
It will execute the tool calls and referenced context_variables and next_agent accordingly.
"""
self._next_agent = None
self._context_variables = context_variables or {}
self._reply_func_list.clear()
self.register_reply([Agent, None], SwarmAgent.generate_swarm_tool_reply)

Expand Down Expand Up @@ -400,7 +427,18 @@ def transfer_to_agent() -> "SwarmAgent":
return transfer_to_agent

transfer_func = make_transfer_function(transit)
self.add_single_function(transfer_func, f"transfer_to_{transit.target.name}", transit.condition)

# Store function to add/remove later based on it being 'available'
# Function names are made unique and allow multiple ON_CONDITIONS to the same agent
base_func_name = f"transfer_from_{self.name}_to_{transit.target.name}"
func_name = base_func_name
count = 2
while func_name in self._conditional_functions:
func_name = f"{base_func_name}_{count}"
count += 1

# Store function to add/remove later based on it being 'available'
self._conditional_functions[func_name] = (transfer_func, transit)

elif isinstance(transit.target, Dict):
# Transition to a nested chat
Expand All @@ -410,6 +448,31 @@ def transfer_to_agent() -> "SwarmAgent":
else:
raise ValueError("Invalid hand off condition, must be either ON_CONDITION or AFTER_WORK")

def _update_conditional_functions(self):
"""Updates the agent's functions based on the ON_CONDITION's available condition."""
for func_name, (func, on_condition) in self._conditional_functions.items():
is_available = False

if on_condition.available is not None:
if isinstance(on_condition.available, Callable):
is_available = on_condition.available(self, next(iter(self.chat_messages.values())))
elif isinstance(on_condition.available, str):
is_available = self.get_context(on_condition.available) or False
else:
is_available = True

print(f"DEBUG INFO: Function {func_name} available? {'Yes' if is_available else 'No'}")

if is_available:
if func_name not in self._function_map:
self.add_single_function(func, func_name, on_condition.condition)
else:
# Remove function using the stored name
if func_name in self._function_map:
self.update_tool_signature(func_name, is_remove=True)
del self._function_map[func_name]
print(f"Removed function: {func_name} from {self.name}")

def generate_swarm_tool_reply(
self,
messages: Optional[List[Dict]] = None,
Expand Down

0 comments on commit b5d3d98

Please sign in to comment.