diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 5d8bbdca8d..b738e6821d 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -85,6 +85,7 @@ def __init__( description: Optional[str] = None, chat_messages: Optional[Dict[Agent, List[Dict]]] = None, silent: Optional[bool] = None, + context_variables: Optional[Dict[str, Any]] = None, response_format: Optional[BaseModel] = None, ): """ @@ -136,7 +137,11 @@ def __init__( resume previous had conversations. Defaults to an empty chat history. silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of silent in each function. - response_format(BaseModel): Used to specify structured response format for the agent. Not available for all LLMs. + context_variables (dict or None): Context variables that provide a persistent context for the agent. + Note: Will maintain a reference to the passed in context variables (enabling a shared context) + Only used in Swarms at this stage: + https://ag2ai.github.io/ag2/docs/reference/agentchat/contrib/swarm_agent + response_format (BaseModel): Used to specify structured response format for the agent. Currently only available for the OpenAI client. """ # we change code_execution_config below and we have to make sure we don't change the input # in case of UserProxyAgent, without this we could even change the default value {} @@ -196,6 +201,8 @@ def __init__( self.register_reply([Agent, None], ConversableAgent.generate_oai_reply) self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True) + self._context_variables = context_variables if context_variables is not None else {} + # Setting up code execution. # Do not register code execution reply if code execution is disabled. if code_execution_config is not False: @@ -523,6 +530,45 @@ def wrapped_reply_func(recipient, messages=None, sender=None, config=None): ), ) + def get_context(self, key: str, default: Any = None) -> Any: + """ + Get a context variable by key. + Args: + key: The key to look up + default: Value to return if key doesn't exist + Returns: + The value associated with the key, or default if not found + """ + return self._context_variables.get(key, default) + + def set_context(self, key: str, value: Any) -> None: + """ + Set a context variable. + Args: + key: The key to set + value: The value to associate with the key + """ + self._context_variables[key] = value + + def update_context(self, context_variables: Dict[str, Any]) -> None: + """ + Update multiple context variables at once. + Args: + context_variables: Dictionary of variables to update/add + """ + self._context_variables.update(context_variables) + + def pop_context(self, key: str, default: Any = None) -> Any: + """ + Remove and return a context variable. + Args: + key: The key to remove + default: Value to return if key doesn't exist + Returns: + The value that was removed, or default if key not found + """ + return self._context_variables.pop(key, default) + @property def system_message(self) -> str: """Return the system message.""" diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index 81e0036dc0..320cfb324b 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -1527,20 +1527,77 @@ def test_handle_carryover(): assert proc_content_empty_carryover == content, "Incorrect carryover processing" +@pytest.mark.skipif(skip_openai, reason=reason) +def test_context_variables(): + # Test initialization with context_variables + initial_context = {"test_key": "test_value", "number": 42, "nested": {"inner": "value"}} + agent = ConversableAgent(name="context_test_agent", llm_config=False, context_variables=initial_context) + + # Check that context was properly initialized + assert agent._context_variables == initial_context + + # Test initialization without context_variables + agent_no_context = ConversableAgent(name="no_context_agent", llm_config=False) + assert agent_no_context._context_variables == {} + + # Test get_context + assert agent.get_context("test_key") == "test_value" + assert agent.get_context("number") == 42 + assert agent.get_context("nested") == {"inner": "value"} + assert agent.get_context("non_existent") is None + assert agent.get_context("non_existent", default="default") == "default" + + # Test set_context + agent.set_context("new_key", "new_value") + assert agent.get_context("new_key") == "new_value" + + # Test overwriting existing value + agent.set_context("test_key", "updated_value") + assert agent.get_context("test_key") == "updated_value" + + # Test update_context + new_values = {"bulk_key1": "bulk_value1", "bulk_key2": "bulk_value2", "test_key": "bulk_updated_value"} + agent.update_context(new_values) + assert agent.get_context("bulk_key1") == "bulk_value1" + assert agent.get_context("bulk_key2") == "bulk_value2" + assert agent.get_context("test_key") == "bulk_updated_value" + + # Test pop_context + # Pop existing key + popped_value = agent.pop_context("bulk_key1") + assert popped_value == "bulk_value1" + assert agent.get_context("bulk_key1") is None + + # Pop with default value + default_value = "default_value" + popped_default = agent.pop_context("non_existent", default=default_value) + assert popped_default == default_value + + # Pop without default (should return None) + popped_none = agent.pop_context("another_non_existent") + assert popped_none is None + + # Verify final state of context + expected_final_context = { + "number": 42, + "nested": {"inner": "value"}, + "new_key": "new_value", + "bulk_key2": "bulk_value2", + "test_key": "bulk_updated_value", + } + assert agent._context_variables == expected_final_context + + if __name__ == "__main__": # test_trigger() # test_context() - # test_max_consecutive_auto_reply() - # test_generate_code_execution_reply() - # test_conversable_agent() - # test_no_llm_config() + # test_handle_carryover(): # test_max_turn() # test_process_before_send() # test_message_func() - - test_summary() - test_adding_duplicate_function_warning() + # test_summary() + # test_adding_duplicate_function_warning() # test_function_registration_e2e_sync() - - test_process_gemini_carryover() - test_process_carryover() + # test_process_gemini_carryover() + # test_process_carryover() + test_context_variables()