Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context variables to ConversableAgent #137

Merged
merged 4 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
description: Optional[str] = None,
chat_messages: Optional[Dict[Agent, List[Dict]]] = None,
silent: Optional[bool] = None,
context_variables: Optional[Dict[str, Any]] = None,
response_format: Optional[BaseModel] = None,
):
"""
Expand Down Expand Up @@ -136,7 +137,10 @@ def __init__(
resume previous had conversations. Defaults to an empty chat history.
silent (bool or None): (Experimental) whether to print the message sent. If None, will use the value of
silent in each function.
response_format(BaseModel): Used to specify structured response format for the agent. Not available for all LLMs.
context_variables (dict or None): Context variables that provide a persistent context for the agent.
The passed in context variables will be deep-copied, not referenced.
Only used in Swarms at this stage.
marklysze marked this conversation as resolved.
Show resolved Hide resolved
response_format (BaseModel): Used to specify structured response format for the agent. Currently only available for the OpenAI client.
"""
# we change code_execution_config below and we have to make sure we don't change the input
# in case of UserProxyAgent, without this we could even change the default value {}
Expand Down Expand Up @@ -196,6 +200,8 @@ def __init__(
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True)

self._context_variables = copy.deepcopy(context_variables) if context_variables is not None else {}
marklysze marked this conversation as resolved.
Show resolved Hide resolved

# Setting up code execution.
# Do not register code execution reply if code execution is disabled.
if code_execution_config is not False:
Expand Down Expand Up @@ -523,6 +529,45 @@ def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
),
)

def get_context(self, key: str, default: Any = None) -> Any:
"""
Get a context variable by key.
Args:
key: The key to look up
default: Value to return if key doesn't exist
Returns:
The value associated with the key, or default if not found
"""
return self._context_variables.get(key, default)

def set_context(self, key: str, value: Any) -> None:
"""
Set a context variable.
Args:
key: The key to set
value: The value to associate with the key
"""
self._context_variables[key] = value

def update_context(self, context_variables: Dict[str, Any]) -> None:
"""
Update multiple context variables at once.
Args:
context_variables: Dictionary of variables to update/add
"""
self._context_variables.update(context_variables)

def pop_context(self, key: str, default: Any = None) -> Any:
"""
Remove and return a context variable.
Args:
key: The key to remove
default: Value to return if key doesn't exist
Returns:
The value that was removed, or default if key not found
"""
return self._context_variables.pop(key, default)

@property
def system_message(self) -> str:
"""Return the system message."""
Expand Down
77 changes: 67 additions & 10 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,20 +1527,77 @@ def test_handle_carryover():
assert proc_content_empty_carryover == content, "Incorrect carryover processing"


@pytest.mark.skipif(skip_openai, reason=reason)
def test_context_variables():
# Test initialization with context_variables
initial_context = {"test_key": "test_value", "number": 42, "nested": {"inner": "value"}}
agent = ConversableAgent(name="context_test_agent", llm_config=False, context_variables=initial_context)

# Check that context was properly initialized
assert agent._context_variables == initial_context

# Test initialization without context_variables
agent_no_context = ConversableAgent(name="no_context_agent", llm_config=False)
assert agent_no_context._context_variables == {}

# Test get_context
assert agent.get_context("test_key") == "test_value"
assert agent.get_context("number") == 42
assert agent.get_context("nested") == {"inner": "value"}
assert agent.get_context("non_existent") is None
assert agent.get_context("non_existent", default="default") == "default"

# Test set_context
agent.set_context("new_key", "new_value")
assert agent.get_context("new_key") == "new_value"

# Test overwriting existing value
agent.set_context("test_key", "updated_value")
assert agent.get_context("test_key") == "updated_value"

# Test update_context
new_values = {"bulk_key1": "bulk_value1", "bulk_key2": "bulk_value2", "test_key": "bulk_updated_value"}
agent.update_context(new_values)
assert agent.get_context("bulk_key1") == "bulk_value1"
assert agent.get_context("bulk_key2") == "bulk_value2"
assert agent.get_context("test_key") == "bulk_updated_value"

# Test pop_context
# Pop existing key
popped_value = agent.pop_context("bulk_key1")
assert popped_value == "bulk_value1"
assert agent.get_context("bulk_key1") is None

# Pop with default value
default_value = "default_value"
popped_default = agent.pop_context("non_existent", default=default_value)
assert popped_default == default_value

# Pop without default (should return None)
popped_none = agent.pop_context("another_non_existent")
assert popped_none is None

# Verify final state of context
expected_final_context = {
"number": 42,
"nested": {"inner": "value"},
"new_key": "new_value",
"bulk_key2": "bulk_value2",
"test_key": "bulk_updated_value",
}
assert agent._context_variables == expected_final_context


if __name__ == "__main__":
# test_trigger()
# test_context()
# test_max_consecutive_auto_reply()
# test_generate_code_execution_reply()
# test_conversable_agent()
# test_no_llm_config()
# test_handle_carryover():
# test_max_turn()
# test_process_before_send()
# test_message_func()

test_summary()
test_adding_duplicate_function_warning()
# test_summary()
# test_adding_duplicate_function_warning()
# test_function_registration_e2e_sync()

test_process_gemini_carryover()
test_process_carryover()
# test_process_gemini_carryover()
# test_process_carryover()
test_context_variables()
Loading