From 84788047d9c27ee0684ddb13e7bfbc34ce4f1b04 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 30 Dec 2024 02:48:39 +0000 Subject: [PATCH] Added tests for refactored initiate_swarm_chat functions, fixed groupchat bug. Signed-off-by: Mark Sze --- autogen/agentchat/groupchat.py | 4 + test/agentchat/contrib/test_swarm.py | 231 ++++++++++++++++++++++++++- 2 files changed, 234 insertions(+), 1 deletion(-) diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index 4a2bc18241..3ba876266c 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -1262,6 +1262,10 @@ async def a_run_chat( else: # admin agent is not found in the participants raise + except NoEligibleSpeaker: + # No eligible speaker, terminate the conversation + break + if reply is None: break # The speaker sends the message without requesting a reply diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 85c24110a0..7ccb70aa3b 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -8,16 +8,23 @@ from autogen.agentchat.contrib.swarm_agent import ( __CONTEXT_VARIABLES_PARAM_NAME__, + __TOOL_EXECUTOR_NAME__, AFTER_WORK, ON_CONDITION, UPDATE_SYSTEM_MESSAGE, AfterWorkOption, SwarmAgent, SwarmResult, + _cleanup_temp_user_messages, + _create_nested_chats, + _prepare_swarm_agents, + _process_initial_messages, + _setup_context_variables, + a_initiate_swarm_chat, initiate_swarm_chat, ) from autogen.agentchat.conversable_agent import ConversableAgent -from autogen.agentchat.groupchat import GroupChat +from autogen.agentchat.groupchat import GroupChat, GroupChatManager from autogen.agentchat.user_proxy_agent import UserProxyAgent TEST_MESSAGES = [{"role": "user", "content": "Initial message"}] @@ -780,5 +787,227 @@ def mock_generate_oai_reply_tool(*args, **kwargs): assert "transfer_agent1_to_agent2_3" in agent1._function_map +def test_prepare_swarm_agents(): + """Test preparation of swarm agents including tool executor setup""" + + testing_llm_config = { + "config_list": [ + { + "model": "gpt-4o", + "api_key": "SAMPLE_API_KEY", + } + ] + } + + # Create test agents + agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) + agent2 = SwarmAgent("agent2", llm_config=testing_llm_config) + agent3 = SwarmAgent("agent3", llm_config=testing_llm_config) + + # Add some functions to test tool executor aggregation + def test_func1(): + pass + + def test_func2(): + pass + + agent1.add_single_function(test_func1) + agent2.add_single_function(test_func2) + + # Add handoffs to test validation + agent1.register_hand_off(AFTER_WORK(agent=agent2)) + + # Test valid preparation + tool_executor, nested_chat_agents = _prepare_swarm_agents(agent1, [agent1, agent2]) + + # Verify tool executor setup + assert tool_executor.name == __TOOL_EXECUTOR_NAME__ + assert "test_func1" in tool_executor._function_map + assert "test_func2" in tool_executor._function_map + + # Test invalid initial agent type + with pytest.raises(AssertionError): + _prepare_swarm_agents(ConversableAgent("invalid"), [agent1, agent2]) + + # Test invalid agents list + with pytest.raises(AssertionError): + _prepare_swarm_agents(agent1, [agent1, ConversableAgent("invalid")]) + + # Test missing handoff agent + agent3.register_hand_off(AFTER_WORK(agent=SwarmAgent("missing"))) + with pytest.raises(AssertionError): + _prepare_swarm_agents(agent1, [agent1, agent2, agent3]) + + +def test_create_nested_chats(): + """Test creation of nested chat agents and registration of handoffs""" + + testing_llm_config = { + "config_list": [ + { + "model": "gpt-4o", + "api_key": "SAMPLE_API_KEY", + } + ] + } + + test_agent = SwarmAgent("test_agent", llm_config=testing_llm_config) + test_agent_2 = SwarmAgent("test_agent_2", llm_config=testing_llm_config) + nested_chat_agents = [] + + nested_chat_one = { + "carryover_config": {"summary_method": "last_msg"}, + "recipient": test_agent_2, + "message": "Extract the order details", + "max_turns": 1, + } + + chat_queue = [nested_chat_one] + + # Register a nested chat handoff + nested_chat_config = { + "chat_queue": chat_queue, + "reply_func_from_nested_chats": "summary_from_nested_chats", + "config": None, + "use_async": False, + } + + test_agent.register_hand_off(ON_CONDITION(target=nested_chat_config, condition="test condition")) + + # Create nested chats + _create_nested_chats(test_agent, nested_chat_agents) + + # Verify nested chat agent creation + assert len(nested_chat_agents) == 1 + assert nested_chat_agents[0].name == f"nested_chat_{test_agent.name}_1" + + # Verify nested chat configuration + # The nested chat agent should have a handoff back to the passed in agent + assert nested_chat_agents[0].after_work.agent == test_agent + + +def test_process_initial_messages(): + """Test processing of initial messages in different scenarios""" + + agent1 = SwarmAgent("agent1") + agent2 = SwarmAgent("agent2") + nested_agent = SwarmAgent("nested_chat_agent1_1") + user_agent = UserProxyAgent("test_user") + + # Test single string message + messages = "Initial message" + processed_messages, last_agent, agent_names, temp_users = _process_initial_messages( + messages, None, [agent1, agent2], [nested_agent] + ) + + assert len(processed_messages) == 1 + assert processed_messages[0]["content"] == "Initial message" + assert len(temp_users) == 1 # Should create temporary user + assert temp_users[0].name == "_User" + + # Test message with existing agent name + messages = [{"role": "user", "content": "Test", "name": "agent1"}] + processed_messages, last_agent, agent_names, temp_users = _process_initial_messages( + messages, user_agent, [agent1, agent2], [nested_agent] + ) + + assert last_agent == agent1 + assert len(temp_users) == 0 # Should not create temp user + + # Test message with user agent name + messages = [{"role": "user", "content": "Test", "name": "test_user"}] + processed_messages, last_agent, agent_names, temp_users = _process_initial_messages( + messages, user_agent, [agent1, agent2], [nested_agent] + ) + + assert last_agent == user_agent + assert len(temp_users) == 0 + + # Test invalid agent name + messages = [{"role": "user", "content": "Test", "name": "invalid_agent"}] + with pytest.raises(ValueError): + _process_initial_messages(messages, user_agent, [agent1, agent2], [nested_agent]) + + +def test_setup_context_variables(): + """Test setup of context variables across agents""" + + tool_execution = SwarmAgent(__TOOL_EXECUTOR_NAME__) + agent1 = SwarmAgent("agent1") + agent2 = SwarmAgent("agent2") + + groupchat = GroupChat(agents=[tool_execution, agent1, agent2], messages=[]) + manager = GroupChatManager(groupchat) + + test_context = {"test_key": "test_value"} + + _setup_context_variables(tool_execution, [agent1, agent2], manager, test_context) + + # Verify all agents share the same context_variables reference + assert tool_execution._context_variables is test_context + assert agent1._context_variables is test_context + assert agent2._context_variables is test_context + assert manager._context_variables is test_context + + +def test_cleanup_temp_user_messages(): + """Test cleanup of temporary user messages""" + + chat_result = MagicMock() + chat_result.chat_history = [ + {"role": "user", "name": "_User", "content": "Test 1"}, + {"role": "assistant", "name": "agent1", "content": "Response 1"}, + {"role": "user", "name": "_User", "content": "Test 2"}, + ] + + _cleanup_temp_user_messages(chat_result) + + # Verify _User names are removed + for message in chat_result.chat_history: + if message["role"] == "user": + assert "name" not in message + + +@pytest.mark.asyncio +async def test_a_initiate_swarm_chat(): + """Test async swarm chat""" + + agent1 = SwarmAgent("agent1") + agent2 = SwarmAgent("agent2") + user_agent = UserProxyAgent("test_user") + + # Mock async reply function + async def mock_a_generate_oai_reply(*args, **kwargs): + return True, "This is a mock response from the agent." + + # Register mock replies + agent1.register_reply([ConversableAgent, None], mock_a_generate_oai_reply) + agent2.register_reply([ConversableAgent, None], mock_a_generate_oai_reply) + + # Test with string message + chat_result, context_vars, last_speaker = await a_initiate_swarm_chat( + initial_agent=agent1, messages="Test message", agents=[agent1, agent2], user_agent=user_agent, max_rounds=3 + ) + + assert len(chat_result.chat_history) > 0 + + # Test with message list which should include call a_resume + messages = [{"role": "user", "content": "Test"}, {"role": "assistant", "name": "agent1", "content": "Response"}] + + chat_result, context_vars, last_speaker = await a_initiate_swarm_chat( + initial_agent=agent1, messages=messages, agents=[agent1, agent2], user_agent=user_agent, max_rounds=3 + ) + + assert len(chat_result.chat_history) > 1 + + # Test context variables + test_context = {"test_key": "test_value"} + chat_result, context_vars, last_speaker = await a_initiate_swarm_chat( + initial_agent=agent1, messages="Test", agents=[agent1, agent2], context_variables=test_context, max_rounds=3 + ) + + assert context_vars == test_context + + if __name__ == "__main__": pytest.main([__file__])