From 9a8c1c5517eebca9555fad770442baa8828eb163 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Thu, 28 Nov 2024 01:53:31 +0000 Subject: [PATCH 01/16] Added system_message_func to SwarmAgent and update of sys message when selected Signed-off-by: Mark Sze --- autogen/agentchat/contrib/swarm_agent.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index c1c790a906..b3f1bd931c 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -110,7 +110,17 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any INIT_AGENT_USED = False def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat): - """Swarm transition function to determine the next agent in the conversation""" + """Swarm transition function to determine and prepare the next agent in the conversation""" + next_agent = determine_next_agent(last_speaker, groupchat) + + if next_agent and isinstance(next_agent, SwarmAgent): + # Update their state + next_agent.update_state(context_variables, groupchat.messages) + + return next_agent + + def determine_next_agent(last_speaker: SwarmAgent, groupchat: GroupChat): + """Determine the next agent in the conversation""" nonlocal INIT_AGENT_USED if not INIT_AGENT_USED: INIT_AGENT_USED = True @@ -257,6 +267,7 @@ def __init__( human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", description: Optional[str] = None, code_execution_config=False, + system_message_func: Optional[Callable] = None, **kwargs, ) -> None: super().__init__( @@ -286,6 +297,8 @@ def __init__( self._context_variables = {} self._next_agent = None + self._system_message_func = system_message_func + def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None): """Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents. This agent will be used internally and should not be visible to the user. @@ -458,6 +471,11 @@ def add_functions(self, func_list: List[Callable]): for func in func_list: self.add_single_function(func) + def update_state(self, context_variables: Optional[Dict[str, Any]], messages: List[Dict[str, Any]]): + """Updates the state of the agent, system message so far. This is called when they're selected and just before they speak.""" + if self._system_message_func: + self.update_system_message(self._system_message_func(context_variables, messages)) + # Forward references for SwarmAgent in SwarmResult SwarmResult.update_forward_refs() From db729dd71096622e280b2bf855bdcf8ca7c9d33e Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Thu, 28 Nov 2024 02:11:20 +0000 Subject: [PATCH 02/16] Add a test for the system message function Signed-off-by: Mark Sze --- test/agentchat/contrib/test_swarm.py | 44 ++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 0987ba9ca0..ca906ea194 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -460,6 +460,50 @@ def test_initialization(): initial_agent=agent1, messages=TEST_MESSAGES, agents=[agent1, agent2], max_rounds=3 ) + def test_sys_message_func(): + """Tests a custom system message function""" + + # This test will use context variables and the messages to construct a custom system message + # This will be available at the point of reply (we use register a reply to capture it at that point) + + # To store the system message + class MessageContainer: + def __init__(self): + self.final_sys_message = "" + + message_container = MessageContainer() + + def my_sys_message(context_variables, messages) -> str: + return f"This is a custom system message with {context_variables['sample_name']} and a total of {len(messages)} message(s)." + + agent1 = SwarmAgent("agent1", system_message_func=my_sys_message) + agent2 = SwarmAgent("agent2") + + test_context_variables = {"sample_name": "Bob"} + + # Mock a reply to be able to capture the system message + def mock_generate_oai_reply(*args, **kwargs): + message_container.final_sys_message = args[0]._oai_system_message[0][ + "content" + ] # The SwarmAgent's system message + return True, "This is a mock response from the agent." + + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply) + + chat_result, context_vars, last_speaker = initiate_swarm_chat( + initial_agent=agent1, + messages=TEST_MESSAGES, + agents=[agent1, agent2], + context_variables=test_context_variables, + max_rounds=4, + ) + + # The system message should be the custom message + assert ( + message_container.final_sys_message + == "This is a custom system message with Bob and a total of 1 message(s)." + ) + if __name__ == "__main__": pytest.main([__file__]) From e768a6923973c95ea1a62f1e66628f4de4ab3943 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sat, 30 Nov 2024 04:30:55 +0000 Subject: [PATCH 03/16] Interim commit with context_variables on ConversableAgent Signed-off-by: Mark Sze --- autogen/agentchat/contrib/swarm_agent.py | 56 +++++++++++++++++++----- autogen/agentchat/conversable_agent.py | 4 ++ 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index b3f1bd931c..af8f2272bc 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -101,12 +101,17 @@ def custom_afterwork_func(last_speaker: SwarmAgent, messages: List[Dict[str, Any name="Tool_Execution", system_message="Tool Execution", ) - tool_execution._set_to_tool_execution(context_variables=context_variables) + tool_execution._set_to_tool_execution() # Update tool execution agent with all the functions from all the agents for agent in agents: tool_execution._function_map.update(agent._function_map) + # Point all SwarmAgent's context variables to this function's context_variables + # providing a single (shared) context across all SwarmAgents in the swarm + for agent in agents + [tool_execution]: + agent._context_variables = context_variables + INIT_AGENT_USED = False def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat): @@ -267,7 +272,7 @@ def __init__( human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", description: Optional[str] = None, code_execution_config=False, - system_message_func: Optional[Callable] = None, + update_state_functions: Optional[Union[List[Callable], Callable]] = None, **kwargs, ) -> None: super().__init__( @@ -294,19 +299,40 @@ def __init__( self.after_work = None # use in the tool execution agent to transfer to the next agent - self._context_variables = {} self._next_agent = None - self._system_message_func = system_message_func + self.register_update_states_functions(update_state_functions) + + def register_update_states_functions(self, functions: Optional[Union[List[Callable], Callable]]): + """ + Register functions that will be called when the agent is selected and before it speaks. + You can add your own validation or precondition functions here. + + Args: + functions (List[Callable[[], None]]): A list of functions to be registered. Each function + is called when the agent is selected and before it speaks. + """ + + # TEMP - THIS WILL BE UPDATED TO UTILISE A NEW HOOK - update_agent_state - def _set_to_tool_execution(self, context_variables: Optional[Dict[str, Any]] = None): + if functions is None: + return + if not isinstance(functions, list) and not isinstance(functions, Callable): + raise ValueError("functions must be a list of callables") + + if isinstance(functions, Callable): + functions = [functions] + + for func in functions: + self.register_hook("update_states_once_selected", func) + + def _set_to_tool_execution(self): """Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents. This agent will be used internally and should not be visible to the user. - It will execute the tool calls and update the context_variables and next_agent accordingly. + It will execute the tool calls and update the referenced context_variables and next_agent accordingly. """ self._next_agent = None - self._context_variables = context_variables or {} self._reply_func_list.clear() self.register_reply([Agent, None], SwarmAgent.generate_swarm_tool_reply) @@ -472,9 +498,19 @@ def add_functions(self, func_list: List[Callable]): self.add_single_function(func) def update_state(self, context_variables: Optional[Dict[str, Any]], messages: List[Dict[str, Any]]): - """Updates the state of the agent, system message so far. This is called when they're selected and just before they speak.""" - if self._system_message_func: - self.update_system_message(self._system_message_func(context_variables, messages)) + """Updates the state of the agent prior to reply""" + + # TEMP - THIS WILL BE REPLACED BY A NEW HOOK - update_agent_state + + for hook in self.hook_lists["update_states_once_selected"]: + result = hook(self, context_variables, messages) + + if result is None: + continue + + returned_variables, returned_messages = result + self._context_variables.update(returned_variables) + messages = self.process_all_messages_before_reply(returned_messages) # Forward references for SwarmAgent in SwarmResult diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 840da79204..b558038eec 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -85,6 +85,7 @@ def __init__( description: Optional[str] = None, chat_messages: Optional[Dict[Agent, List[Dict]]] = None, silent: Optional[bool] = None, + context_variables: Optional[Dict[str, Any]] = None, ): """ Args: @@ -135,6 +136,7 @@ def __init__( resume previous had conversations. Defaults to an empty chat history. silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of silent in each function. + context_variables (dict or None): Context variables that provide a persistent context for the agent. Only used in Swarms at this stage. """ # we change code_execution_config below and we have to make sure we don't change the input # in case of UserProxyAgent, without this we could even change the default value {} @@ -193,6 +195,8 @@ def __init__( self.register_reply([Agent, None], ConversableAgent.generate_oai_reply) self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True) + self._context_variables = context_variables if context_variables is not None else {} + # Setting up code execution. # Do not register code execution reply if code execution is disabled. if code_execution_config is not False: From 40c0b475c6684181d99fafd7db964e6179d722d2 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sat, 30 Nov 2024 05:57:10 +0000 Subject: [PATCH 04/16] Implemented update_agent_state hook, UPDATE_SYSTEM_MESSAGE Signed-off-by: Mark Sze --- autogen/agentchat/__init__.py | 2 + autogen/agentchat/contrib/swarm_agent.py | 73 ++++++++++++++++-------- autogen/agentchat/conversable_agent.py | 25 ++++++++ 3 files changed, 75 insertions(+), 25 deletions(-) diff --git a/autogen/agentchat/__init__.py b/autogen/agentchat/__init__.py index 6c3c12e6ce..c41820bf9b 100644 --- a/autogen/agentchat/__init__.py +++ b/autogen/agentchat/__init__.py @@ -12,6 +12,7 @@ from .contrib.swarm_agent import ( AFTER_WORK, ON_CONDITION, + UPDATE_SYSTEM_MESSAGE, AfterWorkOption, SwarmAgent, SwarmResult, @@ -39,4 +40,5 @@ "ON_CONDITION", "AFTER_WORK", "AfterWorkOption", + "UPDATE_SYSTEM_MESSAGE", ] diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index af8f2272bc..1465e60564 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -49,6 +49,23 @@ def __post_init__(self): assert isinstance(self.agent, SwarmAgent), "Agent must be a SwarmAgent" +@dataclass +class UPDATE_SYSTEM_MESSAGE: + update_function: Union[Callable, str] + + def __post_init__(self): + if isinstance(self.update_function, str): + pass + elif isinstance(self.update_function, Callable): + sig = signature(self.update_function) + if len(sig.parameters) != 2: + raise ValueError("Update function must accept two parameters, context_variables and messages") + if sig.return_annotation != str: + raise ValueError("Update function must return a string") + else: + raise ValueError("Update function must be either a string or a callable") + + def initiate_swarm_chat( initial_agent: "SwarmAgent", messages: Union[List[Dict[str, Any]], str], @@ -118,10 +135,6 @@ def swarm_transition(last_speaker: SwarmAgent, groupchat: GroupChat): """Swarm transition function to determine and prepare the next agent in the conversation""" next_agent = determine_next_agent(last_speaker, groupchat) - if next_agent and isinstance(next_agent, SwarmAgent): - # Update their state - next_agent.update_state(context_variables, groupchat.messages) - return next_agent def determine_next_agent(last_speaker: SwarmAgent, groupchat: GroupChat): @@ -301,9 +314,9 @@ def __init__( # use in the tool execution agent to transfer to the next agent self._next_agent = None - self.register_update_states_functions(update_state_functions) + self.register_update_state_functions(update_state_functions) - def register_update_states_functions(self, functions: Optional[Union[List[Callable], Callable]]): + def register_update_state_functions(self, functions: Optional[Union[List[Callable], Callable]]): """ Register functions that will be called when the agent is selected and before it speaks. You can add your own validation or precondition functions here. @@ -312,9 +325,6 @@ def register_update_states_functions(self, functions: Optional[Union[List[Callab functions (List[Callable[[], None]]): A list of functions to be registered. Each function is called when the agent is selected and before it speaks. """ - - # TEMP - THIS WILL BE UPDATED TO UTILISE A NEW HOOK - update_agent_state - if functions is None: return if not isinstance(functions, list) and not isinstance(functions, Callable): @@ -324,7 +334,35 @@ def register_update_states_functions(self, functions: Optional[Union[List[Callab functions = [functions] for func in functions: - self.register_hook("update_states_once_selected", func) + if isinstance(func, UPDATE_SYSTEM_MESSAGE): + + # Wrapper function that allows this to be used in the update_agent_state hook + # Its primary purpose, however, is just to update the agent's system message + # Outer function to create a closure with the update function + def create_wrapper(update_func: UPDATE_SYSTEM_MESSAGE): + def update_system_message_wrapper( + context_variables: Dict[str, Any], messages: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + if isinstance(update_func.update_function, str): + # Templates like "My context variable passport is {passport}" will + # use the context_variables for substitution + sys_message = OpenAIWrapper.instantiate( + template=update_func.update_function, + context=context_variables, + allow_format_str_template=True, + ) + else: + sys_message = update_func.update_function(context_variables, messages) + + self.update_system_message(sys_message) + return messages + + return update_system_message_wrapper + + self.register_hook(hookable_method="update_agent_state", hook=create_wrapper(func)) + + else: + self.register_hook(hookable_method="update_agent_state", hook=func) def _set_to_tool_execution(self): """Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents. @@ -497,21 +535,6 @@ def add_functions(self, func_list: List[Callable]): for func in func_list: self.add_single_function(func) - def update_state(self, context_variables: Optional[Dict[str, Any]], messages: List[Dict[str, Any]]): - """Updates the state of the agent prior to reply""" - - # TEMP - THIS WILL BE REPLACED BY A NEW HOOK - update_agent_state - - for hook in self.hook_lists["update_states_once_selected"]: - result = hook(self, context_variables, messages) - - if result is None: - continue - - returned_variables, returned_messages = result - self._context_variables.update(returned_variables) - messages = self.process_all_messages_before_reply(returned_messages) - # Forward references for SwarmAgent in SwarmResult SwarmResult.update_forward_refs() diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index b558038eec..e05510e8a5 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -261,6 +261,7 @@ def __init__( "process_last_received_message": [], "process_all_messages_before_reply": [], "process_message_before_send": [], + "update_agent_state": [], } def _validate_llm_config(self, llm_config): @@ -2046,6 +2047,9 @@ def generate_reply( if messages is None: messages = self._oai_messages[sender] + # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. + messages = self.process_update_agent_states(messages) + # Call the hookable method that gives registered hooks a chance to process the last message. # Message modifications do not affect the incoming messages or self._oai_messages. messages = self.process_last_received_message(messages) @@ -2116,6 +2120,9 @@ async def a_generate_reply( if messages is None: messages = self._oai_messages[sender] + # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. + messages = self.process_update_agent_states(messages) + # Call the hookable method that gives registered hooks a chance to process all messages. # Message modifications do not affect the incoming messages or self._oai_messages. messages = self.process_all_messages_before_reply(messages) @@ -2802,6 +2809,24 @@ def register_hook(self, hookable_method: str, hook: Callable): assert hook not in hook_list, f"{hook} is already registered as a hook." hook_list.append(hook) + def process_update_agent_states(self, messages: List[Dict]) -> List[Dict]: + """ + Calls any registered capability hooks to update the agent's state. + Primarily used to update context variables. + Will, potentially, modify the messages. + """ + hook_list = self.hook_lists["update_agent_state"] + + # If no hooks are registered, or if there are no messages to process, return the original message list. + if len(hook_list) == 0 or messages is None: + return messages + + # Call each hook (in order of registration) to process the messages. + processed_messages = messages + for hook in hook_list: + processed_messages = hook(self._context_variables, processed_messages) + return processed_messages + def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: """ Calls any registered capability hooks to process all messages, potentially modifying the messages. From 0b8ba3be9e710fef4618c13ea373796492c84976 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sat, 30 Nov 2024 06:08:01 +0000 Subject: [PATCH 05/16] process_update_agent_states no longer returns messages Signed-off-by: Mark Sze --- autogen/agentchat/conversable_agent.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index e05510e8a5..4d06fd33e4 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -2048,7 +2048,7 @@ def generate_reply( messages = self._oai_messages[sender] # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. - messages = self.process_update_agent_states(messages) + self.process_update_agent_states(messages) # Call the hookable method that gives registered hooks a chance to process the last message. # Message modifications do not affect the incoming messages or self._oai_messages. @@ -2809,7 +2809,7 @@ def register_hook(self, hookable_method: str, hook: Callable): assert hook not in hook_list, f"{hook} is already registered as a hook." hook_list.append(hook) - def process_update_agent_states(self, messages: List[Dict]) -> List[Dict]: + def process_update_agent_states(self, messages: List[Dict]) -> None: """ Calls any registered capability hooks to update the agent's state. Primarily used to update context variables. @@ -2817,15 +2817,9 @@ def process_update_agent_states(self, messages: List[Dict]) -> List[Dict]: """ hook_list = self.hook_lists["update_agent_state"] - # If no hooks are registered, or if there are no messages to process, return the original message list. - if len(hook_list) == 0 or messages is None: - return messages - # Call each hook (in order of registration) to process the messages. - processed_messages = messages for hook in hook_list: - processed_messages = hook(self._context_variables, processed_messages) - return processed_messages + hook(self._context_variables, messages) def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: """ From 12b0bbe347d61251cc41814066caa8b10c82f375 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sat, 30 Nov 2024 08:49:48 +0000 Subject: [PATCH 06/16] Update hook to pass in agent and messages (context available on agent), context access functions Signed-off-by: Mark Sze --- autogen/agentchat/contrib/swarm_agent.py | 6 +-- autogen/agentchat/conversable_agent.py | 47 +++++++++++++++++++++++- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 1465e60564..53258115e0 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -341,18 +341,18 @@ def register_update_state_functions(self, functions: Optional[Union[List[Callabl # Outer function to create a closure with the update function def create_wrapper(update_func: UPDATE_SYSTEM_MESSAGE): def update_system_message_wrapper( - context_variables: Dict[str, Any], messages: List[Dict[str, Any]] + agent: ConversableAgent, messages: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: if isinstance(update_func.update_function, str): # Templates like "My context variable passport is {passport}" will # use the context_variables for substitution sys_message = OpenAIWrapper.instantiate( template=update_func.update_function, - context=context_variables, + context=agent._context_variables, allow_format_str_template=True, ) else: - sys_message = update_func.update_function(context_variables, messages) + sys_message = update_func.update_function(self, messages) self.update_system_message(sys_message) return messages diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 4d06fd33e4..ef33dd3bc3 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -530,6 +530,51 @@ def system_message(self) -> str: """Return the system message.""" return self._oai_system_message[0]["content"] + def get_context(self, key: str, default: Any = None) -> Any: + """ + Get a context variable by key. + + Args: + key: The key to look up + default: Value to return if key doesn't exist + + Returns: + The value associated with the key, or default if not found + """ + return self._context_variables.get(key, default) + + def set_context(self, key: str, value: Any) -> None: + """ + Set a context variable. + + Args: + key: The key to set + value: The value to associate with the key + """ + self._context_variables[key] = value + + def update_context(self, context_variables: Dict[str, Any]) -> None: + """ + Update multiple context variables at once. + + Args: + context_variables: Dictionary of variables to update/add + """ + self._context_variables.update(context_variables) + + def pop_context(self, key: str, default: Any = None) -> Any: + """ + Remove and return a context variable. + + Args: + key: The key to remove + default: Value to return if key doesn't exist + + Returns: + The value that was removed, or default if key not found + """ + return self._context_variables.pop(key, default) + def update_system_message(self, system_message: str) -> None: """Update the system message. @@ -2819,7 +2864,7 @@ def process_update_agent_states(self, messages: List[Dict]) -> None: # Call each hook (in order of registration) to process the messages. for hook in hook_list: - hook(self._context_variables, messages) + hook(self, messages) def process_all_messages_before_reply(self, messages: List[Dict]) -> List[Dict]: """ From d212e2b223faacee507aa01e7073c7287bbae6ef Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sat, 30 Nov 2024 09:08:11 +0000 Subject: [PATCH 07/16] Updated context variable access methods, update_agent_before_reply parameter name changed Signed-off-by: Mark Sze --- autogen/agentchat/contrib/swarm_agent.py | 6 +++--- autogen/agentchat/conversable_agent.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index a9b2a1733f..957551f486 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -286,7 +286,7 @@ def __init__( human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", description: Optional[str] = None, code_execution_config=False, - update_state_functions: Optional[Union[List[Callable], Callable]] = None, + update_agent_before_reply: Optional[Union[List[Callable], Callable]] = None, **kwargs, ) -> None: super().__init__( @@ -315,9 +315,9 @@ def __init__( # use in the tool execution agent to transfer to the next agent self._next_agent = None - self.register_update_state_functions(update_state_functions) + self.register_update_agent_before_reply(update_agent_before_reply) - def register_update_state_functions(self, functions: Optional[Union[List[Callable], Callable]]): + def register_update_agent_before_reply(self, functions: Optional[Union[List[Callable], Callable]]): """ Register functions that will be called when the agent is selected and before it speaks. You can add your own validation or precondition functions here. diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index ef33dd3bc3..91fa7d7fb0 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -530,7 +530,7 @@ def system_message(self) -> str: """Return the system message.""" return self._oai_system_message[0]["content"] - def get_context(self, key: str, default: Any = None) -> Any: + def get_context_value(self, key: str, default: Any = None) -> Any: """ Get a context variable by key. @@ -543,7 +543,7 @@ def get_context(self, key: str, default: Any = None) -> Any: """ return self._context_variables.get(key, default) - def set_context(self, key: str, value: Any) -> None: + def set_context_value(self, key: str, value: Any) -> None: """ Set a context variable. @@ -553,7 +553,7 @@ def set_context(self, key: str, value: Any) -> None: """ self._context_variables[key] = value - def update_context(self, context_variables: Dict[str, Any]) -> None: + def set_context_values(self, context_variables: Dict[str, Any]) -> None: """ Update multiple context variables at once. @@ -562,7 +562,7 @@ def update_context(self, context_variables: Dict[str, Any]) -> None: """ self._context_variables.update(context_variables) - def pop_context(self, key: str, default: Any = None) -> Any: + def pop_context_key(self, key: str, default: Any = None) -> Any: """ Remove and return a context variable. From bb385735df029ed43a7f454e6164678ae609098c Mon Sep 17 00:00:00 2001 From: "margelnin@gmail.com" Date: Sat, 30 Nov 2024 13:57:26 -0500 Subject: [PATCH 08/16] test update_system_message --- autogen/agentchat/contrib/swarm_agent.py | 18 ++- test/agentchat/contrib/test_swarm.py | 137 +++++++++++++++++------ 2 files changed, 115 insertions(+), 40 deletions(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 957551f486..ebe9b9af7b 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -6,7 +6,9 @@ from dataclasses import dataclass from enum import Enum from inspect import signature +import re from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +import warnings from pydantic import BaseModel @@ -56,7 +58,11 @@ class UPDATE_SYSTEM_MESSAGE: def __post_init__(self): if isinstance(self.update_function, str): - pass + # find all {var} in the string + vars = re.findall(r"\{(\w+)\}", self.update_function) + if len(vars) == 0: + warnings.warn("Update function string contains no variables. This is probably unintended.") + elif isinstance(self.update_function, Callable): sig = signature(self.update_function) if len(sig.parameters) != 2: @@ -286,7 +292,7 @@ def __init__( human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", description: Optional[str] = None, code_execution_config=False, - update_agent_before_reply: Optional[Union[List[Callable], Callable]] = None, + update_agent_before_reply: Optional[Union[List[Union[Callable, UPDATE_SYSTEM_MESSAGE]], Callable, UPDATE_SYSTEM_MESSAGE]] = None, **kwargs, ) -> None: super().__init__( @@ -328,10 +334,10 @@ def register_update_agent_before_reply(self, functions: Optional[Union[List[Call """ if functions is None: return - if not isinstance(functions, list) and not isinstance(functions, Callable): + if not isinstance(functions, list) and type(functions) not in [UPDATE_SYSTEM_MESSAGE, Callable]: raise ValueError("functions must be a list of callables") - if isinstance(functions, Callable): + if type(functions) is not list: functions = [functions] for func in functions: @@ -353,9 +359,9 @@ def update_system_message_wrapper( allow_format_str_template=True, ) else: - sys_message = update_func.update_function(self, messages) + sys_message = update_func.update_function(agent._context_variables, messages) - self.update_system_message(sys_message) + agent.update_system_message(sys_message) return messages return update_system_message_wrapper diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index ca906ea194..8cf4de5c3d 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -1,7 +1,7 @@ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict +from typing import Any, Dict, List from unittest.mock import MagicMock, patch import pytest @@ -11,6 +11,7 @@ AFTER_WORK, ON_CONDITION, AfterWorkOption, + UPDATE_SYSTEM_MESSAGE, SwarmAgent, SwarmResult, initiate_swarm_chat, @@ -460,50 +461,118 @@ def test_initialization(): initial_agent=agent1, messages=TEST_MESSAGES, agents=[agent1, agent2], max_rounds=3 ) - def test_sys_message_func(): - """Tests a custom system message function""" - # This test will use context variables and the messages to construct a custom system message - # This will be available at the point of reply (we use register a reply to capture it at that point) +def test_update_system_message(): + """Tests the update_agent_before_reply functionality with different scenarios""" + + # Test container to capture system messages + class MessageContainer: + def __init__(self): + self.captured_sys_message = "" + + message_container = MessageContainer() - # To store the system message - class MessageContainer: - def __init__(self): - self.final_sys_message = "" + # 1. Test with a callable function + def custom_update_function(context_variables: Dict[str, Any], messages: List[Dict]) -> str: + return f"System message with {context_variables['test_var']} and {len(messages)} messages" - message_container = MessageContainer() + # 2. Test with a string template + template_message = "Template message with {test_var}" - def my_sys_message(context_variables, messages) -> str: - return f"This is a custom system message with {context_variables['sample_name']} and a total of {len(messages)} message(s)." + # Create agents with different update configurations + agent1 = SwarmAgent( + "agent1", + update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function) + ) + + agent2 = SwarmAgent( + "agent2", + update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(template_message) + ) - agent1 = SwarmAgent("agent1", system_message_func=my_sys_message) - agent2 = SwarmAgent("agent2") + # Mock the reply function to capture the system message + def mock_generate_oai_reply(*args, **kwargs): + # Capture the system message for verification + message_container.captured_sys_message = args[0]._oai_system_message[0]["content"] + return True, "Mock response" - test_context_variables = {"sample_name": "Bob"} + # Register mock reply for both agents + agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply) + agent2.register_reply([ConversableAgent, None], mock_generate_oai_reply) - # Mock a reply to be able to capture the system message - def mock_generate_oai_reply(*args, **kwargs): - message_container.final_sys_message = args[0]._oai_system_message[0][ - "content" - ] # The SwarmAgent's system message - return True, "This is a mock response from the agent." + # Test context and messages + test_context = {"test_var": "test_value"} + test_messages = [{"role": "user", "content": "Test message"}] - agent1.register_reply([ConversableAgent, None], mock_generate_oai_reply) + # Run chat with first agent (using callable function) + chat_result1, context_vars1, last_speaker1 = initiate_swarm_chat( + initial_agent=agent1, + messages=test_messages, + agents=[agent1], + context_variables=test_context, + max_rounds=2 + ) - chat_result, context_vars, last_speaker = initiate_swarm_chat( - initial_agent=agent1, - messages=TEST_MESSAGES, - agents=[agent1, agent2], - context_variables=test_context_variables, - max_rounds=4, - ) + # Verify callable function result + assert message_container.captured_sys_message == "System message with test_value and 1 messages" - # The system message should be the custom message - assert ( - message_container.final_sys_message - == "This is a custom system message with Bob and a total of 1 message(s)." - ) + # Reset captured message + message_container.captured_sys_message = "" + + # Run chat with second agent (using string template) + chat_result2, context_vars2, last_speaker2 = initiate_swarm_chat( + initial_agent=agent2, + messages=test_messages, + agents=[agent2], + context_variables=test_context, + max_rounds=2 + ) + # Verify template result + assert message_container.captured_sys_message == "Template message with test_value" + + # Test invalid update function + with pytest.raises(ValueError, match="Update function must be either a string or a callable"): + SwarmAgent("agent3", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(123)) + + # Test invalid callable (wrong number of parameters) + def invalid_update_function(context_variables): + return "Invalid function" + + with pytest.raises(ValueError, match="Update function must accept two parameters"): + SwarmAgent("agent4", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_update_function)) + + # Test invalid callable (wrong return type) + def invalid_return_function(context_variables, messages) -> dict: + return {} + + with pytest.raises(ValueError, match="Update function must return a string"): + SwarmAgent("agent5", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_return_function)) + + # Test multiple update functions + def another_update_function(context_variables: Dict[str, Any], messages: List[Dict]) -> str: + return "Another update" + + agent6 = SwarmAgent( + "agent6", + update_agent_before_reply=[ + UPDATE_SYSTEM_MESSAGE(custom_update_function), + UPDATE_SYSTEM_MESSAGE(another_update_function) + ] + ) + + agent6.register_reply([ConversableAgent, None], mock_generate_oai_reply) + + chat_result6, context_vars6, last_speaker6 = initiate_swarm_chat( + initial_agent=agent6, + messages=test_messages, + agents=[agent6], + context_variables=test_context, + max_rounds=2 + ) + # Verify last update function took effect + assert message_container.captured_sys_message == "Another update" + if __name__ == "__main__": pytest.main([__file__]) From 10a4e8f9c5bac42f48abf32dc28df8baa85a5f74 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sun, 1 Dec 2024 19:39:23 +0000 Subject: [PATCH 09/16] pre-commit updates Signed-off-by: Mark Sze --- autogen/agentchat/contrib/swarm_agent.py | 13 ++++--- test/agentchat/contrib/test_swarm.py | 45 ++++++++---------------- 2 files changed, 22 insertions(+), 36 deletions(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index ebe9b9af7b..ff95405b2c 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -3,12 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 import copy import json +import re +import warnings from dataclasses import dataclass from enum import Enum from inspect import signature -import re from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union -import warnings from pydantic import BaseModel @@ -62,7 +62,7 @@ def __post_init__(self): vars = re.findall(r"\{(\w+)\}", self.update_function) if len(vars) == 0: warnings.warn("Update function string contains no variables. This is probably unintended.") - + elif isinstance(self.update_function, Callable): sig = signature(self.update_function) if len(sig.parameters) != 2: @@ -292,7 +292,9 @@ def __init__( human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", description: Optional[str] = None, code_execution_config=False, - update_agent_before_reply: Optional[Union[List[Union[Callable, UPDATE_SYSTEM_MESSAGE]], Callable, UPDATE_SYSTEM_MESSAGE]] = None, + update_agent_before_reply: Optional[ + Union[List[Union[Callable, UPDATE_SYSTEM_MESSAGE]], Callable, UPDATE_SYSTEM_MESSAGE] + ] = None, **kwargs, ) -> None: super().__init__( @@ -337,7 +339,7 @@ def register_update_agent_before_reply(self, functions: Optional[Union[List[Call if not isinstance(functions, list) and type(functions) not in [UPDATE_SYSTEM_MESSAGE, Callable]: raise ValueError("functions must be a list of callables") - if type(functions) is not list: + if not isinstance(functions, list): functions = [functions] for func in functions: @@ -508,6 +510,7 @@ def generate_swarm_tool_reply( return False, None def add_single_function(self, func: Callable, name=None, description=""): + """Add a single function to the agent, removing context variables for LLM use""" if name: func._name = name else: diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 8cf4de5c3d..bc79661a93 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -10,8 +10,8 @@ __CONTEXT_VARIABLES_PARAM_NAME__, AFTER_WORK, ON_CONDITION, - AfterWorkOption, UPDATE_SYSTEM_MESSAGE, + AfterWorkOption, SwarmAgent, SwarmResult, initiate_swarm_chat, @@ -463,13 +463,13 @@ def test_initialization(): def test_update_system_message(): - """Tests the update_agent_before_reply functionality with different scenarios""" - + """Tests the update_agent_before_reply functionality with multiple scenarios""" + # Test container to capture system messages class MessageContainer: def __init__(self): self.captured_sys_message = "" - + message_container = MessageContainer() # 1. Test with a callable function @@ -480,15 +480,9 @@ def custom_update_function(context_variables: Dict[str, Any], messages: List[Dic template_message = "Template message with {test_var}" # Create agents with different update configurations - agent1 = SwarmAgent( - "agent1", - update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function) - ) - - agent2 = SwarmAgent( - "agent2", - update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(template_message) - ) + agent1 = SwarmAgent("agent1", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function)) + + agent2 = SwarmAgent("agent2", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(template_message)) # Mock the reply function to capture the system message def mock_generate_oai_reply(*args, **kwargs): @@ -506,11 +500,7 @@ def mock_generate_oai_reply(*args, **kwargs): # Run chat with first agent (using callable function) chat_result1, context_vars1, last_speaker1 = initiate_swarm_chat( - initial_agent=agent1, - messages=test_messages, - agents=[agent1], - context_variables=test_context, - max_rounds=2 + initial_agent=agent1, messages=test_messages, agents=[agent1], context_variables=test_context, max_rounds=2 ) # Verify callable function result @@ -521,11 +511,7 @@ def mock_generate_oai_reply(*args, **kwargs): # Run chat with second agent (using string template) chat_result2, context_vars2, last_speaker2 = initiate_swarm_chat( - initial_agent=agent2, - messages=test_messages, - agents=[agent2], - context_variables=test_context, - max_rounds=2 + initial_agent=agent2, messages=test_messages, agents=[agent2], context_variables=test_context, max_rounds=2 ) # Verify template result @@ -557,22 +543,19 @@ def another_update_function(context_variables: Dict[str, Any], messages: List[Di "agent6", update_agent_before_reply=[ UPDATE_SYSTEM_MESSAGE(custom_update_function), - UPDATE_SYSTEM_MESSAGE(another_update_function) - ] + UPDATE_SYSTEM_MESSAGE(another_update_function), + ], ) agent6.register_reply([ConversableAgent, None], mock_generate_oai_reply) chat_result6, context_vars6, last_speaker6 = initiate_swarm_chat( - initial_agent=agent6, - messages=test_messages, - agents=[agent6], - context_variables=test_context, - max_rounds=2 + initial_agent=agent6, messages=test_messages, agents=[agent6], context_variables=test_context, max_rounds=2 ) # Verify last update function took effect assert message_container.captured_sys_message == "Another update" - + + if __name__ == "__main__": pytest.main([__file__]) From 623727b19e9582e40aad632cc61cd9b9c4a6d480 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sun, 1 Dec 2024 20:48:28 +0000 Subject: [PATCH 10/16] Fix for ConversableAgent's a_generate_reply Signed-off-by: Mark Sze --- autogen/agentchat/conversable_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 91fa7d7fb0..385ea5bc1b 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -2166,7 +2166,7 @@ async def a_generate_reply( messages = self._oai_messages[sender] # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. - messages = self.process_update_agent_states(messages) + self.process_update_agent_states(messages) # Call the hookable method that gives registered hooks a chance to process all messages. # Message modifications do not affect the incoming messages or self._oai_messages. From 8188593e5b25868cc0a59a533543d24928be97bb Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sun, 1 Dec 2024 20:57:04 +0000 Subject: [PATCH 11/16] Added ConversableAgent context variable tests Signed-off-by: Mark Sze --- test/agentchat/test_conversable_agent.py | 72 ++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 6 deletions(-) diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index 81e0036dc0..cc059ce29b 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -1527,6 +1527,67 @@ def test_handle_carryover(): assert proc_content_empty_carryover == content, "Incorrect carryover processing" +@pytest.mark.skipif(skip_openai, reason=reason) +def test_context_variables(): + # Test initialization with context_variables + initial_context = {"test_key": "test_value", "number": 42, "nested": {"inner": "value"}} + agent = ConversableAgent(name="context_test_agent", llm_config=False, context_variables=initial_context) + + # Check that context was properly initialized + assert agent._context_variables == initial_context + + # Test initialization without context_variables + agent_no_context = ConversableAgent(name="no_context_agent", llm_config=False) + assert agent_no_context._context_variables == {} + + # Test get_context_value + assert agent.get_context_value("test_key") == "test_value" + assert agent.get_context_value("number") == 42 + assert agent.get_context_value("nested") == {"inner": "value"} + assert agent.get_context_value("non_existent") is None + assert agent.get_context_value("non_existent", default="default") == "default" + + # Test set_context_value + agent.set_context_value("new_key", "new_value") + assert agent.get_context_value("new_key") == "new_value" + + # Test overwriting existing value + agent.set_context_value("test_key", "updated_value") + assert agent.get_context_value("test_key") == "updated_value" + + # Test set_context_values + new_values = {"bulk_key1": "bulk_value1", "bulk_key2": "bulk_value2", "test_key": "bulk_updated_value"} + agent.set_context_values(new_values) + assert agent.get_context_value("bulk_key1") == "bulk_value1" + assert agent.get_context_value("bulk_key2") == "bulk_value2" + assert agent.get_context_value("test_key") == "bulk_updated_value" + + # Test pop_context_key + # Pop existing key + popped_value = agent.pop_context_key("bulk_key1") + assert popped_value == "bulk_value1" + assert agent.get_context_value("bulk_key1") is None + + # Pop with default value + default_value = "default_value" + popped_default = agent.pop_context_key("non_existent", default=default_value) + assert popped_default == default_value + + # Pop without default (should return None) + popped_none = agent.pop_context_key("another_non_existent") + assert popped_none is None + + # Verify final state of context + expected_final_context = { + "number": 42, + "nested": {"inner": "value"}, + "new_key": "new_value", + "bulk_key2": "bulk_value2", + "test_key": "bulk_updated_value", + } + assert agent._context_variables == expected_final_context + + if __name__ == "__main__": # test_trigger() # test_context() @@ -1537,10 +1598,9 @@ def test_handle_carryover(): # test_max_turn() # test_process_before_send() # test_message_func() - - test_summary() - test_adding_duplicate_function_warning() + # test_summary() + # test_adding_duplicate_function_warning() # test_function_registration_e2e_sync() - - test_process_gemini_carryover() - test_process_carryover() + # test_process_gemini_carryover() + # test_process_carryover() + test_context_variables() From b9352daaa00be902ba39644dfe03bed9d8deadce Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Fri, 6 Dec 2024 23:40:12 +0000 Subject: [PATCH 12/16] Corrected missing variable from nested chat PR Signed-off-by: Mark Sze --- autogen/agentchat/contrib/swarm_agent.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index aa90a7f441..5dc60e3002 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -376,9 +376,13 @@ def __init__( self.after_work = None - # use in the tool execution agent to transfer to the next agent + # Used in the tool execution agent to transfer to the next agent self._next_agent = None + # Store nested chats hand offs as we'll establish these in the initiate_swarm_chat + # List of Dictionaries containing the nested_chats and condition + self._nested_chat_handoffs = [] + self.register_update_agent_before_reply(update_agent_before_reply) def register_update_agent_before_reply(self, functions: Optional[Union[List[Callable], Callable]]): From 71cc5c7b4c92bf6e30ba9a53bd0cf05d182914cc Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Fri, 6 Dec 2024 23:48:01 +0000 Subject: [PATCH 13/16] Restore conversable agent context getters/setters Signed-off-by: Mark Sze --- autogen/agentchat/conversable_agent.py | 45 -------------------------- 1 file changed, 45 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 8251cadab7..db69574f96 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -572,51 +572,6 @@ def system_message(self) -> str: """Return the system message.""" return self._oai_system_message[0]["content"] - def get_context_value(self, key: str, default: Any = None) -> Any: - """ - Get a context variable by key. - - Args: - key: The key to look up - default: Value to return if key doesn't exist - - Returns: - The value associated with the key, or default if not found - """ - return self._context_variables.get(key, default) - - def set_context_value(self, key: str, value: Any) -> None: - """ - Set a context variable. - - Args: - key: The key to set - value: The value to associate with the key - """ - self._context_variables[key] = value - - def set_context_values(self, context_variables: Dict[str, Any]) -> None: - """ - Update multiple context variables at once. - - Args: - context_variables: Dictionary of variables to update/add - """ - self._context_variables.update(context_variables) - - def pop_context_key(self, key: str, default: Any = None) -> Any: - """ - Remove and return a context variable. - - Args: - key: The key to remove - default: Value to return if key doesn't exist - - Returns: - The value that was removed, or default if key not found - """ - return self._context_variables.pop(key, default) - def update_system_message(self, system_message: str) -> None: """Update the system message. From 790f0372a3a5765b55e3ba9f4ee2e82b60950e2e Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sat, 7 Dec 2024 00:48:35 +0000 Subject: [PATCH 14/16] Docs and update system message callable signature change Signed-off-by: Mark Sze --- autogen/agentchat/contrib/swarm_agent.py | 6 ++- test/agentchat/contrib/test_swarm.py | 4 +- website/docs/topics/swarm.ipynb | 54 +++++++++++++++++++++++- 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 5dc60e3002..fcb6db08a1 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -73,7 +73,9 @@ def __post_init__(self): elif isinstance(self.update_function, Callable): sig = signature(self.update_function) if len(sig.parameters) != 2: - raise ValueError("Update function must accept two parameters, context_variables and messages") + raise ValueError( + "Update function must accept two parameters of type ConversableAgent and List[Dict[str Any]], respectively" + ) if sig.return_annotation != str: raise ValueError("Update function must return a string") else: @@ -421,7 +423,7 @@ def update_system_message_wrapper( allow_format_str_template=True, ) else: - sys_message = update_func.update_function(agent._context_variables, messages) + sys_message = update_func.update_function(agent, messages) agent.update_system_message(sys_message) return messages diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index f27bfe874f..85130baac2 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -473,8 +473,8 @@ def __init__(self): message_container = MessageContainer() # 1. Test with a callable function - def custom_update_function(context_variables: Dict[str, Any], messages: List[Dict]) -> str: - return f"System message with {context_variables['test_var']} and {len(messages)} messages" + def custom_update_function(agent: ConversableAgent, messages: List[Dict]) -> str: + return f"System message with {agent.get_context('test_var')} and {len(messages)} messages" # 2. Test with a string template template_message = "Template message with {test_var}" diff --git a/website/docs/topics/swarm.ipynb b/website/docs/topics/swarm.ipynb index 1724eea982..2edd76fcde 100644 --- a/website/docs/topics/swarm.ipynb +++ b/website/docs/topics/swarm.ipynb @@ -159,8 +159,58 @@ "])\n", "\n", "agent_2.handoff(hand_to=[AFTER_WORK(AfterWorkOption.TERMINATE)]) # Terminate the chat if no handoff is suggested\n", - "```\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Update Agent state before replying\n", + "\n", + "It can be useful to update an agent's state before they reply, particularly their system message/prompt.\n", + "\n", + "When initialising an agent you can use the `update_agent_before_reply` to register the updates run when the agent is selected, but before they reply.\n", + "\n", + "The `update_agent_before_reply` takes a list of any combination of the following (executing them in the provided order):\n", "\n", + "- `UPDATE_SYSTEM_MESSAGE` provides a simple way to update the agent's system message via an f-string that substitutes the values of context variables, or a Callable that returns a string\n", + "- Callable with two parameters of type `ConversableAgent` for the agent and `List[Dict[str Any]]` for the messages, and does not return a value\n", + "\n", + "Below is an example of setting these up when creating a Swarm agent.\n", + "\n", + "```python\n", + "# Creates a system message string\n", + "def create_system_prompt_function(my_agent: ConversableAgent, messages: List[Dict[]]) -> str:\n", + " preferred_name = my_agent.get_context(\"preferred_name\", \"(name not provided)\")\n", + "\n", + " # Note that the returned string will be treated like an f-string using the context variables\n", + " return \"You are a customer service representative helping a customer named \"\n", + " + preferred_name\n", + " + \" and their passport number is '{passport_number}'.\"\n", + "\n", + "# Function to update an Agent's state\n", + "def my_callable_state_update_function(my_agent: ConversableAgent, messages: List[Dict[]]) -> None:\n", + " agent.set_context(\"context_key\", 43)\n", + " agent.update_system_message(\"You are a customer service representative.\")\n", + "\n", + "# Create the SwarmAgent and set agent updates\n", + "customer_service = SwarmAgent(\n", + " name=\"CustomerServiceRep\",\n", + " system_message=\"You are a customer service representative.\",\n", + " update_agent_before_reply=[\n", + " UPDATE_SYSTEM_MESSAGE(\"You are a customer service representative. Quote passport number '{passport_number}'\"),\n", + " UPDATE_SYSTEM_MESSAGE(create_system_prompt_function),\n", + " my_callable_state_update_function]\n", + " ...\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "### Initialize SwarmChat with `initiate_swarm_chat`\n", "\n", "After a set of swarm agents are created, you can initiate a swarm chat by calling `initiate_swarm_chat`.\n", @@ -185,7 +235,7 @@ "\n", "> How are context variables updated?\n", "\n", - "The context variables will only be updated through custom function calls when returning a `SwarmResult` object. In fact, all interactions with context variables will be done through function calls (accessing and updating). The context variables dictionary is a reference, and any modification will be done in place.\n", + "In a swarm, the context variables are shared amongst Swarm agents. As context variables are available at the agent level, you can use the context variable getters/setters on the agent to view and change the shared context variables. If you're working with a function that returns a `SwarmResult` you should update the passed in context variables and return it in the `SwarmResult`, this will ensure the shared context is updated.\n", "\n", "> What is the difference between ON_CONDITION and AFTER_WORK?\n", "\n", From 2c3e0638d09fbfeca5e0db33255b975683d4c6bf Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sun, 15 Dec 2024 21:20:17 +0000 Subject: [PATCH 15/16] Updated parameter name to update_agent_state_before_reply Signed-off-by: Mark Sze --- autogen/agentchat/contrib/swarm_agent.py | 6 +++--- autogen/agentchat/conversable_agent.py | 6 +++--- website/docs/topics/swarm.ipynb | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index fcb6db08a1..c525c53db3 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -350,7 +350,7 @@ def __init__( human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER", description: Optional[str] = None, code_execution_config=False, - update_agent_before_reply: Optional[ + update_agent_state_before_reply: Optional[ Union[List[Union[Callable, UPDATE_SYSTEM_MESSAGE]], Callable, UPDATE_SYSTEM_MESSAGE] ] = None, **kwargs, @@ -385,9 +385,9 @@ def __init__( # List of Dictionaries containing the nested_chats and condition self._nested_chat_handoffs = [] - self.register_update_agent_before_reply(update_agent_before_reply) + self.register_update_agent_state_before_reply(update_agent_state_before_reply) - def register_update_agent_before_reply(self, functions: Optional[Union[List[Callable], Callable]]): + def register_update_agent_state_before_reply(self, functions: Optional[Union[List[Callable], Callable]]): """ Register functions that will be called when the agent is selected and before it speaks. You can add your own validation or precondition functions here. diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index db69574f96..ffd6923721 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -2093,7 +2093,7 @@ def generate_reply( messages = self._oai_messages[sender] # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. - self.process_update_agent_states(messages) + self.update_agent_state_before_reply(messages) # Call the hookable method that gives registered hooks a chance to process the last message. # Message modifications do not affect the incoming messages or self._oai_messages. @@ -2166,7 +2166,7 @@ async def a_generate_reply( messages = self._oai_messages[sender] # Call the hookable method that gives registered hooks a chance to update agent state, used for their context variables. - self.process_update_agent_states(messages) + self.update_agent_state_before_reply(messages) # Call the hookable method that gives registered hooks a chance to process all messages. # Message modifications do not affect the incoming messages or self._oai_messages. @@ -2854,7 +2854,7 @@ def register_hook(self, hookable_method: str, hook: Callable): assert hook not in hook_list, f"{hook} is already registered as a hook." hook_list.append(hook) - def process_update_agent_states(self, messages: List[Dict]) -> None: + def update_agent_state_before_reply(self, messages: List[Dict]) -> None: """ Calls any registered capability hooks to update the agent's state. Primarily used to update context variables. diff --git a/website/docs/topics/swarm.ipynb b/website/docs/topics/swarm.ipynb index 2edd76fcde..82a0a3cc83 100644 --- a/website/docs/topics/swarm.ipynb +++ b/website/docs/topics/swarm.ipynb @@ -168,11 +168,11 @@ "source": [ "### Update Agent state before replying\n", "\n", - "It can be useful to update an agent's state before they reply, particularly their system message/prompt.\n", + "It can be useful to update a swarm agent's state before they reply. For example, using an agent's context variables you could change their system message based on the state of the workflow.\n", "\n", - "When initialising an agent you can use the `update_agent_before_reply` to register the updates run when the agent is selected, but before they reply.\n", + "When initialising a swarm agent use the `update_agent_state_before_reply` parameter to register updates that run after the agent is selected, but before they reply.\n", "\n", - "The `update_agent_before_reply` takes a list of any combination of the following (executing them in the provided order):\n", + "`update_agent_state_before_reply` takes a list of any combination of the following (executing them in the provided order):\n", "\n", "- `UPDATE_SYSTEM_MESSAGE` provides a simple way to update the agent's system message via an f-string that substitutes the values of context variables, or a Callable that returns a string\n", "- Callable with two parameters of type `ConversableAgent` for the agent and `List[Dict[str Any]]` for the messages, and does not return a value\n", @@ -198,7 +198,7 @@ "customer_service = SwarmAgent(\n", " name=\"CustomerServiceRep\",\n", " system_message=\"You are a customer service representative.\",\n", - " update_agent_before_reply=[\n", + " update_agent_state_before_reply=[\n", " UPDATE_SYSTEM_MESSAGE(\"You are a customer service representative. Quote passport number '{passport_number}'\"),\n", " UPDATE_SYSTEM_MESSAGE(create_system_prompt_function),\n", " my_callable_state_update_function]\n", From bf0de6407166662a93b8bc95f16e53c4b4aec035 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Sun, 15 Dec 2024 21:26:54 +0000 Subject: [PATCH 16/16] Update tests for update_agent_state_before_reply Signed-off-by: Mark Sze --- test/agentchat/contrib/test_swarm.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index 85130baac2..ae2f3cf9b9 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -463,7 +463,7 @@ def test_initialization(): def test_update_system_message(): - """Tests the update_agent_before_reply functionality with multiple scenarios""" + """Tests the update_agent_state_before_reply functionality with multiple scenarios""" # Test container to capture system messages class MessageContainer: @@ -480,9 +480,9 @@ def custom_update_function(agent: ConversableAgent, messages: List[Dict]) -> str template_message = "Template message with {test_var}" # Create agents with different update configurations - agent1 = SwarmAgent("agent1", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function)) + agent1 = SwarmAgent("agent1", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(custom_update_function)) - agent2 = SwarmAgent("agent2", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(template_message)) + agent2 = SwarmAgent("agent2", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(template_message)) # Mock the reply function to capture the system message def mock_generate_oai_reply(*args, **kwargs): @@ -519,21 +519,21 @@ def mock_generate_oai_reply(*args, **kwargs): # Test invalid update function with pytest.raises(ValueError, match="Update function must be either a string or a callable"): - SwarmAgent("agent3", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(123)) + SwarmAgent("agent3", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(123)) # Test invalid callable (wrong number of parameters) def invalid_update_function(context_variables): return "Invalid function" with pytest.raises(ValueError, match="Update function must accept two parameters"): - SwarmAgent("agent4", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_update_function)) + SwarmAgent("agent4", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_update_function)) # Test invalid callable (wrong return type) def invalid_return_function(context_variables, messages) -> dict: return {} with pytest.raises(ValueError, match="Update function must return a string"): - SwarmAgent("agent5", update_agent_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_return_function)) + SwarmAgent("agent5", update_agent_state_before_reply=UPDATE_SYSTEM_MESSAGE(invalid_return_function)) # Test multiple update functions def another_update_function(context_variables: Dict[str, Any], messages: List[Dict]) -> str: @@ -541,7 +541,7 @@ def another_update_function(context_variables: Dict[str, Any], messages: List[Di agent6 = SwarmAgent( "agent6", - update_agent_before_reply=[ + update_agent_state_before_reply=[ UPDATE_SYSTEM_MESSAGE(custom_update_function), UPDATE_SYSTEM_MESSAGE(another_update_function), ],