Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function tool callback #16637

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
185 changes: 185 additions & 0 deletions docs/docs/examples/tools/function_tool_callback.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Function call with callback\n",
"\n",
"This is a feature that allows applying some human-in-the-loop concepts in FunctionTool.\n",
"\n",
"Basically, a callback function is added that enables the developer to request user input in the middle of an agent interaction, as well as allowing any programmatic action."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install llama-index-llms-openai\n",
"%pip install llama-index-agents-openai"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from llama_index.core.tools import FunctionTool\n",
"from llama_index.agent.openai import OpenAIAgent\n",
"from llama_index.llms.openai import OpenAI\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"os.environ[\"OPENAI_API_KEY\"] = \"sk-\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Function to display to the user the data produced for function calling and request their input to return to the interaction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def callback(message):\n",
" confirmation = input(\n",
" f\"{message[1]}\\nDo you approve of sending this greeting?\\nInput(Y/N):\"\n",
" )\n",
"\n",
" if confirmation.lower() == \"y\":\n",
" # Here you can trigger an action such as sending an email, message, api call, etc.\n",
" return \"Greeting sent successfully.\"\n",
" else:\n",
" return (\n",
" \"Greeting has not been approved, talk a bit about how to improve\"\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Simple function that only requires a recipient and a greeting message."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def send_hello(destination: str, message: str) -> str:\n",
" \"\"\"\n",
" Say hello with a rhyme\n",
" destination: str - Name of recipient\n",
" message: str - Greeting message with a rhyme to the recipient's name\n",
" \"\"\"\n",
"\n",
" return destination, message\n",
"\n",
"\n",
"hello_tool = FunctionTool.from_defaults(fn=send_hello, callback=callback)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<bound method FunctionTool.to_langchain_tool of <llama_index.core.tools.function_tool.FunctionTool object at 0x7f7da9fa5670>>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hello_tool.to_langchain_tool"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI()\n",
"agent = OpenAIAgent.from_tools([hello_tool])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The hello message has been sent to Karen with the rhyme \"Hello Karen, you're a star!\"\n"
]
}
],
"source": [
"response = agent.chat(\"Send hello to Karen\")\n",
"print(str(response))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"I have successfully sent a hello message to Joe with the greeting \"Hello Joe, you're a pro!\"\n"
]
}
],
"source": [
"response = agent.chat(\"Send hello to Joe\")\n",
"print(str(response))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
62 changes: 54 additions & 8 deletions llama-index-core/llama_index/core/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def _async_wrapped_fn(*args: Any, **kwargs: Any) -> Any:


def async_to_sync(func_async: AsyncCallable) -> Callable:
"""Async from sync."""
"""Async to sync."""

def _sync_wrapped_fn(*args: Any, **kwargs: Any) -> Any:
return asyncio_run(func_async(*args, **kwargs)) # type: ignore[arg-type]
Expand All @@ -35,19 +35,22 @@ def _sync_wrapped_fn(*args: Any, **kwargs: Any) -> Any:
class FunctionTool(AsyncBaseTool):
"""Function Tool.

A tool that takes in a function.
A tool that takes in a function and a callback.

"""

def __init__(
self,
fn: Optional[Callable[..., Any]] = None,
metadata: Optional[ToolMetadata] = None,
async_fn: Optional[AsyncCallable] = None,
async_fn: Optional[Callable[..., Any]] = None,
callback: Optional[Callable[..., Any]] = None,
async_callback: Optional[Callable[..., Any]] = None,
) -> None:
if fn is None and async_fn is None:
raise ValueError("fn or async_fn must be provided.")

# Handle function (sync and async)
if fn is not None:
self._fn = fn
elif async_fn is not None:
Expand All @@ -61,8 +64,33 @@ def __init__(
if metadata is None:
raise ValueError("metadata must be provided.")

# Handle callback (sync and async)
self._callback = None
if callback is not None:
self._callback = callback
elif async_callback is not None:
self._callback = async_to_sync(async_callback)

self._async_callback = None
if async_callback is not None:
self._async_callback = async_callback
elif self._callback is not None:
self._async_callback = sync_to_async(self._callback)

self._metadata = metadata

def _run_sync_callback(self, result: Any) -> Any:
"""Runs the sync callback, if provided."""
if self._callback:
return self._callback(result)
return None

async def _run_async_callback(self, result: Any) -> Any:
"""Runs the async callback, if provided."""
if self._async_callback:
return await self._async_callback(result)
return None

@classmethod
def from_defaults(
cls,
Expand All @@ -73,6 +101,8 @@ def from_defaults(
fn_schema: Optional[Type[BaseModel]] = None,
async_fn: Optional[AsyncCallable] = None,
tool_metadata: Optional[ToolMetadata] = None,
callback: Optional[Callable[[Any], Any]] = None,
async_callback: Optional[AsyncCallable] = None,
) -> "FunctionTool":
if tool_metadata is None:
fn_to_parse = fn or async_fn
Expand All @@ -90,7 +120,13 @@ def from_defaults(
fn_schema=fn_schema,
return_direct=return_direct,
)
return cls(fn=fn, metadata=tool_metadata, async_fn=async_fn)
return cls(
fn=fn,
metadata=tool_metadata,
async_fn=async_fn,
callback=callback,
async_callback=async_callback,
)

@property
def metadata(self) -> ToolMetadata:
Expand All @@ -108,20 +144,30 @@ def async_fn(self) -> AsyncCallable:
return self._async_fn

def call(self, *args: Any, **kwargs: Any) -> ToolOutput:
"""Call."""
"""Sync Call."""
tool_output = self._fn(*args, **kwargs)
final_output_content = str(tool_output)
# Execute sync callback, if available
callback_output = self._run_sync_callback(tool_output)
if callback_output:
final_output_content += f" Callback: {callback_output}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, looking at this again, this seems rather opinionated 😅 What's the intended use case for this?

I kind of expected the callback to just either be logging something, or outright modifying the tool output in place.

Appending a string like this seems kind of strange

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll explain the use case we applied here to see if it becomes a bit clearer:

I have a tool that sends information to the external system of my company.
The agent is responsible for setting all parameters for the API call, but I can't blindly trust the information constructed by the agent.
So, I use this callback strategy to request user confirmation.
This confirmation returns to the agent's interaction, allowing the agent to decide the next steps.
Here, we apply this rule with interactions between agents and the front end for confirmation, but if you check the example notebook I attached, you can see this flow abstracted in a certain way that's easier to understand.

Basically, I need the callback return to influence the remainder of the agent's flow. I couldn't think of an easier way to adapt the classes without violating any principles. If the user doesn't want the callback to influence the flow, they just don't return anything in the function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like to me a better design might be

callback_output: ToolOutput | None = self._run_sync_callback(tool_output)

Basically the callback either returns a new tool output to override the existing, or returns None and the original tool_output is used

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then its up to the user how the callback changes the tool output

return ToolOutput(
content=str(tool_output),
content=final_output_content,
tool_name=self.metadata.name,
raw_input={"args": args, "kwargs": kwargs},
raw_output=tool_output,
)

async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput:
"""Call."""
"""Async Call."""
tool_output = await self._async_fn(*args, **kwargs)
final_output_content = str(tool_output)
# Execute async callback, if available
callback_output = await self._run_async_callback(tool_output)
if callback_output:
final_output_content += f" Callback: {callback_output}"
return ToolOutput(
content=str(tool_output),
content=final_output_content,
tool_name=self.metadata.name,
raw_input={"args": args, "kwargs": kwargs},
raw_output=tool_output,
Expand Down
7 changes: 7 additions & 0 deletions llama-index-core/llama_index/core/tools/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def _process_langchain_tool_kwargs(
langchain_tool_kwargs["description"] = self.metadata.description
if "fn_schema" not in langchain_tool_kwargs:
langchain_tool_kwargs["args_schema"] = self.metadata.fn_schema

# Callback dont exist on langchain
if "_callback" in langchain_tool_kwargs:
del langchain_tool_kwargs["_callback"]
if "_async_callback" in langchain_tool_kwargs:
del langchain_tool_kwargs["_async_callback"]

return langchain_tool_kwargs

def to_langchain_tool(
Expand Down
Loading