From bf0de6407166662a93b8bc95f16e53c4b4aec035 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sun, 15 Dec 2024 21:26:54 +0000 Subject: [PATCH] Update tests for update_agent_state_before_reply Signed-off-by: Mark Sze --- test/agentchat/contrib/test_swarm.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 85130baac2..ae2f3cf9b9 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -463,7 +463,7 @@ def test_initialization(): def test_update_system_message(): - """Tests the update_agent_before_reply functionality with multiple scenarios""" + """Tests the update_agent_state_before_reply functionality with multiple scenarios""" # Test container to capture system messages class MessageContainer: @@ -480,9 +480,9 @@ def custom_update_function(agent: ConversableAgent, messages: List[Dict]) -> str template_message = "Template message with {test_var}" # Create agents with different update configurations - agent1 = SwarmAgent("agent1", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function)) + agent1 = SwarmAgent("agent1", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function)) - agent2 = SwarmAgent("agent2", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(template_message)) + agent2 = SwarmAgent("agent2", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(template_message)) # Mock the reply function to capture the system message def mock_generate_oai_reply(*args, **kwargs): @@ -519,21 +519,21 @@ def mock_generate_oai_reply(*args, **kwargs): # Test invalid update function with pytest.raises(ValueError, match="Update function must be either a string or a callable"): - SwarmAgent("agent3", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(123)) + SwarmAgent("agent3", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(123)) # Test invalid callable (wrong number of parameters) def invalid_update_function(context_variables): return "Invalid function" with pytest.raises(ValueError, match="Update function must accept two parameters"): - SwarmAgent("agent4", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_update_function)) + SwarmAgent("agent4", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_update_function)) # Test invalid callable (wrong return type) def invalid_return_function(context_variables, messages) -> dict: return {} with pytest.raises(ValueError, match="Update function must return a string"): - SwarmAgent("agent5", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_return_function)) + SwarmAgent("agent5", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_return_function)) # Test multiple update functions def another_update_function(context_variables: Dict[str, Any], messages: List[Dict]) -> str: @@ -541,7 +541,7 @@ def another_update_function(context_variables: Dict[str, Any], messages: List[Di agent6 = SwarmAgent( "agent6", - update_agent_before_reply=[ + update_agent_state_before_reply=[ UPDATE_SYSTEM_MESSAGE(custom_update_function), UPDATE_SYSTEM_MESSAGE(another_update_function), ],