Skip to content

Commit

Permalink
Added tests for refactored initiate_swarm_chat functions, fixed group…
Browse files Browse the repository at this point in the history
…chat bug.

Signed-off-by: Mark Sze <[email protected]>
  • Loading branch information
marklysze committed Dec 30, 2024
1 parent e23f4f3 commit 8478804
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 1 deletion.
4 changes: 4 additions & 0 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,10 @@ async def a_run_chat(
else:
# admin agent is not found in the participants
raise
except NoEligibleSpeaker:
# No eligible speaker, terminate the conversation
break

if reply is None:
break
# The speaker sends the message without requesting a reply
Expand Down
231 changes: 230 additions & 1 deletion test/agentchat/contrib/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,23 @@

from autogen.agentchat.contrib.swarm_agent import (
__CONTEXT_VARIABLES_PARAM_NAME__,
__TOOL_EXECUTOR_NAME__,
AFTER_WORK,
ON_CONDITION,
UPDATE_SYSTEM_MESSAGE,
AfterWorkOption,
SwarmAgent,
SwarmResult,
_cleanup_temp_user_messages,
_create_nested_chats,
_prepare_swarm_agents,
_process_initial_messages,
_setup_context_variables,
a_initiate_swarm_chat,
initiate_swarm_chat,
)
from autogen.agentchat.conversable_agent import ConversableAgent
from autogen.agentchat.groupchat import GroupChat
from autogen.agentchat.groupchat import GroupChat, GroupChatManager
from autogen.agentchat.user_proxy_agent import UserProxyAgent

TEST_MESSAGES = [{"role": "user", "content": "Initial message"}]
Expand Down Expand Up @@ -780,5 +787,227 @@ def mock_generate_oai_reply_tool(*args, **kwargs):
assert "transfer_agent1_to_agent2_3" in agent1._function_map


def test_prepare_swarm_agents():
"""Test preparation of swarm agents including tool executor setup"""

testing_llm_config = {
"config_list": [
{
"model": "gpt-4o",
"api_key": "SAMPLE_API_KEY",
}
]
}

# Create test agents
agent1 = SwarmAgent("agent1", llm_config=testing_llm_config)
agent2 = SwarmAgent("agent2", llm_config=testing_llm_config)
agent3 = SwarmAgent("agent3", llm_config=testing_llm_config)

# Add some functions to test tool executor aggregation
def test_func1():
pass

def test_func2():
pass

agent1.add_single_function(test_func1)
agent2.add_single_function(test_func2)

# Add handoffs to test validation
agent1.register_hand_off(AFTER_WORK(agent=agent2))

# Test valid preparation
tool_executor, nested_chat_agents = _prepare_swarm_agents(agent1, [agent1, agent2])

# Verify tool executor setup
assert tool_executor.name == __TOOL_EXECUTOR_NAME__
assert "test_func1" in tool_executor._function_map
assert "test_func2" in tool_executor._function_map

# Test invalid initial agent type
with pytest.raises(AssertionError):
_prepare_swarm_agents(ConversableAgent("invalid"), [agent1, agent2])

# Test invalid agents list
with pytest.raises(AssertionError):
_prepare_swarm_agents(agent1, [agent1, ConversableAgent("invalid")])

# Test missing handoff agent
agent3.register_hand_off(AFTER_WORK(agent=SwarmAgent("missing")))
with pytest.raises(AssertionError):
_prepare_swarm_agents(agent1, [agent1, agent2, agent3])


def test_create_nested_chats():
"""Test creation of nested chat agents and registration of handoffs"""

testing_llm_config = {
"config_list": [
{
"model": "gpt-4o",
"api_key": "SAMPLE_API_KEY",
}
]
}

test_agent = SwarmAgent("test_agent", llm_config=testing_llm_config)
test_agent_2 = SwarmAgent("test_agent_2", llm_config=testing_llm_config)
nested_chat_agents = []

nested_chat_one = {
"carryover_config": {"summary_method": "last_msg"},
"recipient": test_agent_2,
"message": "Extract the order details",
"max_turns": 1,
}

chat_queue = [nested_chat_one]

# Register a nested chat handoff
nested_chat_config = {
"chat_queue": chat_queue,
"reply_func_from_nested_chats": "summary_from_nested_chats",
"config": None,
"use_async": False,
}

test_agent.register_hand_off(ON_CONDITION(target=nested_chat_config, condition="test condition"))

# Create nested chats
_create_nested_chats(test_agent, nested_chat_agents)

# Verify nested chat agent creation
assert len(nested_chat_agents) == 1
assert nested_chat_agents[0].name == f"nested_chat_{test_agent.name}_1"

# Verify nested chat configuration
# The nested chat agent should have a handoff back to the passed in agent
assert nested_chat_agents[0].after_work.agent == test_agent


def test_process_initial_messages():
"""Test processing of initial messages in different scenarios"""

agent1 = SwarmAgent("agent1")
agent2 = SwarmAgent("agent2")
nested_agent = SwarmAgent("nested_chat_agent1_1")
user_agent = UserProxyAgent("test_user")

# Test single string message
messages = "Initial message"
processed_messages, last_agent, agent_names, temp_users = _process_initial_messages(
messages, None, [agent1, agent2], [nested_agent]
)

assert len(processed_messages) == 1
assert processed_messages[0]["content"] == "Initial message"
assert len(temp_users) == 1 # Should create temporary user
assert temp_users[0].name == "_User"

# Test message with existing agent name
messages = [{"role": "user", "content": "Test", "name": "agent1"}]
processed_messages, last_agent, agent_names, temp_users = _process_initial_messages(
messages, user_agent, [agent1, agent2], [nested_agent]
)

assert last_agent == agent1
assert len(temp_users) == 0 # Should not create temp user

# Test message with user agent name
messages = [{"role": "user", "content": "Test", "name": "test_user"}]
processed_messages, last_agent, agent_names, temp_users = _process_initial_messages(
messages, user_agent, [agent1, agent2], [nested_agent]
)

assert last_agent == user_agent
assert len(temp_users) == 0

# Test invalid agent name
messages = [{"role": "user", "content": "Test", "name": "invalid_agent"}]
with pytest.raises(ValueError):
_process_initial_messages(messages, user_agent, [agent1, agent2], [nested_agent])


def test_setup_context_variables():
"""Test setup of context variables across agents"""

tool_execution = SwarmAgent(__TOOL_EXECUTOR_NAME__)
agent1 = SwarmAgent("agent1")
agent2 = SwarmAgent("agent2")

groupchat = GroupChat(agents=[tool_execution, agent1, agent2], messages=[])
manager = GroupChatManager(groupchat)

test_context = {"test_key": "test_value"}

_setup_context_variables(tool_execution, [agent1, agent2], manager, test_context)

# Verify all agents share the same context_variables reference
assert tool_execution._context_variables is test_context
assert agent1._context_variables is test_context
assert agent2._context_variables is test_context
assert manager._context_variables is test_context


def test_cleanup_temp_user_messages():
"""Test cleanup of temporary user messages"""

chat_result = MagicMock()
chat_result.chat_history = [
{"role": "user", "name": "_User", "content": "Test 1"},
{"role": "assistant", "name": "agent1", "content": "Response 1"},
{"role": "user", "name": "_User", "content": "Test 2"},
]

_cleanup_temp_user_messages(chat_result)

# Verify _User names are removed
for message in chat_result.chat_history:
if message["role"] == "user":
assert "name" not in message


@pytest.mark.asyncio
async def test_a_initiate_swarm_chat():
"""Test async swarm chat"""

agent1 = SwarmAgent("agent1")
agent2 = SwarmAgent("agent2")
user_agent = UserProxyAgent("test_user")

# Mock async reply function
async def mock_a_generate_oai_reply(*args, **kwargs):
return True, "This is a mock response from the agent."

# Register mock replies
agent1.register_reply([ConversableAgent, None], mock_a_generate_oai_reply)
agent2.register_reply([ConversableAgent, None], mock_a_generate_oai_reply)

# Test with string message
chat_result, context_vars, last_speaker = await a_initiate_swarm_chat(
initial_agent=agent1, messages="Test message", agents=[agent1, agent2], user_agent=user_agent, max_rounds=3
)

assert len(chat_result.chat_history) > 0

# Test with message list which should include call a_resume
messages = [{"role": "user", "content": "Test"}, {"role": "assistant", "name": "agent1", "content": "Response"}]

chat_result, context_vars, last_speaker = await a_initiate_swarm_chat(
initial_agent=agent1, messages=messages, agents=[agent1, agent2], user_agent=user_agent, max_rounds=3
)

assert len(chat_result.chat_history) > 1

# Test context variables
test_context = {"test_key": "test_value"}
chat_result, context_vars, last_speaker = await a_initiate_swarm_chat(
initial_agent=agent1, messages="Test", agents=[agent1, agent2], context_variables=test_context, max_rounds=3
)

assert context_vars == test_context


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 8478804

Please sign in to comment.