From ae823c919f9fe97c6fd7b2adeeed329113ea3777 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 25 Nov 2024 04:49:09 +0000 Subject: [PATCH 1/5] Swarm tests and bug fixes --- .github/workflows/contrib-tests.yml | 37 ++ autogen/agentchat/contrib/swarm_agent.py | 85 ++++- test/agentchat/contrib/test_swarm.py | 444 +++++++++++++++++++++++ 3 files changed, 548 insertions(+), 18 deletions(-) create mode 100644 test/agentchat/contrib/test_swarm.py diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 9e93490ea7..1c9e473414 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -732,3 +732,40 @@ jobs: with: file: ./coverage.xml flags: unittests + + SwarmTest: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + exclude: + - os: macos-latest + python-version: "3.9" + steps: + - uses: actions/checkout@v4 + with: + lfs: true + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install pytest-cov>=5 + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Coverage + run: | + pytest test/agentchat/contrib/test_swarm.py --skip-openai + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + file: ./coverage.xml + flags: unittests diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index ffdde772c2..1e305858d8 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -44,6 +44,10 @@ class ON_CONDITION: agent: "SwarmAgent" condition: str = "" + # Ensure that agent is a SwarmAgent + def __post_init__(self): + assert isinstance(self.agent, SwarmAgent), "Agent must be a SwarmAgent" + def initiate_swarm_chat( initial_agent: "SwarmAgent", @@ -80,7 +84,12 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any SwarmAgent: Last speaker. """ assert isinstance(initial_agent, SwarmAgent), "initial_agent must be a SwarmAgent" - assert all(isinstance(agent, SwarmAgent) for agent in agents), "agents must be a list of SwarmAgents" + assert all(isinstance(agent, SwarmAgent) for agent in agents), "Agents must be a list of SwarmAgents" + # Ensure all agents in hand-off after-works are in the passed in agents list + for agent in agents: + if agent.after_work is not None: + if isinstance(agent.after_work.agent, SwarmAgent): + assert agent.after_work.agent in agents, "Agent in hand-off must be in the agents list" context_variables = context_variables or {} if isinstance(messages, str): @@ -175,9 +184,12 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat): last_message = messages[0] if "name" in last_message: - if "name" in swarm_agent_names: + if last_message["name"] in swarm_agent_names: # If there's a name in the message and it's a swarm agent, use that last_agent = groupchat.agent_by_name(name=last_message["name"]) + elif user_agent and last_message["name"] == user_agent.name: + # If the user agent is passed in and is the first message + last_agent = user_agent else: raise ValueError(f"Invalid swarm agent name in last message: {last_message['name']}") else: @@ -260,9 +272,13 @@ def __init__( ) if isinstance(functions, list): + if not all(isinstance(func, Callable) for func in functions): + raise TypeError("All elements in the functions list must be callable") self.add_functions(functions) elif isinstance(functions, Callable): self.add_single_function(functions) + elif functions is not None: + raise TypeError("Functions must be a callable or a list of callables") self.after_work = None @@ -299,11 +315,18 @@ def transfer_to_agent_name() -> SwarmAgent: 1. register the function with the agent 2. register the schema with the agent, description set to the condition """ + # Ensure that hand_to is a list or ON_CONDITION or AFTER_WORK + if not isinstance(hand_to, (list, ON_CONDITION, AFTER_WORK)): + raise ValueError("hand_to must be a list of ON_CONDITION or AFTER_WORK") + if isinstance(hand_to, (ON_CONDITION, AFTER_WORK)): hand_to = [hand_to] for transit in hand_to: if isinstance(transit, AFTER_WORK): + assert isinstance( + transit.agent, (AfterWorkOption, SwarmAgent, str, Callable) + ), "Invalid After Work value" self.after_work = transit elif isinstance(transit, ON_CONDITION): @@ -340,8 +363,18 @@ def generate_swarm_tool_reply( message = messages[-1] if "tool_calls" in message: - # 1. add context_variables to the tool call arguments - for tool_call in message["tool_calls"]: + + tool_calls = len(message["tool_calls"]) + + # Loop through tool calls individually (so context can be updated after each function call) + next_agent = None + tool_responses_inner = [] + contents = [] + for index in range(tool_calls): + + # 1. add context_variables to the tool call arguments + tool_call = message["tool_calls"][index] + if tool_call["type"] == "function": function_name = tool_call["function"]["name"] @@ -357,20 +390,36 @@ def generate_swarm_tool_reply( # Update the tool call with new arguments tool_call["function"]["arguments"] = json.dumps(current_args) - # 2. generate tool calls reply - _, tool_message = self.generate_tool_calls_reply([message]) - - # 3. update context_variables and next_agent, convert content to string - for tool_response in tool_message["tool_responses"]: - content = tool_response.get("content") - if isinstance(content, SwarmResult): - if content.context_variables != {}: - self._context_variables.update(content.context_variables) - if content.agent is not None: - self._next_agent = content.agent - elif isinstance(content, Agent): - self._next_agent = content - tool_response["content"] = str(tool_response["content"]) + # Copy the message + message_copy = message.copy() + tool_calls_copy = message_copy["tool_calls"] + + # remove all the tool calls except the one at the index + message_copy["tool_calls"] = [tool_calls_copy[index]] + + # 2. generate tool calls reply + _, tool_message = self.generate_tool_calls_reply([message_copy]) + + # 3. update context_variables and next_agent, convert content to string + for tool_response in tool_message["tool_responses"]: + content = tool_response.get("content") + if isinstance(content, SwarmResult): + if content.context_variables != {}: + self._context_variables.update(content.context_variables) + if content.agent is not None: + self._next_agent = content.agent + elif isinstance(content, Agent): + self._next_agent = content + + tool_responses_inner.append(tool_response) + contents.append(str(tool_response["content"])) + + self._next_agent = next_agent + + # Put the tool responses and content strings back into the response message + # Caters for multiple tool calls + tool_message["tool_responses"] = tool_responses_inner + tool_message["content"] = "\n".join(contents) return True, tool_message return False, None diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py new file mode 100644 index 0000000000..29bddef7ed --- /dev/null +++ b/test/agentchat/contrib/test_swarm.py @@ -0,0 +1,444 @@ +import os +import sys +from typing import Any, Dict +from unittest.mock import MagicMock, Mock, call, patch + +import pytest + +from autogen import ConversableAgent, UserProxyAgent, config_list_from_json +from autogen.agentchat.contrib.swarm_agent import ( + __CONTEXT_VARIABLES_PARAM_NAME__, + AFTER_WORK, + ON_CONDITION, + AfterWorkOption, + SwarmAgent, + SwarmResult, + initiate_swarm_chat, +) + +TEST_MESSAGES = [{"role": "user", "content": "Initial message"}] + + +def test_swarm_agent_initialization(): + """Test SwarmAgent initialization with valid and invalid parameters""" + + # Valid initialization + agent = SwarmAgent("test_agent") + assert agent.name == "test_agent" + assert agent.human_input_mode == "NEVER" + + # Invalid functions parameter + with pytest.raises(TypeError): + SwarmAgent("test_agent", functions="invalid") + + +def test_swarm_result(): + """Test SwarmResult initialization and string conversion""" + # Valid initialization + result = SwarmResult(values="test result") + assert str(result) == "test result" + assert result.context_variables == {} + + # Test with context variables + context = {"key": "value"} + result = SwarmResult(values="test", context_variables=context) + assert result.context_variables == context + + # Test with agent + agent = SwarmAgent("test") + result = SwarmResult(values="test", agent=agent) + assert result.agent == agent + + +def test_after_work_initialization(): + """Test AFTER_WORK initialization with different options""" + # Test with AfterWorkOption + after_work = AFTER_WORK(AfterWorkOption.TERMINATE) + assert after_work.agent == AfterWorkOption.TERMINATE + + # Test with string + after_work = AFTER_WORK("TERMINATE") + assert after_work.agent == AfterWorkOption.TERMINATE + + # Test with SwarmAgent + agent = SwarmAgent("test") + after_work = AFTER_WORK(agent) + assert after_work.agent == agent + + # Test with Callable + def test_callable(x: int) -> SwarmAgent: + return agent + + after_work = AFTER_WORK(test_callable) + assert after_work.agent == test_callable + + # Test with invalid option + with pytest.raises(ValueError): + AFTER_WORK("INVALID_OPTION") + + +def test_on_condition(): + """Test ON_CONDITION initialization""" + agent = SwarmAgent("test") + condition = ON_CONDITION(agent=agent, condition="test condition") + assert condition.agent == agent + assert condition.condition == "test condition" + + # Test with a ConversableAgent + test_conversable_agent = ConversableAgent("test_conversable_agent") + with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"): + condition = ON_CONDITION(agent=test_conversable_agent, condition="test condition") + + +def test_receiving_agent(): + """Test the receiving agent based on various starting messages""" + + # 1. Test with a single message - should always be the initial agent + messages_one_no_name = [{"role": "user", "content": "Initial message"}] + + test_initial_agent = SwarmAgent("InitialAgent") + + # Test the chat + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=test_initial_agent, messages=messages_one_no_name, agents=[test_initial_agent] + ) + + # Make sure the first speaker (second message) is the initialagent + assert "name" not in chat_result.chat_history[0] # _User should not exist + assert chat_result.chat_history[1].get("name") == "InitialAgent" + + # 2. Test with a single message from an existing agent (should still be initial agent) + test_second_agent = SwarmAgent("SecondAgent") + + messages_one_w_name = [{"role": "user", "content": "Initial message", "name": "SecondAgent"}] + + # Test the chat + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=test_initial_agent, messages=messages_one_w_name, agents=[test_initial_agent, test_second_agent] + ) + + assert chat_result.chat_history[0].get("name") == "SecondAgent" + assert chat_result.chat_history[1].get("name") == "InitialAgent" + + # 3. Test with a single message from a user agent, user passed in + + test_user = UserProxyAgent("MyUser") + + messages_one_w_name = [{"role": "user", "content": "Initial message", "name": "MyUser"}] + + # Test the chat + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=test_second_agent, + user_agent=test_user, + messages=messages_one_w_name, + agents=[test_initial_agent, test_second_agent], + ) + + assert chat_result.chat_history[0].get("name") == "MyUser" # Should persist + assert chat_result.chat_history[1].get("name") == "SecondAgent" + + +def test_swarm_transitions(): + """Test different swarm transition scenarios""" + agent1 = SwarmAgent("agent1") + agent2 = SwarmAgent("agent2") + + # Test initial transition + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, messages=TEST_MESSAGES, agents=[agent1, agent2] + ) + assert last_speaker == agent1 + + # If we have multiple messages, first agent is still the initial_agent + multiple_messages = [ + {"role": "user", "content": "First message"}, + {"role": "assistant", "name": "agent2", "content": "Response"}, + ] + + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, messages=multiple_messages, agents=[agent1, agent2] + ) + + assert isinstance(last_speaker, SwarmAgent) + assert last_speaker == agent1 + + +def test_after_work_options(): + """Test different after work options""" + + agent1 = SwarmAgent("agent1") + agent2 = SwarmAgent("agent2") + user_agent = UserProxyAgent("test_user") + + # Fake generate_oai_reply + def mock_generate_oai_reply(*args, **kwargs): + return True, "This is a mock response from the agent." + + # Mock an LLM response by overriding the generate_oai_reply function + for agent in [agent1, agent2]: + for reply_func_tuple in agent._reply_func_list: + if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply": + reply_func_tuple["reply_func"] = mock_generate_oai_reply + + # 1. Test TERMINATE + agent1.after_work = AFTER_WORK(AfterWorkOption.TERMINATE) + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, messages=TEST_MESSAGES, agents=[agent1, agent2] + ) + assert last_speaker == agent1 + + # 2. Test REVERT_TO_USER + agent1.after_work = AFTER_WORK(AfterWorkOption.REVERT_TO_USER) + + test_messages = [ + {"role": "user", "content": "Initial message"}, + {"role": "assistant", "name": "agent1", "content": "Response"}, + ] + + with patch("builtins.input", return_value="continue"): + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, messages=test_messages, agents=[agent1, agent2], user_agent=user_agent, max_rounds=4 + ) + + # Ensure that after agent1 is finished, it goes to user (4th message) + assert chat_result.chat_history[3]["name"] == "test_user" + + # 3. Test STAY + agent1.after_work = AFTER_WORK(AfterWorkOption.STAY) + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, messages=test_messages, agents=[agent1, agent2], max_rounds=4 + ) + + # Stay on agent1 + assert chat_result.chat_history[3]["name"] == "agent1" + + # 4. Test Callable + + # Transfer to agent2 + def test_callable(last_speaker, messages, groupchat, context_variables): + return agent2 + + agent1.after_work = AFTER_WORK(test_callable) + + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, messages=test_messages, agents=[agent1, agent2], max_rounds=4 + ) + + # We should have transferred to agent2 after agent1 has finished + assert chat_result.chat_history[3]["name"] == "agent2" + + +def test_temporary_user_proxy(): + """Test that temporary user proxy agent name is cleared""" + agent1 = SwarmAgent("agent1") + agent2 = SwarmAgent("agent2") + + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, messages=TEST_MESSAGES, agents=[agent1, agent2] + ) + + # Verify no message has name "_User" + for message in chat_result.chat_history: + assert message.get("name") != "_User" + + +def test_context_variables_updating(): + """Test context variables handling in tool calls""" + + testing_llm_config = { + "config_list": [ + { + "model": "gpt-4o", + "api_key": "SAMPLE_API_KEY", + } + ] + } + + # Starting context variable, this will increment in the swarm + test_context_variables = {"my_key": 0} + + # Increment the context variable + def test_func(context_variables: Dict[str, Any], param1: str) -> str: + context_variables["my_key"] += 1 + return SwarmResult(values=f"Test {param1}", context_variables=context_variables, agent=agent1) + + agent1 = SwarmAgent("agent1", functions=[test_func], llm_config=testing_llm_config) + agent2 = SwarmAgent("agent2", functions=[test_func], llm_config=testing_llm_config) + + # Fake generate_oai_reply + def mock_generate_oai_reply(*args, **kwargs): + return True, "This is a mock response from the agent." + + # Fake generate_oai_reply + def mock_generate_oai_reply_tool(*args, **kwargs): + return True, { + "role": "assistant", + "name": "agent1", + "tool_calls": [{"type": "function", "function": {"name": "test_func", "arguments": '{"param1": "test"}'}}], + } + + # Mock an LLM response by overriding the generate_oai_reply function + for agent in [agent1, agent2]: + for reply_func_tuple in agent._reply_func_list: + if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply": + if agent == agent1: + reply_func_tuple["reply_func"] = mock_generate_oai_reply + elif agent == agent2: + reply_func_tuple["reply_func"] = mock_generate_oai_reply_tool + + # Test message with a tool call + tool_call_messages = [ + {"role": "user", "content": "Initial message"}, + ] + + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent2, + messages=tool_call_messages, + agents=[agent1, agent2], + context_variables=test_context_variables, + max_rounds=3, + ) + + # Ensure we've incremented the context variable + assert context_vars["my_key"] == 1 + + +def test_context_variables_updating_multi_tools(): + """Test context variables handling in tool calls""" + + testing_llm_config = { + "config_list": [ + { + "model": "gpt-4o", + "api_key": "SAMPLE_API_KEY", + } + ] + } + + # Starting context variable, this will increment in the swarm + test_context_variables = {"my_key": 0} + + # Increment the context variable + def test_func_1(context_variables: Dict[str, Any], param1: str) -> str: + context_variables["my_key"] += 1 + return SwarmResult(values=f"Test 1 {param1}", context_variables=context_variables, agent=agent1) + + # Increment the context variable + def test_func_2(context_variables: Dict[str, Any], param2: str) -> str: + context_variables["my_key"] += 100 + return SwarmResult(values=f"Test 2 {param2}", context_variables=context_variables, agent=agent1) + + agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) + agent2 = SwarmAgent("agent2", functions=[test_func_1, test_func_2], llm_config=testing_llm_config) + + # Fake generate_oai_reply + def mock_generate_oai_reply(*args, **kwargs): + return True, "This is a mock response from the agent." + + # Fake generate_oai_reply + def mock_generate_oai_reply_tool(*args, **kwargs): + return True, { + "role": "assistant", + "name": "agent1", + "tool_calls": [ + {"type": "function", "function": {"name": "test_func_1", "arguments": '{"param1": "test"}'}}, + {"type": "function", "function": {"name": "test_func_2", "arguments": '{"param2": "test"}'}}, + ], + } + + # Mock an LLM response by overriding the generate_oai_reply function + for agent in [agent1, agent2]: + for reply_func_tuple in agent._reply_func_list: + if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply": + if agent == agent1: + reply_func_tuple["reply_func"] = mock_generate_oai_reply + elif agent == agent2: + reply_func_tuple["reply_func"] = mock_generate_oai_reply_tool + + # Test message with a tool call + tool_call_messages = [ + {"role": "user", "content": "Initial message"}, + ] + + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent2, + messages=tool_call_messages, + agents=[agent1, agent2], + context_variables=test_context_variables, + max_rounds=3, + ) + + # Ensure we've incremented the context variable + # in both tools, updated values should traverse + # 0 + 1 (func 1) + 100 (func 2) = 101 + assert context_vars["my_key"] == 101 + + +def test_invalid_parameters(): + """Test various invalid parameter combinations""" + agent1 = SwarmAgent("agent1") + agent2 = SwarmAgent("agent2") + + # Test invalid initial agent type + with pytest.raises(AssertionError): + initiate_swarm_chat(initial_agent="not_an_agent", messages=TEST_MESSAGES, agents=[agent1, agent2]) + + # Test invalid agents list + with pytest.raises(AssertionError): + initiate_swarm_chat(initial_agent=agent1, messages=TEST_MESSAGES, agents=["not_an_agent", agent2]) + + # Test invalid after_work type + with pytest.raises(ValueError): + initiate_swarm_chat(initial_agent=agent1, messages=TEST_MESSAGES, agents=[agent1, agent2], after_work="invalid") + + +def test_non_swarm_in_hand_off(): + """Test that SwarmAgents in the group chat are the only agents in hand-offs""" + + agent1 = SwarmAgent("agent1") + bad_agent = ConversableAgent("bad_agent") + + with pytest.raises(AssertionError, match="Invalid After Work value"): + agent1.register_hand_off(hand_to=AFTER_WORK(bad_agent)) + + with pytest.raises(AssertionError, match="Invalid After Work value"): + agent1.register_hand_off(hand_to=AFTER_WORK(0)) + + with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"): + agent1.register_hand_off(hand_to=ON_CONDITION(0, "Testing")) + + with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"): + agent1.register_hand_off(hand_to=ON_CONDITION(bad_agent, "Testing")) + + with pytest.raises(ValueError, match="hand_to must be a list of ON_CONDITION or AFTER_WORK"): + agent1.register_hand_off(0) + + +def test_initialization(): + """Test initiate_swarm_chat""" + + agent1 = SwarmAgent("agent1") + agent2 = SwarmAgent("agent2") + agent3 = SwarmAgent("agent3") + bad_agent = ConversableAgent("bad_agent") + + with pytest.raises(AssertionError, match="Agents must be a list of SwarmAgents"): + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent2, messages=TEST_MESSAGES, agents=[agent1, agent2, bad_agent], max_rounds=3 + ) + + with pytest.raises(AssertionError, match="initial_agent must be a SwarmAgent"): + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=bad_agent, messages=TEST_MESSAGES, agents=[agent1, agent2], max_rounds=3 + ) + + agent1.register_hand_off(hand_to=AFTER_WORK(agent3)) + + with pytest.raises(AssertionError, match="Agent in hand-off must be in the agents list"): + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, messages=TEST_MESSAGES, agents=[agent1, agent2], max_rounds=3 + ) + + +if __name__ == "__main__": + pytest.main([__file__]) From e4a334e2808b9ecfd2c8e1831ba603b16b13cd78 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 25 Nov 2024 19:10:08 +0000 Subject: [PATCH 2/5] Fixed a bug with the next agent in function results, added ON_CONDITION test, test tidy ups for comments --- autogen/agentchat/contrib/swarm_agent.py | 4 +- test/agentchat/contrib/test_swarm.py | 156 ++++++++++++----------- 2 files changed, 83 insertions(+), 77 deletions(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 1e305858d8..c1c790a906 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -407,9 +407,9 @@ def generate_swarm_tool_reply( if content.context_variables != {}: self._context_variables.update(content.context_variables) if content.agent is not None: - self._next_agent = content.agent + next_agent = content.agent elif isinstance(content, Agent): - self._next_agent = content + next_agent = content tool_responses_inner.append(tool_response) contents.append(str(tool_response["content"])) diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 29bddef7ed..a9d1693bf8 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -1,11 +1,8 @@ -import os -import sys from typing import Any, Dict -from unittest.mock import MagicMock, Mock, call, patch +from unittest.mock import patch import pytest -from autogen import ConversableAgent, UserProxyAgent, config_list_from_json from autogen.agentchat.contrib.swarm_agent import ( __CONTEXT_VARIABLES_PARAM_NAME__, AFTER_WORK, @@ -15,6 +12,8 @@ SwarmResult, initiate_swarm_chat, ) +from autogen.agentchat.conversable_agent import ConversableAgent +from autogen.agentchat.user_proxy_agent import UserProxyAgent TEST_MESSAGES = [{"role": "user", "content": "Initial message"}] @@ -22,11 +21,6 @@ def test_swarm_agent_initialization(): """Test SwarmAgent initialization with valid and invalid parameters""" - # Valid initialization - agent = SwarmAgent("test_agent") - assert agent.name == "test_agent" - assert agent.human_input_mode == "NEVER" - # Invalid functions parameter with pytest.raises(TypeError): SwarmAgent("test_agent", functions="invalid") @@ -79,15 +73,11 @@ def test_callable(x: int) -> SwarmAgent: def test_on_condition(): """Test ON_CONDITION initialization""" - agent = SwarmAgent("test") - condition = ON_CONDITION(agent=agent, condition="test condition") - assert condition.agent == agent - assert condition.condition == "test condition" # Test with a ConversableAgent test_conversable_agent = ConversableAgent("test_conversable_agent") with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"): - condition = ON_CONDITION(agent=test_conversable_agent, condition="test condition") + _ = ON_CONDITION(agent=test_conversable_agent, condition="test condition") def test_receiving_agent(): @@ -159,7 +149,6 @@ def test_swarm_transitions(): initial_agent=agent1, messages=multiple_messages, agents=[agent1, agent2] ) - assert isinstance(last_speaker, SwarmAgent) assert last_speaker == agent1 @@ -174,11 +163,9 @@ def test_after_work_options(): def mock_generate_oai_reply(*args, **kwargs): return True, "This is a mock response from the agent." - # Mock an LLM response by overriding the generate_oai_reply function - for agent in [agent1, agent2]: - for reply_func_tuple in agent._reply_func_list: - if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply": - reply_func_tuple["reply_func"] = mock_generate_oai_reply + # Mock LLM responses + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply) # 1. Test TERMINATE agent1.after_work = AFTER_WORK(AfterWorkOption.TERMINATE) @@ -228,6 +215,50 @@ def test_callable(last_speaker, messages, groupchat, context_variables): assert chat_result.chat_history[3]["name"] == "agent2" +def test_on_condition_handoff(): + """Test ON_CONDITION in handoffs""" + + testing_llm_config = { + "config_list": [ + { + "model": "gpt-4o", + "api_key": "SAMPLE_API_KEY", + } + ] + } + + agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) + agent2 = SwarmAgent("agent2", llm_config=testing_llm_config) + + agent1.register_hand_off(hand_to=ON_CONDITION(agent2, "always take me to agent 2")) + + # Fake generate_oai_reply + def mock_generate_oai_reply(*args, **kwargs): + return True, "This is a mock response from the agent." + + # Fake generate_oai_reply + def mock_generate_oai_reply_tool(*args, **kwargs): + return True, { + "role": "assistant", + "name": "agent1", + "tool_calls": [{"type": "function", "function": {"name": "transfer_to_agent2"}}], + } + + # Mock LLM responses + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply) + + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, + messages=TEST_MESSAGES, + agents=[agent1, agent2], + max_rounds=5, + ) + + # We should have transferred to agent2 after agent1 has finished + assert chat_result.chat_history[3]["name"] == "agent2" + + def test_temporary_user_proxy(): """Test that temporary user proxy agent name is cleared""" agent1 = SwarmAgent("agent1") @@ -242,7 +273,7 @@ def test_temporary_user_proxy(): assert message.get("name") != "_User" -def test_context_variables_updating(): +def test_context_variables_updating_multi_tools(): """Test context variables handling in tool calls""" testing_llm_config = { @@ -258,12 +289,17 @@ def test_context_variables_updating(): test_context_variables = {"my_key": 0} # Increment the context variable - def test_func(context_variables: Dict[str, Any], param1: str) -> str: + def test_func_1(context_variables: Dict[str, Any], param1: str) -> str: context_variables["my_key"] += 1 - return SwarmResult(values=f"Test {param1}", context_variables=context_variables, agent=agent1) + return SwarmResult(values=f"Test 1 {param1}", context_variables=context_variables, agent=agent1) + + # Increment the context variable + def test_func_2(context_variables: Dict[str, Any], param2: str) -> str: + context_variables["my_key"] += 100 + return SwarmResult(values=f"Test 2 {param2}", context_variables=context_variables, agent=agent1) - agent1 = SwarmAgent("agent1", functions=[test_func], llm_config=testing_llm_config) - agent2 = SwarmAgent("agent2", functions=[test_func], llm_config=testing_llm_config) + agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) + agent2 = SwarmAgent("agent2", functions=[test_func_1, test_func_2], llm_config=testing_llm_config) # Fake generate_oai_reply def mock_generate_oai_reply(*args, **kwargs): @@ -274,37 +310,32 @@ def mock_generate_oai_reply_tool(*args, **kwargs): return True, { "role": "assistant", "name": "agent1", - "tool_calls": [{"type": "function", "function": {"name": "test_func", "arguments": '{"param1": "test"}'}}], + "tool_calls": [ + {"type": "function", "function": {"name": "test_func_1", "arguments": '{"param1": "test"}'}}, + {"type": "function", "function": {"name": "test_func_2", "arguments": '{"param2": "test"}'}}, + ], } - # Mock an LLM response by overriding the generate_oai_reply function - for agent in [agent1, agent2]: - for reply_func_tuple in agent._reply_func_list: - if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply": - if agent == agent1: - reply_func_tuple["reply_func"] = mock_generate_oai_reply - elif agent == agent2: - reply_func_tuple["reply_func"] = mock_generate_oai_reply_tool - - # Test message with a tool call - tool_call_messages = [ - {"role": "user", "content": "Initial message"}, - ] + # Mock LLM responses + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool) chat_result, context_vars, last_speaker = initiate_swarm_chat( initial_agent=agent2, - messages=tool_call_messages, + messages=TEST_MESSAGES, agents=[agent1, agent2], context_variables=test_context_variables, max_rounds=3, ) # Ensure we've incremented the context variable - assert context_vars["my_key"] == 1 + # in both tools, updated values should traverse + # 0 + 1 (func 1) + 100 (func 2) = 101 + assert context_vars["my_key"] == 101 -def test_context_variables_updating_multi_tools(): - """Test context variables handling in tool calls""" +def test_function_transfer(): + """Tests a function call that has a transfer to agent in the SwarmResult""" testing_llm_config = { "config_list": [ @@ -323,13 +354,8 @@ def test_func_1(context_variables: Dict[str, Any], param1: str) -> str: context_variables["my_key"] += 1 return SwarmResult(values=f"Test 1 {param1}", context_variables=context_variables, agent=agent1) - # Increment the context variable - def test_func_2(context_variables: Dict[str, Any], param2: str) -> str: - context_variables["my_key"] += 100 - return SwarmResult(values=f"Test 2 {param2}", context_variables=context_variables, agent=agent1) - agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) - agent2 = SwarmAgent("agent2", functions=[test_func_1, test_func_2], llm_config=testing_llm_config) + agent2 = SwarmAgent("agent2", functions=[test_func_1], llm_config=testing_llm_config) # Fake generate_oai_reply def mock_generate_oai_reply(*args, **kwargs): @@ -342,36 +368,22 @@ def mock_generate_oai_reply_tool(*args, **kwargs): "name": "agent1", "tool_calls": [ {"type": "function", "function": {"name": "test_func_1", "arguments": '{"param1": "test"}'}}, - {"type": "function", "function": {"name": "test_func_2", "arguments": '{"param2": "test"}'}}, ], } - # Mock an LLM response by overriding the generate_oai_reply function - for agent in [agent1, agent2]: - for reply_func_tuple in agent._reply_func_list: - if reply_func_tuple["reply_func"].__name__ == "generate_oai_reply": - if agent == agent1: - reply_func_tuple["reply_func"] = mock_generate_oai_reply - elif agent == agent2: - reply_func_tuple["reply_func"] = mock_generate_oai_reply_tool - - # Test message with a tool call - tool_call_messages = [ - {"role": "user", "content": "Initial message"}, - ] + # Mock LLM responses + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool) chat_result, context_vars, last_speaker = initiate_swarm_chat( initial_agent=agent2, - messages=tool_call_messages, + messages=TEST_MESSAGES, agents=[agent1, agent2], context_variables=test_context_variables, - max_rounds=3, + max_rounds=4, ) - # Ensure we've incremented the context variable - # in both tools, updated values should traverse - # 0 + 1 (func 1) + 100 (func 2) = 101 - assert context_vars["my_key"] == 101 + assert chat_result.chat_history[3]["name"] == "agent1" def test_invalid_parameters(): @@ -401,12 +413,6 @@ def test_non_swarm_in_hand_off(): with pytest.raises(AssertionError, match="Invalid After Work value"): agent1.register_hand_off(hand_to=AFTER_WORK(bad_agent)) - with pytest.raises(AssertionError, match="Invalid After Work value"): - agent1.register_hand_off(hand_to=AFTER_WORK(0)) - - with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"): - agent1.register_hand_off(hand_to=ON_CONDITION(0, "Testing")) - with pytest.raises(AssertionError, match="Agent must be a SwarmAgent"): agent1.register_hand_off(hand_to=ON_CONDITION(bad_agent, "Testing")) From 57d59302052a5463f093548e0b39b174150dcbba Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 25 Nov 2024 20:52:59 +0000 Subject: [PATCH 3/5] Update testing workflow to install AG2 Signed-off-by: Mark Sze --- .github/workflows/contrib-tests.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml index 1c9e473414..ac0240eb43 100644 --- a/.github/workflows/contrib-tests.yml +++ b/.github/workflows/contrib-tests.yml @@ -755,6 +755,9 @@ jobs: run: | python -m pip install --upgrade pip wheel pip install pytest-cov>=5 + - name: Install packages and dependencies for Swarms + run: | + pip install -e . - name: Set AUTOGEN_USE_DOCKER based on OS shell: bash run: | From 5a199a8d7ea17fadd3dbc10434779a8503ccf8d3 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 25 Nov 2024 21:02:38 +0000 Subject: [PATCH 4/5] Add license to test file Signed-off-by: Mark Sze --- test/agentchat/contrib/test_swarm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index a9d1693bf8..3d8383e055 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -1,3 +1,6 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 from typing import Any, Dict from unittest.mock import patch From 575b180542ebf2de5cd4b79b7c51cdf6f2aa64c9 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Tue, 26 Nov 2024 23:19:41 +0000 Subject: [PATCH 5/5] Removed unnecessary test, added resume test Signed-off-by: Mark Sze --- test/agentchat/contrib/test_swarm.py | 44 ++++++++++++++++++---------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 3d8383e055..0987ba9ca0 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 from typing import Any, Dict -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -131,28 +131,40 @@ def test_receiving_agent(): assert chat_result.chat_history[1].get("name") == "SecondAgent" -def test_swarm_transitions(): - """Test different swarm transition scenarios""" - agent1 = SwarmAgent("agent1") - agent2 = SwarmAgent("agent2") +def test_resume_speaker(): + """Tests resumption of chat with multiple messages""" - # Test initial transition - chat_result, context_vars, last_speaker = initiate_swarm_chat( - initial_agent=agent1, messages=TEST_MESSAGES, agents=[agent1, agent2] - ) - assert last_speaker == agent1 + test_initial_agent = SwarmAgent("InitialAgent") + test_second_agent = SwarmAgent("SecondAgent") - # If we have multiple messages, first agent is still the initial_agent + # For multiple messages, last agent initiates the chat multiple_messages = [ {"role": "user", "content": "First message"}, - {"role": "assistant", "name": "agent2", "content": "Response"}, + {"role": "assistant", "name": "InitialAgent", "content": "Second message"}, + {"role": "assistant", "name": "SecondAgent", "content": "Third message"}, ] - chat_result, context_vars, last_speaker = initiate_swarm_chat( - initial_agent=agent1, messages=multiple_messages, agents=[agent1, agent2] - ) + # Patch initiate_chat on agents so we can monitor which started the conversation + with patch.object(test_initial_agent, "initiate_chat") as mock_initial_chat, patch.object( + test_second_agent, "initiate_chat" + ) as mock_second_chat: - assert last_speaker == agent1 + mock_chat_result = MagicMock() + mock_chat_result.chat_history = multiple_messages + + # Set up the return value for the mock that will be called + mock_second_chat.return_value = mock_chat_result + + # Run the function + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=test_initial_agent, messages=multiple_messages, agents=[test_initial_agent, test_second_agent] + ) + + # Ensure the second agent initiated the chat + mock_second_chat.assert_called_once() + + # And it wasn't the initial_agent's agent + mock_initial_chat.assert_not_called() def test_after_work_options():