diff --git a/autogen/agentchat/contrib/swarm_agent.py b/autogen/agentchat/contrib/swarm_agent.py index 5dc60e3002..fcb6db08a1 100644 --- a/autogen/agentchat/contrib/swarm_agent.py +++ b/autogen/agentchat/contrib/swarm_agent.py @@ -73,7 +73,9 @@ def __post_init__(self): elif isinstance(self.update_function, Callable): sig = signature(self.update_function) if len(sig.parameters) != 2: - raise ValueError("Update function must accept two parameters, context_variables and messages") + raise ValueError( + "Update function must accept two parameters of type ConversableAgent and List[Dict[str Any]], respectively" + ) if sig.return_annotation != str: raise ValueError("Update function must return a string") else: @@ -421,7 +423,7 @@ def update_system_message_wrapper( allow_format_str_template=True, ) else: - sys_message = update_func.update_function(agent._context_variables, messages) + sys_message = update_func.update_function(agent, messages) agent.update_system_message(sys_message) return messages diff --git a/test/agentchat/contrib/test_swarm.py b/test/agentchat/contrib/test_swarm.py index f27bfe874f..85130baac2 100644 --- a/test/agentchat/contrib/test_swarm.py +++ b/test/agentchat/contrib/test_swarm.py @@ -473,8 +473,8 @@ def __init__(self): message_container = MessageContainer() # 1. Test with a callable function - def custom_update_function(context_variables: Dict[str, Any], messages: List[Dict]) -> str: - return f"System message with {context_variables['test_var']} and {len(messages)} messages" + def custom_update_function(agent: ConversableAgent, messages: List[Dict]) -> str: + return f"System message with {agent.get_context('test_var')} and {len(messages)} messages" # 2. Test with a string template template_message = "Template message with {test_var}" diff --git a/website/docs/topics/swarm.ipynb b/website/docs/topics/swarm.ipynb index 1724eea982..2edd76fcde 100644 --- a/website/docs/topics/swarm.ipynb +++ b/website/docs/topics/swarm.ipynb @@ -159,8 +159,58 @@ "])\n", "\n", "agent_2.handoff(hand_to=[AFTER_WORK(AfterWorkOption.TERMINATE)]) # Terminate the chat if no handoff is suggested\n", - "```\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Update Agent state before replying\n", + "\n", + "It can be useful to update an agent's state before they reply, particularly their system message/prompt.\n", + "\n", + "When initialising an agent you can use the `update_agent_before_reply` to register the updates run when the agent is selected, but before they reply.\n", + "\n", + "The `update_agent_before_reply` takes a list of any combination of the following (executing them in the provided order):\n", "\n", + "- `UPDATE_SYSTEM_MESSAGE` provides a simple way to update the agent's system message via an f-string that substitutes the values of context variables, or a Callable that returns a string\n", + "- Callable with two parameters of type `ConversableAgent` for the agent and `List[Dict[str Any]]` for the messages, and does not return a value\n", + "\n", + "Below is an example of setting these up when creating a Swarm agent.\n", + "\n", + "```python\n", + "# Creates a system message string\n", + "def create_system_prompt_function(my_agent: ConversableAgent, messages: List[Dict[]]) -> str:\n", + " preferred_name = my_agent.get_context(\"preferred_name\", \"(name not provided)\")\n", + "\n", + " # Note that the returned string will be treated like an f-string using the context variables\n", + " return \"You are a customer service representative helping a customer named \"\n", + " + preferred_name\n", + " + \" and their passport number is '{passport_number}'.\"\n", + "\n", + "# Function to update an Agent's state\n", + "def my_callable_state_update_function(my_agent: ConversableAgent, messages: List[Dict[]]) -> None:\n", + " agent.set_context(\"context_key\", 43)\n", + " agent.update_system_message(\"You are a customer service representative.\")\n", + "\n", + "# Create the SwarmAgent and set agent updates\n", + "customer_service = SwarmAgent(\n", + " name=\"CustomerServiceRep\",\n", + " system_message=\"You are a customer service representative.\",\n", + " update_agent_before_reply=[\n", + " UPDATE_SYSTEM_MESSAGE(\"You are a customer service representative. Quote passport number '{passport_number}'\"),\n", + " UPDATE_SYSTEM_MESSAGE(create_system_prompt_function),\n", + " my_callable_state_update_function]\n", + " ...\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "### Initialize SwarmChat with `initiate_swarm_chat`\n", "\n", "After a set of swarm agents are created, you can initiate a swarm chat by calling `initiate_swarm_chat`.\n", @@ -185,7 +235,7 @@ "\n", "> How are context variables updated?\n", "\n", - "The context variables will only be updated through custom function calls when returning a `SwarmResult` object. In fact, all interactions with context variables will be done through function calls (accessing and updating). The context variables dictionary is a reference, and any modification will be done in place.\n", + "In a swarm, the context variables are shared amongst Swarm agents. As context variables are available at the agent level, you can use the context variable getters/setters on the agent to view and change the shared context variables. If you're working with a function that returns a `SwarmResult` you should update the passed in context variables and return it in the `SwarmResult`, this will ensure the shared context is updated.\n", "\n", "> What is the difference between ON_CONDITION and AFTER_WORK?\n", "\n",