diff --git a/docs/docs/examples/tools/function_tool_callback.ipynb b/docs/docs/examples/tools/function_tool_callback.ipynb new file mode 100644 index 0000000000000..9e08f0e18b5f3 --- /dev/null +++ b/docs/docs/examples/tools/function_tool_callback.ipynb @@ -0,0 +1,185 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Function call with callback\n", + "\n", + "This is a feature that allows applying some human-in-the-loop concepts in FunctionTool.\n", + "\n", + "Basically, a callback function is added that enables the developer to request user input in the middle of an agent interaction, as well as allowing any programmatic action." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index-llms-openai\n", + "%pip install llama-index-agents-openai" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.tools import FunctionTool\n", + "from llama_index.agent.openai import OpenAIAgent\n", + "from llama_index.llms.openai import OpenAI\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"OPENAI_API_KEY\"] = \"sk-\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function to display to the user the data produced for function calling and request their input to return to the interaction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def callback(message):\n", + " confirmation = input(\n", + " f\"{message[1]}\\nDo you approve of sending this greeting?\\nInput(Y/N):\"\n", + " )\n", + "\n", + " if confirmation.lower() == \"y\":\n", + " # Here you can trigger an action such as sending an email, message, api call, etc.\n", + " return \"Greeting sent successfully.\"\n", + " else:\n", + " return (\n", + " \"Greeting has not been approved, talk a bit about how to improve\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Simple function that only requires a recipient and a greeting message." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def send_hello(destination: str, message: str) -> str:\n", + " \"\"\"\n", + " Say hello with a rhyme\n", + " destination: str - Name of recipient\n", + " message: str - Greeting message with a rhyme to the recipient's name\n", + " \"\"\"\n", + "\n", + " return destination, message\n", + "\n", + "\n", + "hello_tool = FunctionTool.from_defaults(fn=send_hello, callback=callback)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + ">" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hello_tool.to_langchain_tool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm = OpenAI()\n", + "agent = OpenAIAgent.from_tools([hello_tool])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The hello message has been sent to Karen with the rhyme \"Hello Karen, you're a star!\"\n" + ] + } + ], + "source": [ + "response = agent.chat(\"Send hello to Karen\")\n", + "print(str(response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "I have successfully sent a hello message to Joe with the greeting \"Hello Joe, you're a pro!\"\n" + ] + } + ], + "source": [ + "response = agent.chat(\"Send hello to Joe\")\n", + "print(str(response))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/llama-index-core/llama_index/core/tools/function_tool.py b/llama-index-core/llama_index/core/tools/function_tool.py index 23cf7a5f8a48d..7fc3315af6fe8 100644 --- a/llama-index-core/llama_index/core/tools/function_tool.py +++ b/llama-index-core/llama_index/core/tools/function_tool.py @@ -24,7 +24,7 @@ async def _async_wrapped_fn(*args: Any, **kwargs: Any) -> Any: def async_to_sync(func_async: AsyncCallable) -> Callable: - """Async from sync.""" + """Async to sync.""" def _sync_wrapped_fn(*args: Any, **kwargs: Any) -> Any: return asyncio_run(func_async(*args, **kwargs)) # type: ignore[arg-type] @@ -35,7 +35,7 @@ def _sync_wrapped_fn(*args: Any, **kwargs: Any) -> Any: class FunctionTool(AsyncBaseTool): """Function Tool. - A tool that takes in a function. + A tool that takes in a function and a callback. """ @@ -43,11 +43,14 @@ def __init__( self, fn: Optional[Callable[..., Any]] = None, metadata: Optional[ToolMetadata] = None, - async_fn: Optional[AsyncCallable] = None, + async_fn: Optional[Callable[..., Any]] = None, + callback: Optional[Callable[..., Any]] = None, + async_callback: Optional[Callable[..., Any]] = None, ) -> None: if fn is None and async_fn is None: raise ValueError("fn or async_fn must be provided.") + # Handle function (sync and async) if fn is not None: self._fn = fn elif async_fn is not None: @@ -61,8 +64,33 @@ def __init__( if metadata is None: raise ValueError("metadata must be provided.") + # Handle callback (sync and async) + self._callback = None + if callback is not None: + self._callback = callback + elif async_callback is not None: + self._callback = async_to_sync(async_callback) + + self._async_callback = None + if async_callback is not None: + self._async_callback = async_callback + elif self._callback is not None: + self._async_callback = sync_to_async(self._callback) + self._metadata = metadata + def _run_sync_callback(self, result: Any) -> Any: + """Runs the sync callback, if provided.""" + if self._callback: + return self._callback(result) + return None + + async def _run_async_callback(self, result: Any) -> Any: + """Runs the async callback, if provided.""" + if self._async_callback: + return await self._async_callback(result) + return None + @classmethod def from_defaults( cls, @@ -73,6 +101,8 @@ def from_defaults( fn_schema: Optional[Type[BaseModel]] = None, async_fn: Optional[AsyncCallable] = None, tool_metadata: Optional[ToolMetadata] = None, + callback: Optional[Callable[[Any], Any]] = None, + async_callback: Optional[AsyncCallable] = None, ) -> "FunctionTool": if tool_metadata is None: fn_to_parse = fn or async_fn @@ -90,7 +120,13 @@ def from_defaults( fn_schema=fn_schema, return_direct=return_direct, ) - return cls(fn=fn, metadata=tool_metadata, async_fn=async_fn) + return cls( + fn=fn, + metadata=tool_metadata, + async_fn=async_fn, + callback=callback, + async_callback=async_callback, + ) @property def metadata(self) -> ToolMetadata: @@ -108,20 +144,30 @@ def async_fn(self) -> AsyncCallable: return self._async_fn def call(self, *args: Any, **kwargs: Any) -> ToolOutput: - """Call.""" + """Sync Call.""" tool_output = self._fn(*args, **kwargs) + final_output_content = str(tool_output) + # Execute sync callback, if available + callback_output = self._run_sync_callback(tool_output) + if callback_output: + final_output_content += f" Callback: {callback_output}" return ToolOutput( - content=str(tool_output), + content=final_output_content, tool_name=self.metadata.name, raw_input={"args": args, "kwargs": kwargs}, raw_output=tool_output, ) async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput: - """Call.""" + """Async Call.""" tool_output = await self._async_fn(*args, **kwargs) + final_output_content = str(tool_output) + # Execute async callback, if available + callback_output = await self._run_async_callback(tool_output) + if callback_output: + final_output_content += f" Callback: {callback_output}" return ToolOutput( - content=str(tool_output), + content=final_output_content, tool_name=self.metadata.name, raw_input={"args": args, "kwargs": kwargs}, raw_output=tool_output, diff --git a/llama-index-core/llama_index/core/tools/types.py b/llama-index-core/llama_index/core/tools/types.py index 2355669858cf1..345563ad79929 100644 --- a/llama-index-core/llama_index/core/tools/types.py +++ b/llama-index-core/llama_index/core/tools/types.py @@ -122,6 +122,13 @@ def _process_langchain_tool_kwargs( langchain_tool_kwargs["description"] = self.metadata.description if "fn_schema" not in langchain_tool_kwargs: langchain_tool_kwargs["args_schema"] = self.metadata.fn_schema + + # Callback dont exist on langchain + if "_callback" in langchain_tool_kwargs: + del langchain_tool_kwargs["_callback"] + if "_async_callback" in langchain_tool_kwargs: + del langchain_tool_kwargs["_async_callback"] + return langchain_tool_kwargs def to_langchain_tool(