Skip to content

Commit

Permalink
Merge pull request #310 from ag2ai/tool_captain
Browse files Browse the repository at this point in the history
Integrating Tools from other frameworks into CaptainAgent
  • Loading branch information
LeoLjl authored Dec 29, 2024
2 parents 721ee7a + d2ebc86 commit 07a3d25
Show file tree
Hide file tree
Showing 4 changed files with 969 additions and 55 deletions.
61 changes: 41 additions & 20 deletions autogen/agentchat/contrib/captainagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import hashlib
import json
import os
from typing import Callable, Dict, List, Literal, Optional, Union
from typing import Callable, Literal, Optional, Union

from termcolor import colored

import autogen
from autogen import UserProxyAgent
from autogen.agentchat.conversable_agent import ConversableAgent

from .agent_builder import AgentBuilder
from .tool_retriever import ToolBuilder, get_full_tool_description
from .tool_retriever import ToolBuilder, format_ag2_tool, get_full_tool_description


class CaptainAgent(ConversableAgent):
Expand Down Expand Up @@ -387,8 +389,9 @@ def _run_autobuild(self, group_name: str, execution_task: str, building_task: st
# tool library is enabled, reload tools and bind them to the agents
tool_root_dir = self.tool_root_dir
tool_builder = ToolBuilder(
corpus_path=os.path.join(tool_root_dir, "tool_description.tsv"),
corpus_root=tool_root_dir,
retriever=self._nested_config["autobuild_tool_config"].get("retriever", "all-mpnet-base-v2"),
type=self.tool_type,
)
for idx, agent in enumerate(agent_list):
if idx == len(self.tool_history[group_name]):
Expand All @@ -404,39 +407,57 @@ def _run_autobuild(self, group_name: str, execution_task: str, building_task: st
self.build_history[group_name] = agent_configs.copy()

if self._nested_config.get("autobuild_tool_config", None) and agent_configs["coding"] is True:
print("==> Retrieving tools...", flush=True)
skills = building_task.split("\n")
if len(skills) == 0:
skills = [building_task]

tool_type = "default"
if self._nested_config["autobuild_tool_config"].get("tool_root", "default") == "default":
print(colored("==> Retrieving tools...", "green"), flush=True)
cur_path = os.path.dirname(os.path.abspath(__file__))
tool_root_dir = os.path.join(cur_path, "captainagent", "tools")
elif isinstance(self._nested_config["autobuild_tool_config"].get("tool_root", "default"), list):
# We get a list, in this case, we assume it contains several tools for the agents
tool_root_dir = self._nested_config["autobuild_tool_config"]["tool_root"]
tool_type = "user_defined"
else:
tool_root_dir = self._nested_config["autobuild_tool_config"]["tool_root"]
self.tool_root_dir = tool_root_dir
self.tool_type = tool_type

# Retrieve and build tools based on the smilarities between the skills and the tool description
tool_builder = ToolBuilder(
corpus_path=os.path.join(tool_root_dir, "tool_description.tsv"),
corpus_root=tool_root_dir,
retriever=self._nested_config["autobuild_tool_config"].get("retriever", "all-mpnet-base-v2"),
type=tool_type,
)
for idx, skill in enumerate(skills):
tools = tool_builder.retrieve(skill)
if tool_type == "default":
for idx, skill in enumerate(skills):
tools = tool_builder.retrieve(skill)
docstrings = []
for tool in tools:
category, tool_name = tool.split(" ")[0], tool.split(" ")[1]
tool_path = os.path.join(tool_root_dir, category, f"{tool_name}.py")
docstring = get_full_tool_description(tool_path)
docstrings.append(docstring)
tool_builder.bind(agent_list[idx], "\n\n".join(docstrings))
# the last agent is the user proxy agent, we need special treatment
agent_list[-1] = tool_builder.bind_user_proxy(agent_list[-1], tool_root_dir)
else:
# a list containing all the tools that the agents share
docstrings = []
for tool in tools:
category, tool_name = tool.split(" ")[0], tool.split(" ")[1]
tool_path = os.path.join(tool_root_dir, category, f"{tool_name}.py")
docstring = get_full_tool_description(tool_path)
docstrings.append(docstring)
tool_builder.bind(agent_list[idx], "\n\n".join(docstrings))
# log tools
tool_history = self.tool_history.get(group_name, [])
tool_history.append(docstrings)
self.tool_history[group_name] = tool_history

agent_list[-1] = tool_builder.bind_user_proxy(agent_list[-1], tool_root_dir)

for tool in tool_root_dir:
docstrings.append(format_ag2_tool(tool))
for idx, agent in enumerate(agent_list):
if idx == len(agent_list) - 1:
break
tool_builder.bind(agent, "\n\n".join(docstrings))
agent_list[-1] = tool_builder.bind_user_proxy(agent_list[-1], tool_root_dir)

# log tools
tool_history = self.tool_history.get(group_name, [])
tool_history.append(docstrings)
self.tool_history[group_name] = tool_history
else:
# Build agents from scratch
agent_list, agent_configs = builder.build(
Expand Down
151 changes: 117 additions & 34 deletions autogen/agentchat/contrib/tool_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,39 @@
from sentence_transformers import SentenceTransformer, util

from autogen import AssistantAgent, UserProxyAgent
from autogen.coding import LocalCommandLineCodeExecutor
from autogen.coding import CodeExecutor, CodeExtractor, LocalCommandLineCodeExecutor, MarkdownCodeExtractor
from autogen.coding.base import CodeBlock, CodeResult
from autogen.function_utils import load_basemodels_if_needed
from autogen.function_utils import get_function_schema, load_basemodels_if_needed
from autogen.tools import Tool


class ToolBuilder:
TOOL_USING_PROMPT = """# Functions
You have access to the following functions. They can be accessed from the module called 'functions' by their function names.
TOOL_PROMPT_DEFAULT = """\n## Functions
You have access to the following functions. They can be accessed from the module called 'functions' by their function names.
For example, if there is a function called `foo` you could import it by writing `from functions import foo`
{functions}
"""
TOOL_PROMPT_USER_DEFINED = """\n## Functions
You have access to the following functions. You can write python code to call these functions directly without importing them.
{functions}
"""

def __init__(self, corpus_path, retriever="all-mpnet-base-v2"):

self.df = pd.read_csv(corpus_path, sep="\t")
document_list = self.df["document_content"].tolist()
def __init__(self, corpus_root, retriever="all-mpnet-base-v2", type="default"):
if type == "default":
corpus_path = os.path.join(corpus_root, "tool_description.tsv")
self.df = pd.read_csv(corpus_path, sep="\t")
document_list = self.df["document_content"].tolist()
self.TOOL_PROMPT = self.TOOL_PROMPT_DEFAULT
else:
self.TOOL_PROMPT = self.TOOL_PROMPT_USER_DEFINED
# user defined tools, retrieve is actually not needed, just for consistency
document_list = []
for tool in corpus_root:
document_list.append(tool.description)

self.model = SentenceTransformer(retriever)
self.embeddings = self.model.encode(document_list)
self.type = type

def retrieve(self, query, top_k=3):
# Encode the query using the Sentence Transformer model
Expand All @@ -55,39 +68,59 @@ def retrieve(self, query, top_k=3):
def bind(self, agent: AssistantAgent, functions: str):
"""Binds the function to the agent so that agent is aware of it."""
sys_message = agent.system_message
sys_message += self.TOOL_USING_PROMPT.format(functions=functions)
sys_message += self.TOOL_PROMPT.format(functions=functions)
agent.update_system_message(sys_message)
return

def bind_user_proxy(self, agent: UserProxyAgent, tool_root: str):
def bind_user_proxy(self, agent: UserProxyAgent, tool_root: Union[str, list]):
"""
Updates user proxy agent with a executor so that code executor can successfully execute function-related code.
Returns an updated user proxy.
"""
# Find all the functions in the tool root
functions = find_callables(tool_root)

code_execution_config = agent._code_execution_config
executor = LocalCommandLineCodeExecutor(
timeout=code_execution_config.get("timeout", 180),
work_dir=code_execution_config.get("work_dir", "coding"),
functions=functions,
)
code_execution_config = {
"executor": executor,
"last_n_messages": code_execution_config.get("last_n_messages", 1),
}
updated_user_proxy = UserProxyAgent(
name=agent.name,
is_termination_msg=agent._is_termination_msg,
code_execution_config=code_execution_config,
human_input_mode="NEVER",
default_auto_reply=agent._default_auto_reply,
)
return updated_user_proxy


class LocalExecutorWithTools:
if isinstance(tool_root, str):
# Find all the functions in the tool root
functions = find_callables(tool_root)

code_execution_config = agent._code_execution_config
executor = LocalCommandLineCodeExecutor(
timeout=code_execution_config.get("timeout", 180),
work_dir=code_execution_config.get("work_dir", "coding"),
functions=functions,
)
code_execution_config = {
"executor": executor,
"last_n_messages": code_execution_config.get("last_n_messages", 1),
}
updated_user_proxy = UserProxyAgent(
name=agent.name,
is_termination_msg=agent._is_termination_msg,
code_execution_config=code_execution_config,
human_input_mode="NEVER",
default_auto_reply=agent._default_auto_reply,
)
return updated_user_proxy
else:
# second case: user defined tools
code_execution_config = agent._code_execution_config
executor = LocalExecutorWithTools(
tools=tool_root,
work_dir=code_execution_config.get("work_dir", "coding"),
)
code_execution_config = {
"executor": executor,
"last_n_messages": code_execution_config.get("last_n_messages", 1),
}
updated_user_proxy = UserProxyAgent(
name=agent.name,
is_termination_msg=agent._is_termination_msg,
code_execution_config=code_execution_config,
human_input_mode="NEVER",
default_auto_reply=agent._default_auto_reply,
)
return updated_user_proxy


class LocalExecutorWithTools(CodeExecutor):
"""
An executor that executes code blocks with injected tools. In this executor, the func within the tools can be called directly without declaring in the code block.
Expand Down Expand Up @@ -124,6 +157,11 @@ class LocalExecutorWithTools:
work_dir: The working directory for the code execution. Default is the current directory.
"""

@property
def code_extractor(self) -> CodeExtractor:
"""(Experimental) Export a code extractor that can be used by an agent."""
return MarkdownCodeExtractor()

def __init__(self, tools: Optional[List[Tool]] = None, work_dir: Union[Path, str] = Path(".")):
self.tools = tools if tools is not None else []
self.work_dir = work_dir
Expand Down Expand Up @@ -189,6 +227,51 @@ def restart(self):
pass


def format_ag2_tool(tool: Tool):
# get the args first
schema = get_function_schema(tool.func, description=tool.description)

arg_name = list(inspect.signature(tool.func).parameters.keys())[0]
arg_info = schema["function"]["parameters"]["properties"][arg_name]["properties"]

content = f'def {tool.name}({arg_name}):\n """\n'
content += indent(tool.description, " ") + "\n"
content += (
indent(
f"You must format all the arguments into a dictionary and pass them as **kwargs to {arg_name}. You should use print function to get the results.",
" ",
)
+ "\n"
+ indent(f"For example:\n\tresult = {tool.name}({arg_name}={{'arg1': 'value1' }})", " ")
+ "\n"
)
content += indent(f"Arguments passed in {arg_name}:\n", " ")
for arg, info in arg_info.items():
content += indent(f"{arg} ({info['type']}): {info['description']}\n", " " * 2)
content += ' """\n'
return content


def _wrap_function(func):
"""Wrap the function to dump the return value to json.
Handles both sync and async functions.
Args:
func: the function to be wrapped.
Returns:
The wrapped function.
"""

@load_basemodels_if_needed
@functools.wraps(func)
def _wrapped_func(*args, **kwargs):
return func(*args, **kwargs)

return _wrapped_func


def get_full_tool_description(py_file):
"""
Retrieves the function signature for a given Python file.
Expand Down
1 change: 0 additions & 1 deletion notebook/agentchat_captainagent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
"import autogen\n",
"\n",
"config_path = \"OAI_CONFIG_LIST\"\n",
"llm_config = {\"temperature\": 0}\n",
"config_list = autogen.config_list_from_json(\n",
" config_path, filter_dict={\"model\": [\"gpt-4o\"]}\n",
") # You can modify the filter_dict to select your model"
Expand Down
Loading

0 comments on commit 07a3d25

Please sign in to comment.