From 9ce614fc87b5358726b2564e881dcbefac5e2630 Mon Sep 17 00:00:00 2001 From: Robert Jambrecic Date: Thu, 19 Dec 2024 15:14:50 +0100 Subject: [PATCH] Refactoring and test_pydantic_ai_tool.py added --- autogen/interop/pydantic_ai/pydantic_ai.py | 2 +- .../pydantic_ai}/pydantic_ai_tool.py | 5 +- autogen/tools/__init__.py | 7 +- autogen/tools/tool.py | 10 +-- .../pydantic_ai/test_pydantic_ai_tool.py | 74 +++++++++++++++++++ 5 files changed, 84 insertions(+), 14 deletions(-) rename autogen/{tools => interop/pydantic_ai}/pydantic_ai_tool.py (96%) create mode 100644 test/interop/pydantic_ai/test_pydantic_ai_tool.py diff --git a/autogen/interop/pydantic_ai/pydantic_ai.py b/autogen/interop/pydantic_ai/pydantic_ai.py index b170ccf501..39d157daa6 100644 --- a/autogen/interop/pydantic_ai/pydantic_ai.py +++ b/autogen/interop/pydantic_ai/pydantic_ai.py @@ -11,8 +11,8 @@ from pydantic_ai import RunContext from pydantic_ai.tools import Tool as PydanticAITool -from ...tools import PydanticAITool as AG2PydanticAITool from ..interoperability import Interoperable +from .pydantic_ai_tool import PydanticAITool as AG2PydanticAITool __all__ = ["PydanticAIInteroperability"] diff --git a/autogen/tools/pydantic_ai_tool.py b/autogen/interop/pydantic_ai/pydantic_ai_tool.py similarity index 96% rename from autogen/tools/pydantic_ai_tool.py rename to autogen/interop/pydantic_ai/pydantic_ai_tool.py index a106cd4b70..629f65e7ad 100644 --- a/autogen/tools/pydantic_ai_tool.py +++ b/autogen/interop/pydantic_ai/pydantic_ai_tool.py @@ -4,9 +4,8 @@ from typing import Any, Callable, Dict -from autogen.agentchat.conversable_agent import ConversableAgent - -from .tool import Tool +from ...agentchat.conversable_agent import ConversableAgent +from ...tools import Tool __all__ = ["PydanticAITool"] diff --git a/autogen/tools/__init__.py b/autogen/tools/__init__.py index d787ae1a60..5902681ce0 100644 --- a/autogen/tools/__init__.py +++ b/autogen/tools/__init__.py @@ -1,4 +1,7 @@ -from .pydantic_ai_tool import PydanticAITool +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + from .tool import Tool -__all__ = ["PydanticAITool", "Tool"] +__all__ = ["Tool"] diff --git a/autogen/tools/tool.py b/autogen/tools/tool.py index c0e615f37c..43914aa59c 100644 --- a/autogen/tools/tool.py +++ b/autogen/tools/tool.py @@ -2,15 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict -from unittest.mock import MagicMock +from typing import Any, Callable -from autogen.agentchat.conversable_agent import ConversableAgent - -try: - from crewai.tools import BaseTool as CrewAITool -except ImportError: - CrewAITool = MagicMock() +from ..agentchat.conversable_agent import ConversableAgent __all__ = ["Tool"] diff --git a/test/interop/pydantic_ai/test_pydantic_ai_tool.py b/test/interop/pydantic_ai/test_pydantic_ai_tool.py new file mode 100644 index 0000000000..d6c1d942f9 --- /dev/null +++ b/test/interop/pydantic_ai/test_pydantic_ai_tool.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +import sys +import unittest + +import pytest + +from autogen import AssistantAgent + +if sys.version_info >= (3, 9): + from pydantic_ai.tools import Tool + + from autogen.interop.pydantic_ai.pydantic_ai_tool import PydanticAITool as AG2PydanticAITool +else: + Tool = unittest.mock.MagicMock() + AG2PydanticAITool = unittest.mock.MagicMock() + + +# skip if python version is not >= 3.9 +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability" +) +class TestPydanticAITool: + def test_register_for_llm(self) -> None: + def foobar(a: int, b: str, c: dict[str, list[float]]) -> str: # type: ignore[misc] + """Get me foobar. + + Args: + a: apple pie + b: banana cake + c: carrot smoothie + """ + return f"{a} {b} {c}" + + tool = Tool(foobar) + ag2_tool = AG2PydanticAITool( + name=tool.name, + description=tool.description, + func=tool.function, + parameters_json_schema=tool._parameters_json_schema, + ) + config_list = [{"model": "gpt-4o", "api_key": "abc"}] + chatbot = AssistantAgent( + name="chatbot", + llm_config={"config_list": config_list}, + ) + ag2_tool.register_for_llm(chatbot) + expected_tools = [ + { + "type": "function", + "function": { + "name": "foobar", + "description": "Get me foobar.", + "parameters": { + "properties": { + "a": {"description": "apple pie", "title": "A", "type": "integer"}, + "b": {"description": "banana cake", "title": "B", "type": "string"}, + "c": { + "additionalProperties": {"items": {"type": "number"}, "type": "array"}, + "description": "carrot smoothie", + "title": "C", + "type": "object", + }, + }, + "required": ["a", "b", "c"], + "type": "object", + "additionalProperties": False, + }, + }, + } + ] + assert chatbot.llm_config["tools"] == expected_tools # type: ignore[index]