Skip to content

Commit

Permalink
Implementation of fallbacks for human in the loop
Browse files Browse the repository at this point in the history
  • Loading branch information
marklysze committed Nov 19, 2024
1 parent a4f0907 commit 94730ec
Showing 1 changed file with 90 additions and 16 deletions.
106 changes: 90 additions & 16 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,12 @@
from inspect import signature
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

from openai.types.chat.chat_completion import ChatCompletion
from pydantic import BaseModel

from autogen.agentchat import Agent, ChatResult, ConversableAgent, GroupChat, GroupChatManager, UserProxyAgent
from autogen.function_utils import get_function_schema
from autogen.oai import OpenAIWrapper


def parse_json_object(response: str) -> dict:
return json.loads(response)


# Parameter name for context variables
# Use the value in functions and they will be substituted with the context variables:
# e.g. def my_function(context_variables: Dict[str, Any], my_other_parameters: Any) -> Any:
Expand All @@ -24,12 +18,43 @@ def initialize_swarm_chat(
init_agent: "SwarmAgent",
messages: Union[List[Dict[str, Any]], str],
agents: List["SwarmAgent"],
user_agent: Optional[UserProxyAgent] = None,
max_rounds: int = 20,
context_variables: Optional[Dict[str, Any]] = {},
context_variables: Optional[Dict[str, Any]] = None,
fallback_method: Union[Literal["TERMINATE", "REVERT_TO_USER", "STAY"], Callable] = "REVERT_TO_USER",
) -> Tuple[ChatResult, Dict[str, Any], "SwarmAgent"]:
"""Initialize and run a swarm chat
Args:
init_agent: The initial agent of the conversation.
messages: Initial message(s).
agents: List of swarm agents.
user_agent: Optional user proxy agent for falling back to.
max_rounds: Maximum number of conversation rounds.
context_variables: Starting context variables.
fallback_method: Method to handle conversation continuation when an agent doesn't select the next agent. This fallback_method is considered after the speaking agent's fallback_method. Default is "REVERT_TO_USER".
Could be any of the following (case insensitive):
- "TERMINATE": End the conversation if no next agent is selected
- "REVERT_TO_USER": Return to the passed in user_agent if no next agent is selected. Is equivalent to "TERMINATE" if no user_agent is passed in.
- "STAY": Stay with the current agent if no next agent is selected
- Callable: A custom function that takes the current agent, messages, groupchat, and context_variables as arguments and returns the next agent. The function should return None to terminate.
```python
def custom_fallback_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat, context_variables: Optional[Dict[str, Any]]) -> Optional[SwarmAgent]:
```
Returns:
ChatResult: Conversations chat history.
Dict[str, Any]: Updated Context variables.
SwarmAgent: Last speaker.
"""
context_variables = context_variables or {}
if isinstance(fallback_method, str):
fallback_method = fallback_method.upper()

if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

swarm_agent_names = [agent.name for agent in agents]

tool_execution = SwarmAgent(
name="Tool_Execution",
system_message="Tool Execution",
Expand All @@ -38,13 +63,16 @@ def initialize_swarm_chat(
is_tool_execution=True,
context_variables=context_variables,
)

# Update tool execution agent with all the functions from all the agents
for agent in agents:
tool_execution._function_map.update(agent._function_map)

INIT_AGENT_USED = False

def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
"""Swarm transition function to determine the next agent in the conversation"""

nonlocal INIT_AGENT_USED
if not INIT_AGENT_USED:
INIT_AGENT_USED = True
Expand All @@ -57,15 +85,48 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
tool_execution.next_agent = None
return next_agent

# No next agent has been selected, if last agent is the tool executor,
# we need to go back to the agent before this last tool execution
if last_speaker == tool_execution:
return groupchat.agent_by_name(name=messages[-2].get("name", ""))
last_swarm_speaker = get_last_swarm_speaker()

# If the user last spoke, return to the agent prior
if user_agent and last_speaker == user_agent:
return last_swarm_speaker

# No agent selected via hand-offs (tool calls)
# Check the agent's fallback method
if last_swarm_speaker.fallback_method:
next_agent = get_fallback_agent(last_swarm_speaker.fallback_method, last_swarm_speaker)

if next_agent is not None:
return next_agent

return None
# Check the swarm's fallback method
# Returns None, ending swarm, if still no agent selected
return get_fallback_agent(fallback_method, last_swarm_speaker)

def get_last_swarm_speaker() -> "SwarmAgent":
"""Get the last swarm agent that spoke in the message history"""
for message in reversed(messages):
if "name" in message and message["name"] in swarm_agent_names:
agent = groupchat.agent_by_name(name=message["name"])
if isinstance(agent, SwarmAgent):
return agent

raise ValueError("No swarm agent found in the message history")

def get_fallback_agent(method: Union[str, Callable], agent: "SwarmAgent") -> Optional["SwarmAgent"]:
"""Get the next agent based on the fallback method"""

if method == "TERMINATE" or method == "REVERT_TO_USER" and not user_agent:
return None
elif method == "REVERT_TO_USER":
return user_agent
elif method == "STAY":
return agent
elif callable(method):
return method(agent, messages, groupchat, context_variables)

groupchat = GroupChat(
agents=[tool_execution] + agents,
agents=[tool_execution] + agents + ([user_agent] if user_agent is not None else []),
messages=messages,
max_round=max_rounds,
speaker_selection_method=swarm_transition,
Expand Down Expand Up @@ -93,15 +154,15 @@ class SwarmResult(BaseModel):
"""
Encapsulates the possible return values for a swarm agent function.
arguments:
Args:
values (str): The result values as a string.
agent (SwarmAgent): The swarm agent instance, if applicable.
context_variables (dict): A dictionary of context variables.
"""

values: str = ""
agent: Optional["SwarmAgent"] = None
context_variables: dict = {}
context_variables: Dict[str, Any] = {}

class Config: # Add this inner class
arbitrary_types_allowed = True
Expand All @@ -111,6 +172,17 @@ def __str__(self):


class SwarmAgent(ConversableAgent):
"""Swarm agent for participating in a swarm.
SwarmAgent is a subclass of ConversableAgent.
Additional args:
context_variables (dict): A dictionary of context variables.
is_tool_execution (bool): A flag to indicate if the agent is a tool execution agent.
fallback_method (str or Callable): Method to handle conversation continuation when an agent doesn't select the next agent. Default is None.
"""

def __init__(
self,
name: str,
Expand All @@ -123,6 +195,7 @@ def __init__(
description: Optional[str] = None,
context_variables: Optional[Dict[str, Any]] = None,
is_tool_execution: Optional[bool] = False,
fallback_method: Optional[Union[Literal["TERMINATE", "REVERT_TO_USER", "STAY"], Callable]] = None,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -145,11 +218,12 @@ def __init__(
self._reply_func_list.clear()
self.register_reply([Agent, None], SwarmAgent.generate_swarm_tool_reply)

self.fallback_method = fallback_method
self.context_variables = context_variables or {}
self.next_agent = None # use in the tool execution agent to transfer to the next agent

def update_context_variables(self, context_variables: Dict[str, Any]) -> None:
pass
self.context_variables.update(context_variables)

def __str__(self):
return f"SwarmAgent: {self.name}"
Expand Down

0 comments on commit 94730ec

Please sign in to comment.