Skip to content

Commit

Permalink
Merge pull request #230 from ag2ai/add-tool-imports-pydantic-ai2
Browse files Browse the repository at this point in the history
Add tool support for pydantic ai
  • Loading branch information
davorrunje authored Dec 18, 2024
2 parents fd2b089 + 13cdc35 commit 4ba90f5
Show file tree
Hide file tree
Showing 17 changed files with 567 additions and 19 deletions.
2 changes: 1 addition & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2600,7 +2600,7 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None)

self.client = OpenAIWrapper(**self.llm_config)

def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None):
def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: bool):
"""update a tool_signature in the LLM configuration for tool_call.
Args:
Expand Down
2 changes: 2 additions & 0 deletions autogen/interop/interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .helpers import get_all_interoperability_classes
from .interoperable import Interoperable

__all__ = ["Interoperable"]


class Interoperability:
_interoperability_classes: Dict[str, Type[Interoperable]] = get_all_interoperability_classes()
Expand Down
19 changes: 19 additions & 0 deletions autogen/interop/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

import sys

if sys.version_info < (3, 9):
raise ImportError("This submodule is only supported for Python versions 3.9 and above")

try:
import pydantic_ai.tools
except ImportError:
raise ImportError(
"Please install `interop-pydantic-ai` extra to use this module:\n\n\tpip install ag2[interop-pydantic-ai]"
)

from .pydantic_ai import PydanticAIInteroperability

__all__ = ["PydanticAIInteroperability"]
86 changes: 86 additions & 0 deletions autogen/interop/pydantic_ai/pydantic_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0


from functools import wraps
from inspect import signature
from typing import Any, Callable, Optional

from pydantic_ai import RunContext
from pydantic_ai.tools import Tool as PydanticAITool

from ...tools import PydanticAITool as AG2PydanticAITool
from ..interoperability import Interoperable

__all__ = ["PydanticAIInteroperability"]


class PydanticAIInteroperability(Interoperable):
@staticmethod
def inject_params( # type: ignore[no-any-unimported]
ctx: Optional[RunContext[Any]],
tool: PydanticAITool,
) -> Callable[..., Any]:
max_retries = tool.max_retries if tool.max_retries is not None else 1
f = tool.function

@wraps(f)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if tool.current_retry >= max_retries:
raise ValueError(f"{tool.name} failed after {max_retries} retries")

try:
if ctx is not None:
kwargs.pop("ctx", None)
ctx.retry = tool.current_retry
result = f(**kwargs, ctx=ctx)
else:
result = f(**kwargs)
tool.current_retry = 0
except Exception as e:
tool.current_retry += 1
raise e

return result

sig = signature(f)
if ctx is not None:
new_params = [param for name, param in sig.parameters.items() if name != "ctx"]
else:
new_params = list(sig.parameters.values())

wrapper.__signature__ = sig.replace(parameters=new_params) # type: ignore[attr-defined]

return wrapper

def convert_tool(self, tool: Any, deps: Any = None) -> AG2PydanticAITool:
if not isinstance(tool, PydanticAITool):
raise ValueError(f"Expected an instance of `pydantic_ai.tools.Tool`, got {type(tool)}")

# needed for type checking
pydantic_ai_tool: PydanticAITool = tool # type: ignore[no-any-unimported]

if deps is not None:
ctx = RunContext(
deps=deps,
retry=0,
# All messages send to or returned by a model.
# This is mostly used on pydantic_ai Agent level.
messages=None, # TODO: check in the future if this is needed on Tool level
tool_name=pydantic_ai_tool.name,
)
else:
ctx = None

func = PydanticAIInteroperability.inject_params(
ctx=ctx,
tool=pydantic_ai_tool,
)

return AG2PydanticAITool(
name=pydantic_ai_tool.name,
description=pydantic_ai_tool.description,
func=func,
parameters_json_schema=pydantic_ai_tool._parameters_json_schema,
)
2 changes: 1 addition & 1 deletion autogen/oai/cerebras.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
CEREBRAS_PRICING_1K = {
# Convert pricing per million to per thousand tokens.
"llama3.1-8b": (0.10 / 1000, 0.10 / 1000),
"llama3.1-70b": (0.60 / 1000, 0.60 / 1000),
"llama-3.3-70b": (0.85 / 1000, 1.20 / 1000),
}


Expand Down
3 changes: 2 additions & 1 deletion autogen/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .pydantic_ai_tool import PydanticAITool
from .tool import Tool

__all__ = ["Tool"]
__all__ = ["PydanticAITool", "Tool"]
29 changes: 29 additions & 0 deletions autogen/tools/pydantic_ai_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict

from autogen.agentchat.conversable_agent import ConversableAgent

from .tool import Tool

__all__ = ["PydanticAITool"]


class PydanticAITool(Tool):
def __init__(
self, name: str, description: str, func: Callable[..., Any], parameters_json_schema: Dict[str, Any]
) -> None:
super().__init__(name, description, func)
self._func_schema = {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters_json_schema,
},
}

def register_for_llm(self, agent: ConversableAgent) -> None:
agent.update_tool_signature(self._func_schema, is_remove=False)
6 changes: 3 additions & 3 deletions notebook/tools_crewai_tools_integration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"from crewai_tools import FileWriterTool, ScrapeWebsiteTool\n",
"\n",
"from autogen import AssistantAgent, UserProxyAgent\n",
"from autogen.interoperability.crewai import CrewAIInteroperability"
"from autogen.interop.crewai import CrewAIInteroperability"
]
},
{
Expand Down Expand Up @@ -168,7 +168,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
Expand All @@ -182,7 +182,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.10.16"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions notebook/tools_langchain_tools_integration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"from langchain_community.utilities import WikipediaAPIWrapper\n",
"\n",
"from autogen import AssistantAgent, UserProxyAgent\n",
"from autogen.interoperability.langchain import LangchainInteroperability"
"from autogen.interop.langchain import LangchainInteroperability"
]
},
{
Expand Down Expand Up @@ -112,7 +112,7 @@
"ag2_tool = langchain_interop.convert_tool(langchain_tool)\n",
"\n",
"ag2_tool.register_for_execution(user_proxy)\n",
"ag2_tool.register_for_llm(chatbot)\n"
"ag2_tool.register_for_llm(chatbot)"
]
},
{
Expand Down
183 changes: 183 additions & 0 deletions notebook/tools_pydantic_ai_tools_integration.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Integrating PydanticAI Tools with the AG2 Framework\n",
"\n",
"In this tutorial, we demonstrate how to integrate [PydanticAI Tools](https://ai.pydantic.dev/tools/) into the AG2 framework. This process enables smooth interoperability between the two systems, allowing developers to leverage PydanticAI's powerful tools within AG2's flexible agent-based architecture. By the end of this guide, you will understand how to configure agents, convert PydanticAI tools for use in AG2, and validate the integration with a practical example.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Installation\n",
"To integrate LangChain tools into the AG2 framework, install the required dependencies:\n",
"\n",
"```bash\n",
"pip install ag2[interop-pydantic-ai]\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports\n",
"\n",
"Import necessary modules and tools.\n",
"- `BaseModel`: Used to define data structures for tool inputs and outputs.\n",
"- `RunContext`: Provides context during the execution of tools.\n",
"- `PydanticAITool`: Represents a tool in the PydanticAI framework.\n",
"- `AssistantAgent` and `UserProxyAgent`: Agents that facilitate communication in the AG2 framework.\n",
"- `PydanticAIInteroperability`: A bridge for integrating PydanticAI tools with the AG2 framework."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from typing import Optional\n",
"\n",
"from pydantic import BaseModel\n",
"from pydantic_ai import RunContext\n",
"from pydantic_ai.tools import Tool as PydanticAITool\n",
"\n",
"from autogen import AssistantAgent, UserProxyAgent\n",
"from autogen.interop.pydantic_ai import PydanticAIInteroperability"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Agent Configuration\n",
"\n",
"Configure the agents for the interaction.\n",
"- `config_list` defines the LLM configurations, including the model and API key.\n",
"- `UserProxyAgent` simulates user inputs without requiring actual human interaction (set to `NEVER`).\n",
"- `AssistantAgent` represents the AI agent, configured with the LLM settings."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"config_list = [{\"model\": \"gpt-4o\", \"api_key\": os.environ[\"OPENAI_API_KEY\"]}]\n",
"user_proxy = UserProxyAgent(\n",
" name=\"User\",\n",
" human_input_mode=\"NEVER\",\n",
")\n",
"\n",
"chatbot = AssistantAgent(\n",
" name=\"chatbot\",\n",
" llm_config={\"config_list\": config_list},\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Integration\n",
"\n",
"Integrate the PydanticAI tool with AG2.\n",
"\n",
"- Define a `Player` model using `BaseModel` to structure the input data.\n",
"- Use `RunContext` to securely inject dependencies (like the `Player` instance) into the tool function without exposing them to the LLM.\n",
"- Implement `get_player` to define the tool's functionality, accessing `ctx.deps` for injected data.\n",
"- Convert the tool to an AG2-compatible format with `PydanticAIInteroperability` and register it for execution and LLM communication.\n",
"- Convert the PydanticAI tool into an AG2-compatible format using `convert_tool`.\n",
"- Register the tool for both execution and communication with the LLM by associating it with the `user_proxy` and `chatbot`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class Player(BaseModel):\n",
" name: str\n",
" age: int\n",
"\n",
"\n",
"def get_player(ctx: RunContext[Player], additional_info: Optional[str] = None) -> str: # type: ignore[valid-type]\n",
" \"\"\"Get the player's name.\n",
"\n",
" Args:\n",
" additional_info: Additional information which can be used.\n",
" \"\"\"\n",
" return f\"Name: {ctx.deps.name}, Age: {ctx.deps.age}, Additional info: {additional_info}\" # type: ignore[attr-defined]\n",
"\n",
"\n",
"pydantic_ai_interop = PydanticAIInteroperability()\n",
"pydantic_ai_tool = PydanticAITool(get_player, takes_ctx=True)\n",
"\n",
"# player will be injected as a dependency\n",
"player = Player(name=\"Luka\", age=25)\n",
"ag2_tool = pydantic_ai_interop.convert_tool(tool=pydantic_ai_tool, deps=player)\n",
"\n",
"ag2_tool.register_for_execution(user_proxy)\n",
"ag2_tool.register_for_llm(chatbot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Initiate a conversation between the `UserProxyAgent` and the `AssistantAgent`.\n",
"\n",
"- Use the `initiate_chat` method to send a message from the `user_proxy` to the `chatbot`.\n",
"- In this example, the user requests the chatbot to retrieve player information, providing \"goal keeper\" as additional context.\n",
"- The `Player` instance is securely injected into the tool using `RunContext`, ensuring the chatbot can retrieve and use this data during the interaction."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"user_proxy.initiate_chat(\n",
" recipient=chatbot, message=\"Get player, for additional information use 'goal keeper'\", max_turns=3\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 4ba90f5

Please sign in to comment.