diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index 166ecff4c4..9250d18612 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -445,7 +445,10 @@ def swarm_select_speaker(self, last_speaker: Agent, agents: Optional[List[Agent] # Always start with the first speaker if len(messages) <= 1: - print("aaaaaaa") + if last_speaker == user_agent: + for agent in agents: + if isinstance(agent, SwarmAgent): + return agent return user_agent last_message = messages[-1] # If the last message is a TRANSFER message, extract agent name and return them @@ -1189,19 +1192,12 @@ def _process_reply_from_swarm(self, reply: Union[Dict, List[Dict]], groupchat: G content = r.get("content") if isinstance(content, SwarmResult): if content.context_variables != {}: - self.groupchat.context_variables.update(content.context_variables) + groupchat.context_variables.update(content.context_variables) if content.agent is not None: next_agent = content.agent - - # Change content back to a string for consistency with messages - r["content"] = content.values elif isinstance(content, Agent): next_agent = content - - # Change content back to a string - # Consider adjusting this message, e.g. f"Transfer to {next_agent.name}" - r["content"] = next_agent.name - + r["content"] = str(r["content"]) return next_agent def _broadcast_message(self, groupchat: GroupChat, message: Dict, speaker: Agent) -> None: @@ -1275,7 +1271,7 @@ def run_chat( reply = speaker.generate_reply(sender=self) # reply must be a dict or a list of dicts(only for swarm) if groupchat.speaker_selection_method == "swarm": - next_speaker = self._process_reply_from_swarm(reply, speaker) # process the swarm reply: Update + next_speaker = self._process_reply_from_swarm(reply, groupchat) # process the swarm reply: Update except KeyboardInterrupt: # let the admin agent speak if interrupted @@ -1322,7 +1318,6 @@ async def a_run_chat( messages: Optional[List[Dict]] = None, sender: Optional[Agent] = None, config: Optional[GroupChat] = None, - context_variables: Optional[Dict] = {}, # For Swarms ): """Run a group chat asynchronously.""" if messages is None: diff --git a/autogen/agentchat/swarm/swarm_agent.py b/autogen/agentchat/swarm/swarm_agent.py index 529d0e68f8..f91d5055c8 100644 --- a/autogen/agentchat/swarm/swarm_agent.py +++ b/autogen/agentchat/swarm/swarm_agent.py @@ -29,15 +29,15 @@ class SwarmResult(BaseModel): agent (SwarmAgent): The swarm agent instance, if applicable. context_variables (dict): A dictionary of context variables. """ - values: str = "" agent: Optional["SwarmAgent"] = None context_variables: dict = {} + class Config: # Add this inner class + arbitrary_types_allowed = True -# 1. SwarmResult should be a single instance, a single tool call can return one result only. -# 2. In generate_reply_with_tool_calls, We only process the tool_responses from a single message from generate_tool_calls_reply - + def __str__(self): + return self.values class SwarmAgent(ConversableAgent): def __init__( @@ -92,9 +92,6 @@ def generate_reply_with_tool_calls( if messages is None: messages = self._oai_messages[sender] - # print("messages", messages) - # print(self.llm_config['tools']) - # exit() response = self._generate_oai_reply_from_client(client, self._oai_system_message + messages, self.client_cache) if isinstance(response, str): @@ -120,13 +117,6 @@ def generate_reply_with_tool_calls( # Generate tool calls reply _, tool_message = self.generate_tool_calls_reply([response]) - - # a tool_response example: - # { - # "role": "tool", - # "content": A str, or an object (SwarmResult, SwarmAgent, etc.) - # "tool_call_id": - # }, return True, [response] + tool_message["tool_responses"] else: raise ValueError("Invalid response type:", type(response)) diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index de5f8bab11..c7c702ca16 100755 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -13,7 +13,8 @@ from types import SimpleNamespace from typing import Any, Dict, List, Optional from unittest import TestCase, mock - +import sys +import os import pytest from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST @@ -21,7 +22,79 @@ from autogen import Agent, AssistantAgent, GroupChat, GroupChatManager from autogen.agentchat.contrib.capabilities import transform_messages, transforms from autogen.exception_utils import AgentNameConflict, UndefinedNextAgent +from autogen.agentchat.swarm.swarm_agent import SwarmAgent, SwarmResult + +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +from conftest import skip_openai # noqa: E402 + +try: + from openai import OpenAI +except ImportError: + skip = True +else: + skip = False or skip_openai + + +@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip") +def test_swarm_agent(): + context_variables = {"1": False, "2": False, "3": False} + + def update_context_1(context_variables: dict) -> str: + context_variables["1"] = True + return SwarmResult(value="success", context_variables=context_variables) + + def update_context_2_and_transfer_to_3(context_variables: dict) -> str: + context_variables["2"] = True + return SwarmResult(value="success", context_variables=context_variables, agent=agent_3) + + def update_context_3(context_variables: dict) -> str: + context_variables["3"] = True + return SwarmResult(value="success", context_variables=context_variables) + + def transfer_to_agent_2() -> SwarmAgent: + return agent_2 + + agent_1 = SwarmAgent( + name="Agent_1", + system_message="You are Agent 1, first, call the function to update context 1, and transfer to Agent 2", + llm_config=llm_config, + functions=[update_context_1, transfer_to_agent_2], + ) + + agent_2 = SwarmAgent( + name="Agent_2", + system_message="You are Agent 2, call the function that updates context 2 and transfer to Agent 3", + llm_config=llm_config, + functions=[update_context_2_and_transfer_to_3], + ) + + agent_3 = SwarmAgent( + name="Agent_3", + system_message="You are Agent 3, first, call the function to update context 3, and then reply TERMINATE", + llm_config=llm_config, + functions=[update_context_3], + ) + + user = UserProxyAgent( + name="Human_User", + system_message="Human user", + human_input_mode="ALWAYS", + code_execution_config=False, + ) + groupchat = GroupChat( + agents=[user, agent_1, agent_2, agent_3], + messages=[], + max_round=10, + speaker_selection_method="swarm", + context_variables=context_variables, + ) + manager = GroupChatManager(groupchat=groupchat, llm_config=None) + + chat_result = user.initiate_chat(manager, message="start") + assert context_variables["1"] == True, "Expected context_variables['1'] to be True" + assert context_variables["2"] == True, "Expected context_variables['2'] to be True" + assert context_variables["3"] == True, "Expected context_variables['3'] to be True" def test_func_call_groupchat(): agent1 = autogen.ConversableAgent(