Skip to content

Commit

Permalink
Adding tests
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Sze <[email protected]>
  • Loading branch information
marklysze committed Dec 13, 2024
1 parent cf59a66 commit e3dba1e
Showing 1 changed file with 126 additions and 1 deletion.
127 changes: 126 additions & 1 deletion test/agentchat/contrib/test_swarm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict
from typing import Any, Dict, List, Union
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -16,6 +16,7 @@
initiate_swarm_chat,
)
from autogen.agentchat.conversable_agent import ConversableAgent
from autogen.agentchat.groupchat import GroupChat
from autogen.agentchat.user_proxy_agent import UserProxyAgent

TEST_MESSAGES = [{"role": "user", "content": "Initial message"}]
Expand Down Expand Up @@ -559,5 +560,129 @@ def hello_world(context_variables: dict) -> SwarmResult:
)


def test_after_work_callable():
"""Test Callable in an AFTER_WORK handoff"""

testing_llm_config = {
"config_list": [
{
"model": "gpt-4o",
"api_key": "SAMPLE_API_KEY",
}
]
}

agent1 = SwarmAgent("agent1", llm_config=testing_llm_config)
agent2 = SwarmAgent("agent2", llm_config=testing_llm_config)
agent3 = SwarmAgent("agent3", llm_config=testing_llm_config)

def return_agent(
last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat
) -> Union[AfterWorkOption, SwarmAgent, str]:
return agent2

def return_agent_str(
last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat
) -> Union[AfterWorkOption, SwarmAgent, str]:
return "agent3"

def return_after_work_option(
last_speaker: SwarmAgent, messages: List[Dict[str, Any]], groupchat: GroupChat
) -> Union[AfterWorkOption, SwarmAgent, str]:
return AfterWorkOption.TERMINATE

agent1.register_hand_off(
hand_to=[
AFTER_WORK(agent=return_agent),
]
)

agent2.register_hand_off(
hand_to=[
AFTER_WORK(agent=return_agent_str),
]
)

agent3.register_hand_off(
hand_to=[
AFTER_WORK(agent=return_after_work_option),
]
)

# Fake generate_oai_reply
def mock_generate_oai_reply(*args, **kwargs):
return True, "This is a mock response from the agent."

# Mock LLM responses
agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply)
agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply)
agent3.register_reply([ConversableAgent, None], mock_generate_oai_reply)

chat_result, context_vars, last_speaker = initiate_swarm_chat(
initial_agent=agent1,
messages=TEST_MESSAGES,
agents=[agent1, agent2, agent3],
max_rounds=5,
)

# Confirm transitions and it terminated with 4 messages
assert chat_result.chat_history[1]["name"] == "agent1"
assert chat_result.chat_history[2]["name"] == "agent2"
assert chat_result.chat_history[3]["name"] == "agent3"
assert len(chat_result.chat_history) == 4


def test_on_condition_unique_function_names():
"""Test that ON_CONDITION in handoffs generate unique function names"""

testing_llm_config = {
"config_list": [
{
"model": "gpt-4o",
"api_key": "SAMPLE_API_KEY",
}
]
}

agent1 = SwarmAgent("agent1", llm_config=testing_llm_config)
agent2 = SwarmAgent("agent2", llm_config=testing_llm_config)

agent1.register_hand_off(
hand_to=[
ON_CONDITION(target=agent2, condition="always take me to agent 2"),
ON_CONDITION(target=agent2, condition="sometimes take me there"),
ON_CONDITION(target=agent2, condition="always take me there"),
]
)

# Fake generate_oai_reply
def mock_generate_oai_reply(*args, **kwargs):
return True, "This is a mock response from the agent."

# Fake generate_oai_reply
def mock_generate_oai_reply_tool(*args, **kwargs):
return True, {
"role": "assistant",
"name": "agent1",
"tool_calls": [{"type": "function", "function": {"name": "transfer_agent1_to_agent2"}}],
}

# Mock LLM responses
agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply_tool)
agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply)

chat_result, context_vars, last_speaker = initiate_swarm_chat(
initial_agent=agent1,
messages=TEST_MESSAGES,
agents=[agent1, agent2],
max_rounds=5,
)

# Check that agent1 has 3 functions and they have unique names
assert "transfer_agent1_to_agent2" in agent1._function_map
assert "transfer_agent1_to_agent2_2" in agent1._function_map
assert "transfer_agent1_to_agent2_3" in agent1._function_map


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit e3dba1e

Please sign in to comment.