Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swarm: Allow functions to update agent's state, including system message, before replying #104

Merged
merged 22 commits into from
Dec 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9a8c1c5
Added system_message_func to SwarmAgent and update of sys message whe…
marklysze Nov 28, 2024
db729dd
Add a test for the system message function
marklysze Nov 28, 2024
e768a69
Interim commit with context_variables on ConversableAgent
marklysze Nov 30, 2024
40c0b47
Implemented update_agent_state hook, UPDATE_SYSTEM_MESSAGE
marklysze Nov 30, 2024
0b8ba3b
process_update_agent_states no longer returns messages
marklysze Nov 30, 2024
12b0bbe
Update hook to pass in agent and messages (context available on agent…
marklysze Nov 30, 2024
8fddf90
Merge remote-tracking branch 'origin/main' into swarmsysmsgfunc
marklysze Nov 30, 2024
d212e2b
Updated context variable access methods, update_agent_before_reply pa…
marklysze Nov 30, 2024
bb38573
test update_system_message
linmou Nov 30, 2024
10a4e8f
pre-commit updates
marklysze Dec 1, 2024
623727b
Fix for ConversableAgent's a_generate_reply
marklysze Dec 1, 2024
8188593
Added ConversableAgent context variable tests
marklysze Dec 1, 2024
8425600
Merge remote-tracking branch 'origin/main' into swarmsysmsgfunc
marklysze Dec 2, 2024
3cdad79
Merge branch 'main' into swarmsysmsgfunc
marklysze Dec 3, 2024
482a60e
Merge remote-tracking branch 'origin/main' into swarmsysmsgfunc
marklysze Dec 6, 2024
b9352da
Corrected missing variable from nested chat PR
marklysze Dec 6, 2024
71cc5c7
Restore conversable agent context getters/setters
marklysze Dec 6, 2024
790f037
Docs and update system message callable signature change
marklysze Dec 7, 2024
675c82d
Merge remote-tracking branch 'origin/main' into swarmsysmsgfunc
marklysze Dec 8, 2024
58ba2ad
Merge remote-tracking branch 'origin/main' into swarmsysmsgfunc
marklysze Dec 15, 2024
2c3e063
Updated parameter name to update_agent_state_before_reply
marklysze Dec 15, 2024
bf0de64
Update tests for update_agent_state_before_reply
marklysze Dec 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,17 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any
INIT_AGENT_USED = False

def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat):
"""Swarm transition function to determine the next agent in the conversation"""
"""Swarm transition function to determine and prepare the next agent in the conversation"""
next_agent = determine_next_agent(last_speaker, groupchat)

if next_agent and isinstance(next_agent, SwarmAgent):
# Update their state
next_agent.update_state(context_variables, groupchat.messages)

return next_agent

def determine_next_agent(last_speaker: SwarmAgent, groupchat: GroupChat):
"""Determine the next agent in the conversation"""
nonlocal INIT_AGENT_USED
if not INIT_AGENT_USED:
INIT_AGENT_USED = True
Expand Down Expand Up @@ -257,6 +267,7 @@ def __init__(
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
description: Optional[str] = None,
code_execution_config=False,
system_message_func: Optional[Callable] = None,
**kwargs,
) -> None:
super().__init__(
Expand Down Expand Up @@ -286,6 +297,8 @@ def __init__(
self._context_variables = {}
self._next_agent = None

self._system_message_func = system_message_func

marklysze marked this conversation as resolved.
Show resolved Hide resolved
def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None):
"""Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents.
This agent will be used internally and should not be visible to the user.
Expand Down Expand Up @@ -458,6 +471,11 @@ def add_functions(self, func_list: List[Callable]):
for func in func_list:
self.add_single_function(func)

def update_state(self, context_variables: Optional[Dict[str, Any]], messages: List[Dict[str, Any]]):
"""Updates the state of the agent, system message so far. This is called when they're selected and just before they speak."""
if self._system_message_func:
self.update_system_message(self._system_message_func(context_variables, messages))


# Forward references for SwarmAgent in SwarmResult
SwarmResult.update_forward_refs()
44 changes: 44 additions & 0 deletions test/agentchat/contrib/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,50 @@ def test_initialization():
initial_agent=agent1, messages=TEST_MESSAGES, agents=[agent1, agent2], max_rounds=3
)

def test_sys_message_func():
"""Tests a custom system message function"""

# This test will use context variables and the messages to construct a custom system message
# This will be available at the point of reply (we use register a reply to capture it at that point)

# To store the system message
class MessageContainer:
def __init__(self):
self.final_sys_message = ""

message_container = MessageContainer()
marklysze marked this conversation as resolved.
Show resolved Hide resolved

def my_sys_message(context_variables, messages) -> str:
return f"This is a custom system message with {context_variables['sample_name']} and a total of {len(messages)} message(s)."

agent1 = SwarmAgent("agent1", system_message_func=my_sys_message)
agent2 = SwarmAgent("agent2")

test_context_variables = {"sample_name": "Bob"}

# Mock a reply to be able to capture the system message
def mock_generate_oai_reply(*args, **kwargs):
message_container.final_sys_message = args[0]._oai_system_message[0][
"content"
] # The SwarmAgent's system message
return True, "This is a mock response from the agent."

agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply)

chat_result, context_vars, last_speaker = initiate_swarm_chat(
initial_agent=agent1,
messages=TEST_MESSAGES,
agents=[agent1, agent2],
context_variables=test_context_variables,
max_rounds=4,
)

# The system message should be the custom message
assert (
message_container.final_sys_message
== "This is a custom system message with Bob and a total of 1 message(s)."
)


if __name__ == "__main__":
pytest.main([__file__])
Loading