diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index d2f8763281..840da79204 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -659,6 +659,9 @@ def _append_oai_message(self, message: Union[Dict, str], role, conversation_id: if message.get("role") in ["function", "tool"]: oai_message["role"] = message.get("role") + if "tool_responses" in oai_message: + for tool_response in oai_message["tool_responses"]: + tool_response["content"] = str(tool_response["content"]) elif "override_role" in message: # If we have a direction to override the role then set the # role accordingly. Used to customise the role for the @@ -791,15 +794,16 @@ async def a_send( "Message can't be converted into a valid ChatCompletion message. Either content or function_call must be provided." ) - def _print_received_message(self, message: Union[Dict, str], sender: Agent): + def _print_received_message(self, message: Union[Dict, str], sender: Agent, skip_head: bool = False): iostream = IOStream.get_default() # print the message received - iostream.print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True) + if not skip_head: + iostream.print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True) message = self._message_to_dict(message) if message.get("tool_responses"): # Handle tool multi-call responses for tool_response in message["tool_responses"]: - self._print_received_message(tool_response, sender) + self._print_received_message(tool_response, sender, skip_head=True) if message.get("role") == "tool": return # If role is tool, then content is just a concatenation of all tool_responses @@ -2288,7 +2292,7 @@ def _format_json_str(jstr): result.append(char) return "".join(result) - def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict[str, str]]: + def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict[str, Any]]: """Execute a function call and return the result. Override this function to modify the way to execute function and tool calls. @@ -2342,7 +2346,7 @@ def execute_function(self, func_call, verbose: bool = False) -> Tuple[bool, Dict return is_exec_success, { "name": func_name, "role": "function", - "content": str(content), + "content": content, } async def a_execute_function(self, func_call): @@ -2397,7 +2401,7 @@ async def a_execute_function(self, func_call): return is_exec_success, { "name": func_name, "role": "function", - "content": str(content), + "content": content, } def generate_init_message(self, message: Union[Dict, str, None], **kwargs) -> Union[str, Dict]: diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index 4e9e107f92..9250d18612 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -11,7 +11,7 @@ import re import sys from dataclasses import dataclass, field -from typing import Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union from ..code_utils import content_str from ..exception_utils import AgentNameConflict, NoEligibleSpeaker, UndefinedNextAgent @@ -23,6 +23,8 @@ from .agent import Agent from .contrib.capabilities import transform_messages from .conversable_agent import ConversableAgent +from .swarm import SwarmAgent, SwarmResult +from .user_proxy_agent import UserProxyAgent logger = logging.getLogger(__name__) @@ -68,6 +70,7 @@ class GroupChat: - "manual": the next speaker is selected manually by user input. - "random": the next speaker is selected randomly. - "round_robin": the next speaker is selected in a round robin fashion, i.e., iterating in the same order as provided in `agents`. + - "swarm": utilises the swarm pattern, where agents continue to speak until they handoff to another agent. - a customized speaker selection function (Callable): the function will be called to select the next speaker. The function should take the last speaker and the group chat as input and return one of the following: 1. an `Agent` class, it must be one of the agents in the group chat. @@ -109,6 +112,7 @@ def custom_speaker_selection_func( - select_speaker_auto_model_client_cls: Custom model client class for the internal speaker select agent used during 'auto' speaker selection (optional) - select_speaker_auto_llm_config: LLM config for the internal speaker select agent used during 'auto' speaker selection (optional) - role_for_select_speaker_messages: sets the role name for speaker selection when in 'auto' mode, typically 'user' or 'system'. (default: 'system') + - context_variables: dictionary of context variables for use with swarm-based group chats """ agents: List[Agent] @@ -116,7 +120,7 @@ def custom_speaker_selection_func( max_round: int = 10 admin_name: str = "Admin" func_call_filter: bool = True - speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin"], Callable] = "auto" + speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin", "swarm"], Callable] = "auto" max_retries_for_selecting_speaker: int = 2 allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None allowed_or_disallowed_speaker_transitions: Optional[Dict] = None @@ -148,8 +152,9 @@ def custom_speaker_selection_func( select_speaker_auto_model_client_cls: Optional[Union[ModelClient, List[ModelClient]]] = None select_speaker_auto_llm_config: Optional[Union[Dict, Literal[False]]] = None role_for_select_speaker_messages: Optional[str] = "system" + context_variables: Optional[Dict] = field(default_factory=dict) - _VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"] + _VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin", "swarm"] _VALID_SPEAKER_TRANSITIONS_TYPE = ["allowed", "disallowed", None] # Define a class attribute for the default introduction message @@ -276,6 +281,13 @@ def __post_init__(self): if self.select_speaker_auto_verbose is None or not isinstance(self.select_speaker_auto_verbose, bool): raise ValueError("select_speaker_auto_verbose cannot be None or non-bool") + # Ensure, for swarms, all agents are swarm agents + if self.speaker_selection_method == "swarm": + """MS TEMP REMOVE + if not all(isinstance(agent, SwarmAgent) for agent in self.agents): + raise ValueError("All agents must be of type SwarmAgent when using the 'swarm' speaker selection method.") + """ + @property def agent_names(self) -> List[str]: """Return the names of the agents in the group chat.""" @@ -419,6 +431,44 @@ def random_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[A agents = self.agents return random.choice(agents) + def swarm_select_speaker(self, last_speaker: Agent, agents: Optional[List[Agent]] = None) -> Union[Agent, None]: + """Select the next speaker using the swarm pattern. Note that this does not need to cater for when the agent is continuing to speak.""" + messages = self.messages + user_agent = None + for agent in agents: + if isinstance(agent, UserProxyAgent): + user_agent = agent + break + + if user_agent is None: + raise ValueError("We need a UserProxyAgent to continue the conversation.") + + # Always start with the first speaker + if len(messages) <= 1: + if last_speaker == user_agent: + for agent in agents: + if isinstance(agent, SwarmAgent): + return agent + return user_agent + last_message = messages[-1] + # If the last message is a TRANSFER message, extract agent name and return them + if last_message["role"] == "tool": + if "content" in last_message and last_message["content"].startswith("TRANSFER:"): + agent_name = last_message["content"].split(":")[1].strip() + if self.agent_by_name(name=agent_name): + return self.agent_by_name(agent_name) + else: + # if the agent just call a tool and not transfer, return the last speaker + return last_speaker + + if isinstance(last_speaker, SwarmAgent): + return user_agent + elif isinstance(last_speaker, UserProxyAgent): + return self.agent_by_name(name=messages[-2].get("name", "")) + + # Otherwise, return the agent before the previous one + raise ValueError("Something wrong with the speaker selection.") + def _prepare_and_select_agents( self, last_speaker: Agent, @@ -466,7 +516,9 @@ def _prepare_and_select_agents( f"GroupChat is underpopulated with {n_agents} agents. " "Please add more agents to the GroupChat or use direct communication instead." ) - elif n_agents == 2 and speaker_selection_method.lower() != "round_robin" and allow_repeat_speaker: + elif ( + n_agents == 2 and speaker_selection_method.lower() not in ["round_robin", "swarm"] and allow_repeat_speaker + ): logger.warning( f"GroupChat is underpopulated with {n_agents} agents. " "Consider setting speaker_selection_method to 'round_robin' or allow_repeat_speaker to False, " @@ -536,6 +588,8 @@ def _prepare_and_select_agents( selected_agent = self.next_agent(last_speaker, graph_eligible_agents) elif speaker_selection_method.lower() == "random": selected_agent = self.random_select_speaker(graph_eligible_agents) + elif speaker_selection_method.lower() == "swarm": + selected_agent = self.swarm_select_speaker(last_speaker, graph_eligible_agents) else: # auto selected_agent = None select_speaker_messages = self.messages.copy() @@ -1125,6 +1179,34 @@ def print_messages(recipient, messages, sender, config): """ return self._last_speaker + def _process_reply_from_swarm(self, reply: Union[Dict, List[Dict]], groupchat: GroupChat) -> None: + # If we have a swarm reply, update context variables, and determine the next agent + if isinstance(reply, list): + pass + elif isinstance(reply, dict): + reply = [reply] + else: + return None + next_agent = None + for r in reply: + content = r.get("content") + if isinstance(content, SwarmResult): + if content.context_variables != {}: + groupchat.context_variables.update(content.context_variables) + if content.agent is not None: + next_agent = content.agent + elif isinstance(content, Agent): + next_agent = content + r["content"] = str(r["content"]) + return next_agent + + def _broadcast_message(self, groupchat: GroupChat, message: Dict, speaker: Agent) -> None: + # Broadcast the message to all agents except the speaker + groupchat.append(message, speaker) + for agent in groupchat.agents: + if agent != speaker: + self.send(message, agent, request_reply=False, silent=True) + def run_chat( self, messages: Optional[List[Dict]] = None, @@ -1136,6 +1218,7 @@ def run_chat( messages = self._oai_messages[sender] message = messages[-1] speaker = sender + next_speaker = None # The next swarm agent to speak, determined by the current swarm agent groupchat = config send_introductions = getattr(groupchat, "send_introductions", False) silent = getattr(self, "_silent", False) @@ -1148,28 +1231,48 @@ def run_chat( # NOTE: We do not also append to groupchat.messages, # since groupchat handles its own introductions + if self.groupchat.speaker_selection_method == "swarm": + config.allow_repeat_speaker = True # Swarms allow the last speaker to be the next speaker + if self.client_cache is not None: for a in groupchat.agents: a.previous_cache = a.client_cache a.client_cache = self.client_cache for i in range(groupchat.max_round): self._last_speaker = speaker - groupchat.append(message, speaker) - # broadcast the message to all agents except the speaker - for agent in groupchat.agents: - if agent != speaker: - self.send(message, agent, request_reply=False, silent=True) - if self._is_termination_msg(message) or i == groupchat.max_round - 1: - # The conversation is over or it's the last round - break + if isinstance(message, list): + for m in message: + self._broadcast_message(groupchat, m, speaker) + for m in message: + if self._is_termination_msg(m) or i == groupchat.max_round - 1: + break + else: + self._broadcast_message(groupchat, message, speaker) + if self._is_termination_msg(message) or i == groupchat.max_round - 1: + # The conversation is over or it's the last round + break try: - # select the next speaker - speaker = groupchat.select_speaker(speaker, self) + if next_speaker: + # Speaker has already been selected (swarm) + speaker = next_speaker + next_speaker = None + else: + speaker = groupchat.select_speaker(speaker, self) + if not silent: iostream = IOStream.get_default() iostream.print(colored(f"\nNext speaker: {speaker.name}\n", "green"), flush=True) + + # Update the context_variables on the agent + if isinstance(speaker, SwarmAgent): + speaker.context_variables.update(groupchat.context_variables) + # let the speaker speak - reply = speaker.generate_reply(sender=self) + reply = speaker.generate_reply(sender=self) # reply must be a dict or a list of dicts(only for swarm) + + if groupchat.speaker_selection_method == "swarm": + next_speaker = self._process_reply_from_swarm(reply, groupchat) # process the swarm reply: Update + except KeyboardInterrupt: # let the admin agent speak if interrupted if groupchat.admin_name in groupchat.agent_names: @@ -1197,8 +1300,13 @@ def run_chat( reply["content"] = self.clear_agents_history(reply, groupchat) # The speaker sends the message without requesting a reply - speaker.send(reply, self, request_reply=False, silent=silent) - message = self.last_message(speaker) + if isinstance(reply, list): + for r in reply: + speaker.send(r, self, request_reply=False, silent=silent) + message = reply + else: + speaker.send(reply, self, request_reply=False, silent=silent) + message = self.last_message(speaker) if self.client_cache is not None: for a in groupchat.agents: a.client_cache = a.previous_cache diff --git a/autogen/agentchat/swarm/__init__.py b/autogen/agentchat/swarm/__init__.py new file mode 100644 index 0000000000..95d4aab864 --- /dev/null +++ b/autogen/agentchat/swarm/__init__.py @@ -0,0 +1 @@ +from .swarm_agent import * diff --git a/autogen/agentchat/swarm/swarm_agent.py b/autogen/agentchat/swarm/swarm_agent.py new file mode 100644 index 0000000000..f91d5055c8 --- /dev/null +++ b/autogen/agentchat/swarm/swarm_agent.py @@ -0,0 +1,153 @@ +import json +from inspect import signature +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +from openai.types.chat.chat_completion import ChatCompletion +from pydantic import BaseModel + +from autogen.agentchat import Agent, ConversableAgent +from autogen.function_utils import get_function_schema +from autogen.oai import OpenAIWrapper + + +def parse_json_object(response: str) -> dict: + return json.loads(response) + + +# Parameter name for context variables +# Use the value in functions and they will be substituted with the context variables: +# e.g. def my_function(context_variables: Dict[str, Any], my_other_parameters: Any) -> Any: +__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables" + + +class SwarmResult(BaseModel): + """ + Encapsulates the possible return values for a swarm agent function. + + arguments: + values (str): The result values as a string. + agent (SwarmAgent): The swarm agent instance, if applicable. + context_variables (dict): A dictionary of context variables. + """ + values: str = "" + agent: Optional["SwarmAgent"] = None + context_variables: dict = {} + + class Config: # Add this inner class + arbitrary_types_allowed = True + + def __str__(self): + return self.values + +class SwarmAgent(ConversableAgent): + def __init__( + self, + name: str, + system_message: Optional[str] = "You are a helpful AI Assistant.", + llm_config: Optional[Union[Dict, Literal[False]]] = None, + functions: Union[List[Callable], Callable] = None, + is_termination_msg: Optional[Callable[[Dict], bool]] = None, + max_consecutive_auto_reply: Optional[int] = None, + human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", + description: Optional[str] = None, + context_variables: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + super().__init__( + name, + system_message, + is_termination_msg, + max_consecutive_auto_reply, + human_input_mode, + llm_config=llm_config, + description=description, + **kwargs, + ) + + if isinstance(functions, list): + self.add_functions(functions) + elif isinstance(functions, Callable): + self.add_single_function(functions) + + self._reply_func_list.clear() + self.register_reply([Agent, None], SwarmAgent.generate_reply_with_tool_calls) + self.context_variables = context_variables or {} + + def update_context_variables(self, context_variables: Dict[str, Any]) -> None: + pass + + def __str__(self): + return f"SwarmAgent: {self.name}" + + def generate_reply_with_tool_calls( + self, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + config: Optional[OpenAIWrapper] = None, + ) -> Tuple[bool, SwarmResult]: + + client = self.client if config is None else config + if client is None: + return False, None + if messages is None: + messages = self._oai_messages[sender] + + response = self._generate_oai_reply_from_client(client, self._oai_system_message + messages, self.client_cache) + + if isinstance(response, str): + return True, response + elif isinstance(response, dict): + # Tool calls, inject context_variables back in to the response before executing the tools + if "tool_calls" in response: + for tool_call in response["tool_calls"]: + if tool_call["type"] == "function": + function_name = tool_call["function"]["name"] + + # Check if this function exists in our function map + if function_name in self._function_map: + func = self._function_map[function_name] # Get the original function + + # Check if function has context_variables parameter + sig = signature(func) + if __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters: + current_args = json.loads(tool_call["function"]["arguments"]) + current_args[__CONTEXT_VARIABLES_PARAM_NAME__] = self.context_variables + # Update the tool call with new arguments + tool_call["function"]["arguments"] = json.dumps(current_args) + + # Generate tool calls reply + _, tool_message = self.generate_tool_calls_reply([response]) + return True, [response] + tool_message["tool_responses"] + else: + raise ValueError("Invalid response type:", type(response)) + + def add_single_function(self, func: Callable, description=""): + func._name = func.__name__ + + if description: + func._description = description + else: + # Use function's docstring, strip whitespace, fall back to empty string + func._description = (func.__doc__ or "").strip() + + f = get_function_schema(func, name=func._name, description=func._description) + + # Remove context_variables parameter from function schema + f_no_context = f.copy() + if __CONTEXT_VARIABLES_PARAM_NAME__ in f_no_context["function"]["parameters"]["properties"]: + del f_no_context["function"]["parameters"]["properties"][__CONTEXT_VARIABLES_PARAM_NAME__] + if "required" in f_no_context["function"]["parameters"]: + required = f_no_context["function"]["parameters"]["required"] + f_no_context["function"]["parameters"]["required"] = [ + param for param in required if param != __CONTEXT_VARIABLES_PARAM_NAME__ + ] + # If required list is empty, remove it + if not f_no_context["function"]["parameters"]["required"]: + del f_no_context["function"]["parameters"]["required"] + + self.update_tool_signature(f_no_context, is_remove=False) + self.register_function({func._name: func}) + + def add_functions(self, func_list: List[Callable]): + for func in func_list: + self.add_single_function(func) diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index de5f8bab11..c7c702ca16 100755 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -13,7 +13,8 @@ from types import SimpleNamespace from typing import Any, Dict, List, Optional from unittest import TestCase, mock - +import sys +import os import pytest from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST @@ -21,7 +22,79 @@ from autogen import Agent, AssistantAgent, GroupChat, GroupChatManager from autogen.agentchat.contrib.capabilities import transform_messages, transforms from autogen.exception_utils import AgentNameConflict, UndefinedNextAgent +from autogen.agentchat.swarm.swarm_agent import SwarmAgent, SwarmResult + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from conftest import skip_openai # noqa: E402 + +try: + from openai import OpenAI +except ImportError: + skip = True +else: + skip = False or skip_openai + + +@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip") +def test_swarm_agent(): + context_variables = {"1": False, "2": False, "3": False} + + def update_context_1(context_variables: dict) -> str: + context_variables["1"] = True + return SwarmResult(value="success", context_variables=context_variables) + + def update_context_2_and_transfer_to_3(context_variables: dict) -> str: + context_variables["2"] = True + return SwarmResult(value="success", context_variables=context_variables, agent=agent_3) + + def update_context_3(context_variables: dict) -> str: + context_variables["3"] = True + return SwarmResult(value="success", context_variables=context_variables) + + def transfer_to_agent_2() -> SwarmAgent: + return agent_2 + + agent_1 = SwarmAgent( + name="Agent_1", + system_message="You are Agent 1, first, call the function to update context 1, and transfer to Agent 2", + llm_config=llm_config, + functions=[update_context_1, transfer_to_agent_2], + ) + + agent_2 = SwarmAgent( + name="Agent_2", + system_message="You are Agent 2, call the function that updates context 2 and transfer to Agent 3", + llm_config=llm_config, + functions=[update_context_2_and_transfer_to_3], + ) + + agent_3 = SwarmAgent( + name="Agent_3", + system_message="You are Agent 3, first, call the function to update context 3, and then reply TERMINATE", + llm_config=llm_config, + functions=[update_context_3], + ) + + user = UserProxyAgent( + name="Human_User", + system_message="Human user", + human_input_mode="ALWAYS", + code_execution_config=False, + ) + groupchat = GroupChat( + agents=[user, agent_1, agent_2, agent_3], + messages=[], + max_round=10, + speaker_selection_method="swarm", + context_variables=context_variables, + ) + manager = GroupChatManager(groupchat=groupchat, llm_config=None) + + chat_result = user.initiate_chat(manager, message="start") + assert context_variables["1"] == True, "Expected context_variables['1'] to be True" + assert context_variables["2"] == True, "Expected context_variables['2'] to be True" + assert context_variables["3"] == True, "Expected context_variables['3'] to be True" def test_func_call_groupchat(): agent1 = autogen.ConversableAgent(