From 30908aefe7a80d397a9ba7096d3c100ebeb000b9 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Mon, 2 Dec 2024 21:06:41 +0000 Subject: [PATCH 1/3] Add context_variables to ConversableAgent Signed-off-by: Mark Sze --- autogen/agentchat/conversable_agent.py | 45 ++++++++++++++ test/agentchat/test_conversable_agent.py | 77 +++++++++++++++++++++--- 2 files changed, 112 insertions(+), 10 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 840da79204..0b611a0eaa 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -85,6 +85,7 @@ def __init__( description: Optional[str] = None, chat_messages: Optional[Dict[Agent, List[Dict]]] = None, silent: Optional[bool] = None, + context_variables: Optional[Dict[str, Any]] = None, ): """ Args: @@ -135,6 +136,9 @@ def __init__( resume previous had conversations. Defaults to an empty chat history. silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of silent in each function. + context_variables (dict or None): Context variables that provide a persistent context for the agent. + The passed in context variables will be deep-copied, not referenced. + Only used in Swarms at this stage. """ # we change code_execution_config below and we have to make sure we don't change the input # in case of UserProxyAgent, without this we could even change the default value {} @@ -193,6 +197,8 @@ def __init__( self.register_reply([Agent, None], ConversableAgent.generate_oai_reply) self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True) + self._context_variables = copy.deepcopy(context_variables) if context_variables is not None else {} + # Setting up code execution. # Do not register code execution reply if code execution is disabled. if code_execution_config is not False: @@ -520,6 +526,45 @@ def wrapped_reply_func(recipient, messages=None, sender=None, config=None): ), ) + def get_context_value(self, key: str, default: Any = None) -> Any: + """ + Get a context variable by key. + Args: + key: The key to look up + default: Value to return if key doesn't exist + Returns: + The value associated with the key, or default if not found + """ + return self._context_variables.get(key, default) + + def set_context_value(self, key: str, value: Any) -> None: + """ + Set a context variable. + Args: + key: The key to set + value: The value to associate with the key + """ + self._context_variables[key] = value + + def set_context_values(self, context_variables: Dict[str, Any]) -> None: + """ + Update multiple context variables at once. + Args: + context_variables: Dictionary of variables to update/add + """ + self._context_variables.update(context_variables) + + def pop_context_key(self, key: str, default: Any = None) -> Any: + """ + Remove and return a context variable. + Args: + key: The key to remove + default: Value to return if key doesn't exist + Returns: + The value that was removed, or default if key not found + """ + return self._context_variables.pop(key, default) + @property def system_message(self) -> str: """Return the system message.""" diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index 81e0036dc0..3e7a30a2df 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -1527,20 +1527,77 @@ def test_handle_carryover(): assert proc_content_empty_carryover == content, "Incorrect carryover processing" +@pytest.mark.skipif(skip_openai, reason=reason) +def test_context_variables(): + # Test initialization with context_variables + initial_context = {"test_key": "test_value", "number": 42, "nested": {"inner": "value"}} + agent = ConversableAgent(name="context_test_agent", llm_config=False, context_variables=initial_context) + + # Check that context was properly initialized + assert agent._context_variables == initial_context + + # Test initialization without context_variables + agent_no_context = ConversableAgent(name="no_context_agent", llm_config=False) + assert agent_no_context._context_variables == {} + + # Test get_context_value + assert agent.get_context_value("test_key") == "test_value" + assert agent.get_context_value("number") == 42 + assert agent.get_context_value("nested") == {"inner": "value"} + assert agent.get_context_value("non_existent") is None + assert agent.get_context_value("non_existent", default="default") == "default" + + # Test set_context_value + agent.set_context_value("new_key", "new_value") + assert agent.get_context_value("new_key") == "new_value" + + # Test overwriting existing value + agent.set_context_value("test_key", "updated_value") + assert agent.get_context_value("test_key") == "updated_value" + + # Test set_context_values + new_values = {"bulk_key1": "bulk_value1", "bulk_key2": "bulk_value2", "test_key": "bulk_updated_value"} + agent.set_context_values(new_values) + assert agent.get_context_value("bulk_key1") == "bulk_value1" + assert agent.get_context_value("bulk_key2") == "bulk_value2" + assert agent.get_context_value("test_key") == "bulk_updated_value" + + # Test pop_context_key + # Pop existing key + popped_value = agent.pop_context_key("bulk_key1") + assert popped_value == "bulk_value1" + assert agent.get_context_value("bulk_key1") is None + + # Pop with default value + default_value = "default_value" + popped_default = agent.pop_context_key("non_existent", default=default_value) + assert popped_default == default_value + + # Pop without default (should return None) + popped_none = agent.pop_context_key("another_non_existent") + assert popped_none is None + + # Verify final state of context + expected_final_context = { + "number": 42, + "nested": {"inner": "value"}, + "new_key": "new_value", + "bulk_key2": "bulk_value2", + "test_key": "bulk_updated_value", + } + assert agent._context_variables == expected_final_context + + if __name__ == "__main__": # test_trigger() # test_context() - # test_max_consecutive_auto_reply() - # test_generate_code_execution_reply() - # test_conversable_agent() - # test_no_llm_config() + # test_handle_carryover(): # test_max_turn() # test_process_before_send() # test_message_func() - - test_summary() - test_adding_duplicate_function_warning() + # test_summary() + # test_adding_duplicate_function_warning() # test_function_registration_e2e_sync() - - test_process_gemini_carryover() - test_process_carryover() + # test_process_gemini_carryover() + # test_process_carryover() + test_context_variables() From c072335561b7b84f03be25c0e0848af754d7f0b0 Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Tue, 3 Dec 2024 20:17:56 +0000 Subject: [PATCH 2/3] Updated getter/setter names Signed-off-by: Mark Sze --- autogen/agentchat/conversable_agent.py | 8 ++--- test/agentchat/test_conversable_agent.py | 42 ++++++++++++------------ 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 0b611a0eaa..7620363fa8 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -526,7 +526,7 @@ def wrapped_reply_func(recipient, messages=None, sender=None, config=None): ), ) - def get_context_value(self, key: str, default: Any = None) -> Any: + def get_context(self, key: str, default: Any = None) -> Any: """ Get a context variable by key. Args: @@ -537,7 +537,7 @@ def get_context_value(self, key: str, default: Any = None) -> Any: """ return self._context_variables.get(key, default) - def set_context_value(self, key: str, value: Any) -> None: + def set_context(self, key: str, value: Any) -> None: """ Set a context variable. Args: @@ -546,7 +546,7 @@ def set_context_value(self, key: str, value: Any) -> None: """ self._context_variables[key] = value - def set_context_values(self, context_variables: Dict[str, Any]) -> None: + def update_context(self, context_variables: Dict[str, Any]) -> None: """ Update multiple context variables at once. Args: @@ -554,7 +554,7 @@ def set_context_values(self, context_variables: Dict[str, Any]) -> None: """ self._context_variables.update(context_variables) - def pop_context_key(self, key: str, default: Any = None) -> Any: + def pop_context(self, key: str, default: Any = None) -> Any: """ Remove and return a context variable. Args: diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index 3e7a30a2df..320cfb324b 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -1540,41 +1540,41 @@ def test_context_variables(): agent_no_context = ConversableAgent(name="no_context_agent", llm_config=False) assert agent_no_context._context_variables == {} - # Test get_context_value - assert agent.get_context_value("test_key") == "test_value" - assert agent.get_context_value("number") == 42 - assert agent.get_context_value("nested") == {"inner": "value"} - assert agent.get_context_value("non_existent") is None - assert agent.get_context_value("non_existent", default="default") == "default" + # Test get_context + assert agent.get_context("test_key") == "test_value" + assert agent.get_context("number") == 42 + assert agent.get_context("nested") == {"inner": "value"} + assert agent.get_context("non_existent") is None + assert agent.get_context("non_existent", default="default") == "default" - # Test set_context_value - agent.set_context_value("new_key", "new_value") - assert agent.get_context_value("new_key") == "new_value" + # Test set_context + agent.set_context("new_key", "new_value") + assert agent.get_context("new_key") == "new_value" # Test overwriting existing value - agent.set_context_value("test_key", "updated_value") - assert agent.get_context_value("test_key") == "updated_value" + agent.set_context("test_key", "updated_value") + assert agent.get_context("test_key") == "updated_value" - # Test set_context_values + # Test update_context new_values = {"bulk_key1": "bulk_value1", "bulk_key2": "bulk_value2", "test_key": "bulk_updated_value"} - agent.set_context_values(new_values) - assert agent.get_context_value("bulk_key1") == "bulk_value1" - assert agent.get_context_value("bulk_key2") == "bulk_value2" - assert agent.get_context_value("test_key") == "bulk_updated_value" + agent.update_context(new_values) + assert agent.get_context("bulk_key1") == "bulk_value1" + assert agent.get_context("bulk_key2") == "bulk_value2" + assert agent.get_context("test_key") == "bulk_updated_value" - # Test pop_context_key + # Test pop_context # Pop existing key - popped_value = agent.pop_context_key("bulk_key1") + popped_value = agent.pop_context("bulk_key1") assert popped_value == "bulk_value1" - assert agent.get_context_value("bulk_key1") is None + assert agent.get_context("bulk_key1") is None # Pop with default value default_value = "default_value" - popped_default = agent.pop_context_key("non_existent", default=default_value) + popped_default = agent.pop_context("non_existent", default=default_value) assert popped_default == default_value # Pop without default (should return None) - popped_none = agent.pop_context_key("another_non_existent") + popped_none = agent.pop_context("another_non_existent") assert popped_none is None # Verify final state of context From fba4d0852bef6594d06952318e0c8c70bac12e2f Mon Sep 17 00:00:00 2001 From: Mark Sze Date: Wed, 4 Dec 2024 20:26:20 +0000 Subject: [PATCH 3/3] Removed deep copy, updated comment Signed-off-by: Mark Sze --- autogen/agentchat/conversable_agent.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 3e67503f9c..b738e6821d 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -138,8 +138,9 @@ def __init__( silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of silent in each function. context_variables (dict or None): Context variables that provide a persistent context for the agent. - The passed in context variables will be deep-copied, not referenced. - Only used in Swarms at this stage. + Note: Will maintain a reference to the passed in context variables (enabling a shared context) + Only used in Swarms at this stage: + https://ag2ai.github.io/ag2/docs/reference/agentchat/contrib/swarm_agent response_format (BaseModel): Used to specify structured response format for the agent. Currently only available for the OpenAI client. """ # we change code_execution_config below and we have to make sure we don't change the input @@ -200,7 +201,7 @@ def __init__( self.register_reply([Agent, None], ConversableAgent.generate_oai_reply) self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True) - self._context_variables = copy.deepcopy(context_variables) if context_variables is not None else {} + self._context_variables = context_variables if context_variables is not None else {} # Setting up code execution. # Do not register code execution reply if code execution is disabled.