Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swarm: Allow functions to update agent's state, including system message, before replying #104

Merged
merged 22 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9a8c1c5
Added system_message_func to SwarmAgent and update of sys message whe…
marklysze Nov 28, 2024
db729dd
Add a test for the system message function
marklysze Nov 28, 2024
e768a69
Interim commit with context_variables on ConversableAgent
marklysze Nov 30, 2024
40c0b47
Implemented update_agent_state hook, UPDATE_SYSTEM_MESSAGE
marklysze Nov 30, 2024
0b8ba3b
process_update_agent_states no longer returns messages
marklysze Nov 30, 2024
12b0bbe
Update hook to pass in agent and messages (context available on agent…
marklysze Nov 30, 2024
8fddf90
Merge remote-tracking branch 'origin/main' into swarmsysmsgfunc
marklysze Nov 30, 2024
d212e2b
Updated context variable access methods, update_agent_before_reply pa…
marklysze Nov 30, 2024
bb38573
test update_system_message
linmou Nov 30, 2024
10a4e8f
pre-commit updates
marklysze Dec 1, 2024
623727b
Fix for ConversableAgent's a_generate_reply
marklysze Dec 1, 2024
8188593
Added ConversableAgent context variable tests
marklysze Dec 1, 2024
8425600
Merge remote-tracking branch 'origin/main' into swarmsysmsgfunc
marklysze Dec 2, 2024
3cdad79
Merge branch 'main' into swarmsysmsgfunc
marklysze Dec 3, 2024
482a60e
Merge remote-tracking branch 'origin/main' into swarmsysmsgfunc
marklysze Dec 6, 2024
b9352da
Corrected missing variable from nested chat PR
marklysze Dec 6, 2024
71cc5c7
Restore conversable agent context getters/setters
marklysze Dec 6, 2024
790f037
Docs and update system message callable signature change
marklysze Dec 7, 2024
675c82d
Merge remote-tracking branch 'origin/main' into swarmsysmsgfunc
marklysze Dec 8, 2024
58ba2ad
Merge remote-tracking branch 'origin/main' into swarmsysmsgfunc
marklysze Dec 15, 2024
2c3e063
Updated parameter name to update_agent_state_before_reply
marklysze Dec 15, 2024
bf0de64
Update tests for update_agent_state_before_reply
marklysze Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autogen/agentchat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .contrib.swarm_agent import (
AFTER_WORK,
ON_CONDITION,
UPDATE_SYSTEM_MESSAGE,
AfterWorkOption,
SwarmAgent,
SwarmResult,
Expand Down Expand Up @@ -39,4 +40,5 @@
"ON_CONDITION",
"AFTER_WORK",
"AfterWorkOption",
"UPDATE_SYSTEM_MESSAGE",
]
107 changes: 99 additions & 8 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import copy
import inspect
import json
import re
import warnings
from dataclasses import dataclass
from enum import Enum
from inspect import signature
Expand Down Expand Up @@ -57,6 +59,29 @@ def __post_init__(self):
assert isinstance(self.condition, str) and self.condition.strip(), "'condition' must be a non-empty string"


@dataclass
class UPDATE_SYSTEM_MESSAGE:
update_function: Union[Callable, str]

def __post_init__(self):
if isinstance(self.update_function, str):
# find all {var} in the string
vars = re.findall(r"\{(\w+)\}", self.update_function)
if len(vars) == 0:
warnings.warn("Update function string contains no variables. This is probably unintended.")

elif isinstance(self.update_function, Callable):
sig = signature(self.update_function)
if len(sig.parameters) != 2:
raise ValueError(
"Update function must accept two parameters of type ConversableAgent and List[Dict[str Any]], respectively"
)
if sig.return_annotation != str:
raise ValueError("Update function must return a string")
else:
raise ValueError("Update function must be either a string or a callable")


def initiate_swarm_chat(
initial_agent: "SwarmAgent",
messages: Union[List[Dict[str, Any]], str],
Expand Down Expand Up @@ -107,12 +132,27 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any
name="Tool_Execution",
system_message="Tool Execution",
)
tool_execution._set_to_tool_execution(context_variables=context_variables)
tool_execution._set_to_tool_execution()

# Update tool execution agent with all the functions from all the agents
for agent in agents:
tool_execution._function_map.update(agent._function_map)

# Point all SwarmAgent's context variables to this function's context_variables
# providing a single (shared) context across all SwarmAgents in the swarm
for agent in agents + [tool_execution]:
agent._context_variables = context_variables

INIT_AGENT_USED = False

def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
"""Swarm transition function to determine the next agent in the conversation"""
"""Swarm transition function to determine and prepare the next agent in the conversation"""
next_agent = determine_next_agent(last_speaker, groupchat)

return next_agent

def determine_next_agent(last_speaker: SwarmAgent, groupchat: GroupChat):
"""Determine the next agent in the conversation"""
nonlocal INIT_AGENT_USED
if not INIT_AGENT_USED:
INIT_AGENT_USED = True
Expand Down Expand Up @@ -310,6 +350,9 @@ def __init__(
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
description: Optional[str] = None,
code_execution_config=False,
update_agent_before_reply: Optional[
Union[List[Union[Callable, UPDATE_SYSTEM_MESSAGE]], Callable, UPDATE_SYSTEM_MESSAGE]
] = None,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -335,23 +378,70 @@ def __init__(

self.after_work = None

# Used only in the tool execution agent for context and transferring to the next agent
# Note: context variables are not stored for each agent
self._context_variables = {}
# Used in the tool execution agent to transfer to the next agent
self._next_agent = None

# Store nested chats hand offs as we'll establish these in the initiate_swarm_chat
# List of Dictionaries containing the nested_chats and condition
self._nested_chat_handoffs = []

def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None):
self.register_update_agent_before_reply(update_agent_before_reply)

def register_update_agent_before_reply(self, functions: Optional[Union[List[Callable], Callable]]):
"""
Register functions that will be called when the agent is selected and before it speaks.
You can add your own validation or precondition functions here.

Args:
functions (List[Callable[[], None]]): A list of functions to be registered. Each function
is called when the agent is selected and before it speaks.
"""
if functions is None:
return
if not isinstance(functions, list) and type(functions) not in [UPDATE_SYSTEM_MESSAGE, Callable]:
raise ValueError("functions must be a list of callables")

if not isinstance(functions, list):
functions = [functions]

for func in functions:
if isinstance(func, UPDATE_SYSTEM_MESSAGE):

# Wrapper function that allows this to be used in the update_agent_state hook
# Its primary purpose, however, is just to update the agent's system message
# Outer function to create a closure with the update function
def create_wrapper(update_func: UPDATE_SYSTEM_MESSAGE):
def update_system_message_wrapper(
agent: ConversableAgent, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
if isinstance(update_func.update_function, str):
# Templates like "My context variable passport is {passport}" will
# use the context_variables for substitution
sys_message = OpenAIWrapper.instantiate(
template=update_func.update_function,
context=agent._context_variables,
allow_format_str_template=True,
)
else:
sys_message = update_func.update_function(agent, messages)

agent.update_system_message(sys_message)
return messages

return update_system_message_wrapper

self.register_hook(hookable_method="update_agent_state", hook=create_wrapper(func))

else:
self.register_hook(hookable_method="update_agent_state", hook=func)

def _set_to_tool_execution(self):
"""Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents.
This agent will be used internally and should not be visible to the user.

It will execute the tool calls and update the context_variables and next_agent accordingly.
It will execute the tool calls and update the referenced context_variables and next_agent accordingly.
"""
self._next_agent = None
self._context_variables = context_variables or {}
self._reply_func_list.clear()
self.register_reply([Agent, None], SwarmAgent.generate_swarm_tool_reply)

Expand Down Expand Up @@ -491,6 +581,7 @@ def generate_swarm_tool_reply(
return False, None

def add_single_function(self, func: Callable, name=None, description=""):
"""Add a single function to the agent, removing context variables for LLM use"""
if name:
func._name = name
else:
Expand Down
19 changes: 19 additions & 0 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def __init__(
"process_last_received_message": [],
"process_all_messages_before_reply": [],
"process_message_before_send": [],
"update_agent_state": [],
}

def _validate_llm_config(self, llm_config):
Expand Down Expand Up @@ -2091,6 +2092,9 @@ def generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables.
self.process_update_agent_states(messages)
linmou marked this conversation as resolved.
Show resolved Hide resolved

# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_received_message(messages)
Expand Down Expand Up @@ -2161,6 +2165,9 @@ async def a_generate_reply(
if messages is None:
messages = self._oai_messages[sender]

# Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables.
self.process_update_agent_states(messages)

# Call the hookable method that gives registered hooks a chance to process all messages.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_all_messages_before_reply(messages)
Expand Down Expand Up @@ -2847,6 +2854,18 @@ def register_hook(self, hookable_method: str, hook: Callable):
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)

def process_update_agent_states(self, messages: List[Dict]) -> None:
"""
Calls any registered capability hooks to update the agent's state.
Primarily used to update context variables.
Will, potentially, modify the messages.
"""
hook_list = self.hook_lists["update_agent_state"]

# Call each hook (in order of registration) to process the messages.
for hook in hook_list:
hook(self, messages)

def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]:
"""
Calls any registered capability hooks to process all messages, potentially modifying the messages.
Expand Down
98 changes: 97 additions & 1 deletion test/agentchat/contrib/test_swarm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict
from typing import Any, Dict, List
from unittest.mock import MagicMock, patch

import pytest
Expand All @@ -10,6 +10,7 @@
__CONTEXT_VARIABLES_PARAM_NAME__,
AFTER_WORK,
ON_CONDITION,
UPDATE_SYSTEM_MESSAGE,
AfterWorkOption,
SwarmAgent,
SwarmResult,
Expand Down Expand Up @@ -461,6 +462,101 @@ def test_initialization():
)


def test_update_system_message():
"""Tests the update_agent_before_reply functionality with multiple scenarios"""

# Test container to capture system messages
class MessageContainer:
def __init__(self):
self.captured_sys_message = ""

message_container = MessageContainer()

# 1. Test with a callable function
def custom_update_function(agent: ConversableAgent, messages: List[Dict]) -> str:
return f"System message with {agent.get_context('test_var')} and {len(messages)} messages"

# 2. Test with a string template
template_message = "Template message with {test_var}"

# Create agents with different update configurations
agent1 = SwarmAgent("agent1", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function))

agent2 = SwarmAgent("agent2", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(template_message))

# Mock the reply function to capture the system message
def mock_generate_oai_reply(*args, **kwargs):
# Capture the system message for verification
message_container.captured_sys_message = args[0]._oai_system_message[0]["content"]
return True, "Mock response"

# Register mock reply for both agents
agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply)
agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply)

# Test context and messages
test_context = {"test_var": "test_value"}
test_messages = [{"role": "user", "content": "Test message"}]

# Run chat with first agent (using callable function)
chat_result1, context_vars1, last_speaker1 = initiate_swarm_chat(
initial_agent=agent1, messages=test_messages, agents=[agent1], context_variables=test_context, max_rounds=2
)

# Verify callable function result
assert message_container.captured_sys_message == "System message with test_value and 1 messages"

# Reset captured message
message_container.captured_sys_message = ""

# Run chat with second agent (using string template)
chat_result2, context_vars2, last_speaker2 = initiate_swarm_chat(
initial_agent=agent2, messages=test_messages, agents=[agent2], context_variables=test_context, max_rounds=2
)

# Verify template result
assert message_container.captured_sys_message == "Template message with test_value"

# Test invalid update function
with pytest.raises(ValueError, match="Update function must be either a string or a callable"):
SwarmAgent("agent3", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(123))

# Test invalid callable (wrong number of parameters)
def invalid_update_function(context_variables):
return "Invalid function"

with pytest.raises(ValueError, match="Update function must accept two parameters"):
SwarmAgent("agent4", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_update_function))

# Test invalid callable (wrong return type)
def invalid_return_function(context_variables, messages) -> dict:
return {}

with pytest.raises(ValueError, match="Update function must return a string"):
SwarmAgent("agent5", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_return_function))

# Test multiple update functions
def another_update_function(context_variables: Dict[str, Any], messages: List[Dict]) -> str:
return "Another update"

agent6 = SwarmAgent(
"agent6",
update_agent_before_reply=[
UPDATE_SYSTEM_MESSAGE(custom_update_function),
UPDATE_SYSTEM_MESSAGE(another_update_function),
],
)

agent6.register_reply([ConversableAgent, None], mock_generate_oai_reply)

chat_result6, context_vars6, last_speaker6 = initiate_swarm_chat(
initial_agent=agent6, messages=test_messages, agents=[agent6], context_variables=test_context, max_rounds=2
)

# Verify last update function took effect
assert message_container.captured_sys_message == "Another update"


def test_string_agent_params_for_transfer():
"""Test that string agent parameters are handled correctly without using real LLMs."""
# Define test configuration
Expand Down
Loading
Loading