diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index b3ffa9cf15..b7f4ad40c6 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -1,7 +1,7 @@ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict +from typing import Any, Dict, List, Union from unittest.mock import MagicMock, patch import pytest @@ -16,6 +16,7 @@ initiate_swarm_chat, ) from autogen.agentchat.conversable_agent import ConversableAgent +from autogen.agentchat.groupchat import GroupChat from autogen.agentchat.user_proxy_agent import UserProxyAgent TEST_MESSAGES = [{"role": "user", "content": "Initial message"}] @@ -559,5 +560,129 @@ def hello_world(context_variables: dict) -> SwarmResult: ) +def test_after_work_callable(): + """Test Callable in an AFTER_WORK handoff""" + + testing_llm_config = { + "config_list": [ + { + "model": "gpt-4o", + "api_key": "SAMPLE_API_KEY", + } + ] + } + + agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) + agent2 = SwarmAgent("agent2", llm_config=testing_llm_config) + agent3 = SwarmAgent("agent3", llm_config=testing_llm_config) + + def return_agent( + last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat + ) -> Union[AfterWorkOption, SwarmAgent, str]: + return agent2 + + def return_agent_str( + last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat + ) -> Union[AfterWorkOption, SwarmAgent, str]: + return "agent3" + + def return_after_work_option( + last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat + ) -> Union[AfterWorkOption, SwarmAgent, str]: + return AfterWorkOption.TERMINATE + + agent1.register_hand_off( + hand_to=[ + AFTER_WORK(agent=return_agent), + ] + ) + + agent2.register_hand_off( + hand_to=[ + AFTER_WORK(agent=return_agent_str), + ] + ) + + agent3.register_hand_off( + hand_to=[ + AFTER_WORK(agent=return_after_work_option), + ] + ) + + # Fake generate_oai_reply + def mock_generate_oai_reply(*args, **kwargs): + return True, "This is a mock response from the agent." + + # Mock LLM responses + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply) + agent3.register_reply([ConversableAgent, None], mock_generate_oai_reply) + + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, + messages=TEST_MESSAGES, + agents=[agent1, agent2, agent3], + max_rounds=5, + ) + + # Confirm transitions and it terminated with 4 messages + assert chat_result.chat_history[1]["name"] == "agent1" + assert chat_result.chat_history[2]["name"] == "agent2" + assert chat_result.chat_history[3]["name"] == "agent3" + assert len(chat_result.chat_history) == 4 + + +def test_on_condition_unique_function_names(): + """Test that ON_CONDITION in handoffs generate unique function names""" + + testing_llm_config = { + "config_list": [ + { + "model": "gpt-4o", + "api_key": "SAMPLE_API_KEY", + } + ] + } + + agent1 = SwarmAgent("agent1", llm_config=testing_llm_config) + agent2 = SwarmAgent("agent2", llm_config=testing_llm_config) + + agent1.register_hand_off( + hand_to=[ + ON_CONDITION(target=agent2, condition="always take me to agent 2"), + ON_CONDITION(target=agent2, condition="sometimes take me there"), + ON_CONDITION(target=agent2, condition="always take me there"), + ] + ) + + # Fake generate_oai_reply + def mock_generate_oai_reply(*args, **kwargs): + return True, "This is a mock response from the agent." + + # Fake generate_oai_reply + def mock_generate_oai_reply_tool(*args, **kwargs): + return True, { + "role": "assistant", + "name": "agent1", + "tool_calls": [{"type": "function", "function": {"name": "transfer_agent1_to_agent2"}}], + } + + # Mock LLM responses + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply) + + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, + messages=TEST_MESSAGES, + agents=[agent1, agent2], + max_rounds=5, + ) + + # Check that agent1 has 3 functions and they have unique names + assert "transfer_agent1_to_agent2" in agent1._function_map + assert "transfer_agent1_to_agent2_2" in agent1._function_map + assert "transfer_agent1_to_agent2_3" in agent1._function_map + + if __name__ == "__main__": pytest.main([__file__])