Skip to content

Commit

Permalink
Moved functions and update_agent_state_before_reply parameters and as…
Browse files Browse the repository at this point in the history
…sociated functions

Signed-off-by: Mark Sze <[email protected]>
  • Loading branch information
marklysze committed Dec 30, 2024
1 parent 07a3d25 commit e082623
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 6 deletions.
28 changes: 22 additions & 6 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@

from ..agent import Agent
from ..chat import ChatResult
from ..conversable_agent import ConversableAgent
from ..conversable_agent import __CONTEXT_VARIABLES_PARAM_NAME__, UPDATE_SYSTEM_MESSAGE, ConversableAgent
from ..groupchat import GroupChat, GroupChatManager
from ..user_proxy_agent import UserProxyAgent

""" MS REMOVE
# Parameter name for context variables
# Use the value in functions and they will be substituted with the context variables:
# e.g. def my_function(context_variables: Dict[str, Any], my_other_parameters: Any) -> Any:
__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables"
"""

__TOOL_EXECUTOR_NAME__ = "Tool_Execution"

Expand Down Expand Up @@ -84,6 +86,7 @@ def __post_init__(self):
assert isinstance(self.available, (Callable, str)), "'available' must be a callable or a string"


r''' MS REMOVE
@dataclass
class UPDATE_SYSTEM_MESSAGE:
"""Update the agent's system message before they reply
Expand Down Expand Up @@ -114,6 +117,7 @@ def __post_init__(self):
raise ValueError("Update function must return a string")
else:
raise ValueError("Update function must be either a string or a callable")
'''


def initiate_swarm_chat(
Expand Down Expand Up @@ -379,9 +383,7 @@ class SwarmAgent(ConversableAgent):
SwarmAgent is a subclass of ConversableAgent.
Additional args:
functions (List[Callable]): A list of functions to register with the agent.
update_agent_state_before_reply (List[Callable]): A list of functions, including UPDATE_SYSTEM_MESSAGEs, called to update the agent before it replies.
TRANSFERRING TO CONVERSABLEAGENT - INTERFACE SHOULD BE IDENTICAL
"""

def __init__(
Expand Down Expand Up @@ -409,9 +411,12 @@ def __init__(
llm_config=llm_config,
description=description,
code_execution_config=code_execution_config,
functions=functions,
update_agent_state_before_reply=update_agent_state_before_reply,
**kwargs,
)

""" MS REMOVE
if isinstance(functions, list):
if not all(isinstance(func, Callable) for func in functions):
raise TypeError("All elements in the functions list must be callable")
Expand All @@ -420,6 +425,7 @@ def __init__(
self.add_single_function(functions)
elif functions is not None:
raise TypeError("Functions must be a callable or a list of callables")
"""

self.after_work = None

Expand All @@ -430,7 +436,9 @@ def __init__(
# List of Dictionaries containing the nested_chats and condition
self._nested_chat_handoffs = []

""" MS REMOVE
self.register_update_agent_state_before_reply(update_agent_state_before_reply)
"""

# Store conditional functions (and their ON_CONDITION instances) to add/remove later when transitioning to this agent
self._conditional_functions = {}
Expand All @@ -439,6 +447,7 @@ def __init__(
if name != __TOOL_EXECUTOR_NAME__:
self.register_hook("update_agent_state", self._update_conditional_functions)

''' MS REMOVE
def register_update_agent_state_before_reply(self, functions: Optional[Union[list[Callable], Callable]]):
"""
Register functions that will be called when the agent is selected and before it speaks.
Expand Down Expand Up @@ -486,6 +495,7 @@ def update_system_message_wrapper(
else:
self.register_hook(hookable_method="update_agent_state", hook=func)
'''

def _set_to_tool_execution(self):
"""Set to a special instance of SwarmAgent that is responsible for executing tool calls from other swarm agents.
Expand All @@ -497,8 +507,10 @@ def _set_to_tool_execution(self):
self._reply_func_list.clear()
self.register_reply([Agent, None], SwarmAgent.generate_swarm_tool_reply)

""" MS REMOVE - NOT TRANSFERRED
def __str__(self):
return f"SwarmAgent --> {self.name}"
"""

def register_hand_off(
self,
Expand Down Expand Up @@ -566,7 +578,7 @@ def transfer_to_agent() -> "SwarmAgent":
raise ValueError("Invalid hand off condition, must be either ON_CONDITION or AFTER_WORK")

@staticmethod
def _update_conditional_functions(agent: Agent, messages: Optional[list[dict]] = None) -> None:
def _update_conditional_functions(agent: ConversableAgent, messages: Optional[list[dict]] = None) -> None:
"""Updates the agent's functions based on the ON_CONDITION's available condition."""
for func_name, (func, on_condition) in agent._conditional_functions.items():
is_available = True
Expand All @@ -579,7 +591,7 @@ def _update_conditional_functions(agent: Agent, messages: Optional[list[dict]] =

if is_available:
if func_name not in agent._function_map:
agent.add_single_function(func, func_name, on_condition.condition)
agent._add_single_function(func, func_name, on_condition.condition)
else:
# Remove function using the stored name
if func_name in agent._function_map:
Expand Down Expand Up @@ -666,6 +678,7 @@ def generate_swarm_tool_reply(
return True, tool_message
return False, None

''' MS REMOVE
def add_single_function(self, func: Callable, name=None, description=""):
"""Add a single function to the agent, removing context variables for LLM use"""
if name:
Expand Down Expand Up @@ -696,10 +709,13 @@ def add_single_function(self, func: Callable, name=None, description=""):
self.update_tool_signature(f_no_context, is_remove=False)
self.register_function({func._name: func})
'''

""" MS REMOVE
def add_functions(self, func_list: list[Callable]):
for func in func_list:
self.add_single_function(func)
"""

@staticmethod
def process_nested_chat_carryover(
Expand Down
151 changes: 151 additions & 0 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import re
import warnings
from collections import defaultdict
from dataclasses import dataclass
from inspect import signature
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union

from openai import BadRequestError
Expand Down Expand Up @@ -48,6 +50,43 @@

F = TypeVar("F", bound=Callable[..., Any])

# Parameter name for context variables
# Use the value in functions and they will be substituted with the context variables:
# e.g. def my_function(context_variables: Dict[str, Any], my_other_parameters: Any) -> Any:
__CONTEXT_VARIABLES_PARAM_NAME__ = "context_variables"


@dataclass
class UPDATE_SYSTEM_MESSAGE:
"""Update the agent's system message before they reply
Args:
update_function: The string or function to update the agent's system message. Can be a string or a Callable.
If a string, it will be used as a template and substitute the context variables.
If a Callable, it should have the signature:
def my_update_function(agent: ConversableAgent, messages: List[Dict[str, Any]]) -> str
"""

update_function: Union[Callable, str]

def __post_init__(self):
if isinstance(self.update_function, str):
# find all {var} in the string
vars = re.findall(r"\{(\w+)\}", self.update_function)
if len(vars) == 0:
warnings.warn("Update function string contains no variables. This is probably unintended.")

elif isinstance(self.update_function, Callable):
sig = signature(self.update_function)
if len(sig.parameters) != 2:
raise ValueError(
"Update function must accept two parameters of type ConversableAgent and List[Dict[str Any]], respectively"
)
if sig.return_annotation != str:
raise ValueError("Update function must return a string")
else:
raise ValueError("Update function must be either a string or a callable")


class ConversableAgent(LLMAgent):
"""(In preview) A class for generic conversable agents which can be configured as assistant or user proxy.
Expand Down Expand Up @@ -85,6 +124,10 @@ def __init__(
chat_messages: Optional[dict[Agent, list[dict]]] = None,
silent: Optional[bool] = None,
context_variables: Optional[dict[str, Any]] = None,
functions: Union[list[Callable], Callable] = None,
update_agent_state_before_reply: Optional[
Union[list[Union[Callable, UPDATE_SYSTEM_MESSAGE]], Callable, UPDATE_SYSTEM_MESSAGE]
] = None,
):
"""
Args:
Expand Down Expand Up @@ -139,6 +182,7 @@ def __init__(
Note: Will maintain a reference to the passed in context variables (enabling a shared context)
Only used in Swarms at this stage:
https://docs.ag2.ai/docs/reference/agentchat/contrib/swarm_agent
functions (List[Callable]): A list of functions to register with the agent.
"""
# we change code_execution_config below and we have to make sure we don't change the input
# in case of UserProxyAgent, without this we could even change the default value {}
Expand All @@ -161,6 +205,7 @@ def __init__(
else (lambda x: content_str(x.get("content")) == "TERMINATE")
)
self.silent = silent

# Take a copy to avoid modifying the given dict
if isinstance(llm_config, dict):
try:
Expand Down Expand Up @@ -199,6 +244,16 @@ def __init__(

self._context_variables = context_variables if context_variables is not None else {}

# Register functions to the agent
if isinstance(functions, list):
if not all(isinstance(func, Callable) for func in functions):
raise TypeError("All elements in the functions list must be callable")
self._add_functions(functions)
elif isinstance(functions, Callable):
self._add_single_function(functions)
elif functions is not None:
raise TypeError("Functions must be a callable or a list of callables")

# Setting up code execution.
# Do not register code execution reply if code execution is disabled.
if code_execution_config is not False:
Expand Down Expand Up @@ -266,6 +321,102 @@ def __init__(
"update_agent_state": [],
}

# Associate agent update state hooks
self._register_update_agent_state_before_reply(update_agent_state_before_reply)

def _add_functions(self, func_list: list[Callable]):
"""Add (Register) a list of functions to the agent
Args:
func_list (list[Callable]): A list of functions to register with the agent."""
for func in func_list:
self._add_single_function(func)

def _add_single_function(self, func: Callable, name: Optional[str] = None, description: Optional[str] = ""):
"""Add a single function to the agent, removing context variables for LLM use.
Args:
func (Callable): The function to register.
name (str): The name of the function. If not provided, the function's name will be used.
description (str): The description of the function, used by the LLM. If not provided, the function's docstring will be used.
"""
if name:
func._name = name
else:
func._name = func.__name__

if description:
func._description = description
else:
# Use function's docstring, strip whitespace, fall back to empty string
func._description = (func.__doc__ or "").strip()

f = get_function_schema(func, name=func._name, description=func._description)

# Remove context_variables parameter from function schema
f_no_context = f.copy()
if __CONTEXT_VARIABLES_PARAM_NAME__ in f_no_context["function"]["parameters"]["properties"]:
del f_no_context["function"]["parameters"]["properties"][__CONTEXT_VARIABLES_PARAM_NAME__]
if "required" in f_no_context["function"]["parameters"]:
required = f_no_context["function"]["parameters"]["required"]
f_no_context["function"]["parameters"]["required"] = [
param for param in required if param != __CONTEXT_VARIABLES_PARAM_NAME__
]
# If required list is empty, remove it
if not f_no_context["function"]["parameters"]["required"]:
del f_no_context["function"]["parameters"]["required"]

self.update_tool_signature(f_no_context, is_remove=False)
self.register_function({func._name: func})

def _register_update_agent_state_before_reply(self, functions: Optional[Union[list[Callable], Callable]]):
"""
Register functions that will be called when the agent is selected and before it speaks.
You can add your own validation or precondition functions here.
Args:
functions (List[Callable[[], None]]): A list of functions to be registered. Each function
is called when the agent is selected and before it speaks.
"""
if functions is None:
return
if not isinstance(functions, list) and type(functions) not in [UPDATE_SYSTEM_MESSAGE, Callable]:
raise ValueError("functions must be a list of callables")

if not isinstance(functions, list):
functions = [functions]

for func in functions:
if isinstance(func, UPDATE_SYSTEM_MESSAGE):

# Wrapper function that allows this to be used in the update_agent_state hook
# Its primary purpose, however, is just to update the agent's system message
# Outer function to create a closure with the update function
def create_wrapper(update_func: UPDATE_SYSTEM_MESSAGE):
def update_system_message_wrapper(
agent: ConversableAgent, messages: list[dict[str, Any]]
) -> list[dict[str, Any]]:
if isinstance(update_func.update_function, str):
# Templates like "My context variable passport is {passport}" will
# use the context_variables for substitution
sys_message = OpenAIWrapper.instantiate(
template=update_func.update_function,
context=agent._context_variables,
allow_format_str_template=True,
)
else:
sys_message = update_func.update_function(agent, messages)

agent.update_system_message(sys_message)
return messages

return update_system_message_wrapper

self.register_hook(hookable_method="update_agent_state", hook=create_wrapper(func))

else:
self.register_hook(hookable_method="update_agent_state", hook=func)

def _validate_llm_config(self, llm_config):
assert llm_config in (None, False) or isinstance(
llm_config, dict
Expand Down

0 comments on commit e082623

Please sign in to comment.