-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #230 from ag2ai/add-tool-imports-pydantic-ai2
Add tool support for pydantic ai
- Loading branch information
Showing
17 changed files
with
567 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import sys | ||
|
||
if sys.version_info < (3, 9): | ||
raise ImportError("This submodule is only supported for Python versions 3.9 and above") | ||
|
||
try: | ||
import pydantic_ai.tools | ||
except ImportError: | ||
raise ImportError( | ||
"Please install `interop-pydantic-ai` extra to use this module:\n\n\tpip install ag2[interop-pydantic-ai]" | ||
) | ||
|
||
from .pydantic_ai import PydanticAIInteroperability | ||
|
||
__all__ = ["PydanticAIInteroperability"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from functools import wraps | ||
from inspect import signature | ||
from typing import Any, Callable, Optional | ||
|
||
from pydantic_ai import RunContext | ||
from pydantic_ai.tools import Tool as PydanticAITool | ||
|
||
from ...tools import PydanticAITool as AG2PydanticAITool | ||
from ..interoperability import Interoperable | ||
|
||
__all__ = ["PydanticAIInteroperability"] | ||
|
||
|
||
class PydanticAIInteroperability(Interoperable): | ||
@staticmethod | ||
def inject_params( # type: ignore[no-any-unimported] | ||
ctx: Optional[RunContext[Any]], | ||
tool: PydanticAITool, | ||
) -> Callable[..., Any]: | ||
max_retries = tool.max_retries if tool.max_retries is not None else 1 | ||
f = tool.function | ||
|
||
@wraps(f) | ||
def wrapper(*args: Any, **kwargs: Any) -> Any: | ||
if tool.current_retry >= max_retries: | ||
raise ValueError(f"{tool.name} failed after {max_retries} retries") | ||
|
||
try: | ||
if ctx is not None: | ||
kwargs.pop("ctx", None) | ||
ctx.retry = tool.current_retry | ||
result = f(**kwargs, ctx=ctx) | ||
else: | ||
result = f(**kwargs) | ||
tool.current_retry = 0 | ||
except Exception as e: | ||
tool.current_retry += 1 | ||
raise e | ||
|
||
return result | ||
|
||
sig = signature(f) | ||
if ctx is not None: | ||
new_params = [param for name, param in sig.parameters.items() if name != "ctx"] | ||
else: | ||
new_params = list(sig.parameters.values()) | ||
|
||
wrapper.__signature__ = sig.replace(parameters=new_params) # type: ignore[attr-defined] | ||
|
||
return wrapper | ||
|
||
def convert_tool(self, tool: Any, deps: Any = None) -> AG2PydanticAITool: | ||
if not isinstance(tool, PydanticAITool): | ||
raise ValueError(f"Expected an instance of `pydantic_ai.tools.Tool`, got {type(tool)}") | ||
|
||
# needed for type checking | ||
pydantic_ai_tool: PydanticAITool = tool # type: ignore[no-any-unimported] | ||
|
||
if deps is not None: | ||
ctx = RunContext( | ||
deps=deps, | ||
retry=0, | ||
# All messages send to or returned by a model. | ||
# This is mostly used on pydantic_ai Agent level. | ||
messages=None, # TODO: check in the future if this is needed on Tool level | ||
tool_name=pydantic_ai_tool.name, | ||
) | ||
else: | ||
ctx = None | ||
|
||
func = PydanticAIInteroperability.inject_params( | ||
ctx=ctx, | ||
tool=pydantic_ai_tool, | ||
) | ||
|
||
return AG2PydanticAITool( | ||
name=pydantic_ai_tool.name, | ||
description=pydantic_ai_tool.description, | ||
func=func, | ||
parameters_json_schema=pydantic_ai_tool._parameters_json_schema, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .pydantic_ai_tool import PydanticAITool | ||
from .tool import Tool | ||
|
||
__all__ = ["Tool"] | ||
__all__ = ["PydanticAITool", "Tool"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Any, Callable, Dict | ||
|
||
from autogen.agentchat.conversable_agent import ConversableAgent | ||
|
||
from .tool import Tool | ||
|
||
__all__ = ["PydanticAITool"] | ||
|
||
|
||
class PydanticAITool(Tool): | ||
def __init__( | ||
self, name: str, description: str, func: Callable[..., Any], parameters_json_schema: Dict[str, Any] | ||
) -> None: | ||
super().__init__(name, description, func) | ||
self._func_schema = { | ||
"type": "function", | ||
"function": { | ||
"name": name, | ||
"description": description, | ||
"parameters": parameters_json_schema, | ||
}, | ||
} | ||
|
||
def register_for_llm(self, agent: ConversableAgent) -> None: | ||
agent.update_tool_signature(self._func_schema, is_remove=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Integrating PydanticAI Tools with the AG2 Framework\n", | ||
"\n", | ||
"In this tutorial, we demonstrate how to integrate [PydanticAI Tools](https://ai.pydantic.dev/tools/) into the AG2 framework. This process enables smooth interoperability between the two systems, allowing developers to leverage PydanticAI's powerful tools within AG2's flexible agent-based architecture. By the end of this guide, you will understand how to configure agents, convert PydanticAI tools for use in AG2, and validate the integration with a practical example.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Installation\n", | ||
"To integrate LangChain tools into the AG2 framework, install the required dependencies:\n", | ||
"\n", | ||
"```bash\n", | ||
"pip install ag2[interop-pydantic-ai]\n", | ||
"```\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Imports\n", | ||
"\n", | ||
"Import necessary modules and tools.\n", | ||
"- `BaseModel`: Used to define data structures for tool inputs and outputs.\n", | ||
"- `RunContext`: Provides context during the execution of tools.\n", | ||
"- `PydanticAITool`: Represents a tool in the PydanticAI framework.\n", | ||
"- `AssistantAgent` and `UserProxyAgent`: Agents that facilitate communication in the AG2 framework.\n", | ||
"- `PydanticAIInteroperability`: A bridge for integrating PydanticAI tools with the AG2 framework." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"from typing import Optional\n", | ||
"\n", | ||
"from pydantic import BaseModel\n", | ||
"from pydantic_ai import RunContext\n", | ||
"from pydantic_ai.tools import Tool as PydanticAITool\n", | ||
"\n", | ||
"from autogen import AssistantAgent, UserProxyAgent\n", | ||
"from autogen.interop.pydantic_ai import PydanticAIInteroperability" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Agent Configuration\n", | ||
"\n", | ||
"Configure the agents for the interaction.\n", | ||
"- `config_list` defines the LLM configurations, including the model and API key.\n", | ||
"- `UserProxyAgent` simulates user inputs without requiring actual human interaction (set to `NEVER`).\n", | ||
"- `AssistantAgent` represents the AI agent, configured with the LLM settings." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"config_list = [{\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}]\n", | ||
"user_proxy = UserProxyAgent(\n", | ||
" name=\"User\",\n", | ||
" human_input_mode=\"NEVER\",\n", | ||
")\n", | ||
"\n", | ||
"chatbot = AssistantAgent(\n", | ||
" name=\"chatbot\",\n", | ||
" llm_config={\"config_list\": config_list},\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Tool Integration\n", | ||
"\n", | ||
"Integrate the PydanticAI tool with AG2.\n", | ||
"\n", | ||
"- Define a `Player` model using `BaseModel` to structure the input data.\n", | ||
"- Use `RunContext` to securely inject dependencies (like the `Player` instance) into the tool function without exposing them to the LLM.\n", | ||
"- Implement `get_player` to define the tool's functionality, accessing `ctx.deps` for injected data.\n", | ||
"- Convert the tool to an AG2-compatible format with `PydanticAIInteroperability` and register it for execution and LLM communication.\n", | ||
"- Convert the PydanticAI tool into an AG2-compatible format using `convert_tool`.\n", | ||
"- Register the tool for both execution and communication with the LLM by associating it with the `user_proxy` and `chatbot`." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class Player(BaseModel):\n", | ||
" name: str\n", | ||
" age: int\n", | ||
"\n", | ||
"\n", | ||
"def get_player(ctx: RunContext[Player], additional_info: Optional[str] = None) -> str: # type: ignore[valid-type]\n", | ||
" \"\"\"Get the player's name.\n", | ||
"\n", | ||
" Args:\n", | ||
" additional_info: Additional information which can be used.\n", | ||
" \"\"\"\n", | ||
" return f\"Name: {ctx.deps.name}, Age: {ctx.deps.age}, Additional info: {additional_info}\" # type: ignore[attr-defined]\n", | ||
"\n", | ||
"\n", | ||
"pydantic_ai_interop = PydanticAIInteroperability()\n", | ||
"pydantic_ai_tool = PydanticAITool(get_player, takes_ctx=True)\n", | ||
"\n", | ||
"# player will be injected as a dependency\n", | ||
"player = Player(name=\"Luka\", age=25)\n", | ||
"ag2_tool = pydantic_ai_interop.convert_tool(tool=pydantic_ai_tool, deps=player)\n", | ||
"\n", | ||
"ag2_tool.register_for_execution(user_proxy)\n", | ||
"ag2_tool.register_for_llm(chatbot)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Initiate a conversation between the `UserProxyAgent` and the `AssistantAgent`.\n", | ||
"\n", | ||
"- Use the `initiate_chat` method to send a message from the `user_proxy` to the `chatbot`.\n", | ||
"- In this example, the user requests the chatbot to retrieve player information, providing \"goal keeper\" as additional context.\n", | ||
"- The `Player` instance is securely injected into the tool using `RunContext`, ensuring the chatbot can retrieve and use this data during the interaction." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"user_proxy.initiate_chat(\n", | ||
" recipient=chatbot, message=\"Get player, for additional information use 'goal keeper'\", max_turns=3\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.16" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.