Skip to content

Commit

Permalink
Update tests for update_agent_state_before_reply
Browse files Browse the repository at this point in the history
Signed-off-by: Mark Sze <[email protected]>
  • Loading branch information
marklysze committed Dec 15, 2024
1 parent 2c3e063 commit bf0de64
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions test/agentchat/contrib/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def test_initialization():


def test_update_system_message():
"""Tests the update_agent_before_reply functionality with multiple scenarios"""
"""Tests the update_agent_state_before_reply functionality with multiple scenarios"""

# Test container to capture system messages
class MessageContainer:
Expand All @@ -480,9 +480,9 @@ def custom_update_function(agent: ConversableAgent, messages: List[Dict]) -> str
template_message = "Template message with {test_var}"

# Create agents with different update configurations
agent1 = SwarmAgent("agent1", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function))
agent1 = SwarmAgent("agent1", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function))

agent2 = SwarmAgent("agent2", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(template_message))
agent2 = SwarmAgent("agent2", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(template_message))

# Mock the reply function to capture the system message
def mock_generate_oai_reply(*args, **kwargs):
Expand Down Expand Up @@ -519,29 +519,29 @@ def mock_generate_oai_reply(*args, **kwargs):

# Test invalid update function
with pytest.raises(ValueError, match="Update function must be either a string or a callable"):
SwarmAgent("agent3", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(123))
SwarmAgent("agent3", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(123))

# Test invalid callable (wrong number of parameters)
def invalid_update_function(context_variables):
return "Invalid function"

with pytest.raises(ValueError, match="Update function must accept two parameters"):
SwarmAgent("agent4", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_update_function))
SwarmAgent("agent4", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_update_function))

# Test invalid callable (wrong return type)
def invalid_return_function(context_variables, messages) -> dict:
return {}

with pytest.raises(ValueError, match="Update function must return a string"):
SwarmAgent("agent5", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_return_function))
SwarmAgent("agent5", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_return_function))

# Test multiple update functions
def another_update_function(context_variables: Dict[str, Any], messages: List[Dict]) -> str:
return "Another update"

agent6 = SwarmAgent(
"agent6",
update_agent_before_reply=[
update_agent_state_before_reply=[
UPDATE_SYSTEM_MESSAGE(custom_update_function),
UPDATE_SYSTEM_MESSAGE(another_update_function),
],
Expand Down

0 comments on commit bf0de64

Please sign in to comment.