diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 6748ca2f7f..93507529ee 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -122,6 +122,17 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat): if tool_execution._next_agent is not None: next_agent = tool_execution._next_agent tool_execution._next_agent = None + + # Check for string, access agent from group chat. + + if isinstance(next_agent, str): + if next_agent in swarm_agent_names: + next_agent = groupchat.agent_by_name(name=next_agent) + else: + raise ValueError( + f"No agent found with the name '{next_agent}'. Ensure the agent exists in the swarm." + ) + return next_agent # get the last swarm agent @@ -228,7 +239,7 @@ class SwarmResult(BaseModel): """ values: str = "" - agent: Optional["SwarmAgent"] = None + agent: Optional[Union["SwarmAgent", str]] = None context_variables: Dict[str, Any] = {} class Config: # Add this inner class diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 0987ba9ca0..828b6c837e 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -461,5 +461,103 @@ def test_initialization(): ) +def test_string_agent_params_for_transfer(): + """Test that string agent parameters are handled correctly without using real LLMs.""" + # Define test configuration + testing_llm_config = { + "config_list": [ + { + "model": "gpt-4o", + "api_key": "SAMPLE_API_KEY", + } + ] + } + + # Define a simple function for testing + def hello_world(context_variables: dict) -> SwarmResult: + value = "Hello, World!" + return SwarmResult(values=value, context_variables=context_variables, agent="agent_2") + + # Create SwarmAgent instances + agent_1 = SwarmAgent( + name="agent_1", + system_message="Your task is to call hello_world() function.", + llm_config=testing_llm_config, + functions=[hello_world], + ) + agent_2 = SwarmAgent( + name="agent_2", + system_message="Your task is to let the user know what the previous agent said.", + llm_config=testing_llm_config, + ) + + # Mock LLM responses + def mock_generate_oai_reply_agent1(*args, **kwargs): + return True, { + "role": "assistant", + "name": "agent_1", + "tool_calls": [{"type": "function", "function": {"name": "hello_world", "arguments": "{}"}}], + "content": "I will call the hello_world function.", + } + + def mock_generate_oai_reply_agent2(*args, **kwargs): + return True, { + "role": "assistant", + "name": "agent_2", + "content": "The previous agent called hello_world and got: Hello, World!", + } + + # Register mock responses + agent_1.register_reply([ConversableAgent, None], mock_generate_oai_reply_agent1) + agent_2.register_reply([ConversableAgent, None], mock_generate_oai_reply_agent2) + + # Initiate the swarm chat + chat_result, final_context, last_active_agent = initiate_swarm_chat( + initial_agent=agent_1, + agents=[agent_1, agent_2], + context_variables={}, + messages="Begin by calling the hello_world() function.", + after_work=AFTER_WORK(AfterWorkOption.TERMINATE), + max_rounds=5, + ) + + # Assertions to verify the behavior + assert chat_result.chat_history[3]["name"] == "agent_2" + assert last_active_agent.name == "agent_2" + + # Define a simple function for testing + def hello_world(context_variables: dict) -> SwarmResult: + value = "Hello, World!" + return SwarmResult(values=value, context_variables=context_variables, agent="agent_unknown") + + agent_1 = SwarmAgent( + name="agent_1", + system_message="Your task is to call hello_world() function.", + llm_config=testing_llm_config, + functions=[hello_world], + ) + agent_2 = SwarmAgent( + name="agent_2", + system_message="Your task is to let the user know what the previous agent said.", + llm_config=testing_llm_config, + ) + + # Register mock responses + agent_1.register_reply([ConversableAgent, None], mock_generate_oai_reply_agent1) + agent_2.register_reply([ConversableAgent, None], mock_generate_oai_reply_agent2) + + with pytest.raises( + ValueError, match="No agent found with the name 'agent_unknown'. Ensure the agent exists in the swarm." + ): + chat_result, final_context, last_active_agent = initiate_swarm_chat( + initial_agent=agent_1, + agents=[agent_1, agent_2], + context_variables={}, + messages="Begin by calling the hello_world() function.", + after_work=AFTER_WORK(AfterWorkOption.TERMINATE), + max_rounds=5, + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/website/docs/topics/swarm.ipynb b/website/docs/topics/swarm.ipynb index 05a8454346..a445fb5db3 100644 --- a/website/docs/topics/swarm.ipynb +++ b/website/docs/topics/swarm.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Swarm Ochestration\n", + "# Swarm Orchestration\n", "\n", "With AG2, you can initiate a Swarm Chat similar to OpenAI's [Swarm](https://github.com/openai/swarm). This orchestration offers two main features:\n", "\n",