diff --git a/autogen/interop/crewai/crewai.py b/autogen/interop/crewai/crewai.py index d6f2a608ba..1588ccd529 100644 --- a/autogen/interop/crewai/crewai.py +++ b/autogen/interop/crewai/crewai.py @@ -18,7 +18,7 @@ def _sanitize_name(s: str) -> str: @register_interoperable_class("crewai") -class CrewAIInteroperability(Interoperable): +class CrewAIInteroperability: """ A class implementing the `Interoperable` protocol for converting CrewAI tools to a general `Tool` format. diff --git a/autogen/interop/langchain/langchain.py b/autogen/interop/langchain/langchain.py index 446617f7e2..3471e83203 100644 --- a/autogen/interop/langchain/langchain.py +++ b/autogen/interop/langchain/langchain.py @@ -13,7 +13,7 @@ @register_interoperable_class("langchain") -class LangChainInteroperability(Interoperable): +class LangChainInteroperability: """ A class implementing the `Interoperable` protocol for converting Langchain tools into a general `Tool` format. diff --git a/autogen/interop/pydantic_ai/pydantic_ai.py b/autogen/interop/pydantic_ai/pydantic_ai.py index fbb12e35b5..db2e22dd4b 100644 --- a/autogen/interop/pydantic_ai/pydantic_ai.py +++ b/autogen/interop/pydantic_ai/pydantic_ai.py @@ -9,15 +9,15 @@ from inspect import signature from typing import Any, Callable, Optional -from ...tools import PydanticAITool as AG2PydanticAITool from ..interoperability import Interoperable from ..registry import register_interoperable_class +from .pydantic_ai_tool import PydanticAITool as AG2PydanticAITool __all__ = ["PydanticAIInteroperability"] @register_interoperable_class("pydanticai") -class PydanticAIInteroperability(Interoperable): +class PydanticAIInteroperability: """ A class implementing the `Interoperable` protocol for converting Pydantic AI tools into a general `Tool` format. diff --git a/autogen/tools/pydantic_ai_tool.py b/autogen/interop/pydantic_ai/pydantic_ai_tool.py similarity index 96% rename from autogen/tools/pydantic_ai_tool.py rename to autogen/interop/pydantic_ai/pydantic_ai_tool.py index a106cd4b70..629f65e7ad 100644 --- a/autogen/tools/pydantic_ai_tool.py +++ b/autogen/interop/pydantic_ai/pydantic_ai_tool.py @@ -4,9 +4,8 @@ from typing import Any, Callable, Dict -from autogen.agentchat.conversable_agent import ConversableAgent - -from .tool import Tool +from ...agentchat.conversable_agent import ConversableAgent +from ...tools import Tool __all__ = ["PydanticAITool"] diff --git a/autogen/tools/__init__.py b/autogen/tools/__init__.py index d787ae1a60..5902681ce0 100644 --- a/autogen/tools/__init__.py +++ b/autogen/tools/__init__.py @@ -1,4 +1,7 @@ -from .pydantic_ai_tool import PydanticAITool +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + from .tool import Tool -__all__ = ["PydanticAITool", "Tool"] +__all__ = ["Tool"] diff --git a/autogen/tools/tool.py b/autogen/tools/tool.py index 6e41238287..43914aa59c 100644 --- a/autogen/tools/tool.py +++ b/autogen/tools/tool.py @@ -2,10 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict -from unittest.mock import MagicMock +from typing import Any, Callable -from autogen.agentchat.conversable_agent import ConversableAgent +from ..agentchat.conversable_agent import ConversableAgent __all__ = ["Tool"] diff --git a/test/interop/pydantic_ai/test_pydantic_ai_tool.py b/test/interop/pydantic_ai/test_pydantic_ai_tool.py new file mode 100644 index 0000000000..d6c1d942f9 --- /dev/null +++ b/test/interop/pydantic_ai/test_pydantic_ai_tool.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +import sys +import unittest + +import pytest + +from autogen import AssistantAgent + +if sys.version_info >= (3, 9): + from pydantic_ai.tools import Tool + + from autogen.interop.pydantic_ai.pydantic_ai_tool import PydanticAITool as AG2PydanticAITool +else: + Tool = unittest.mock.MagicMock() + AG2PydanticAITool = unittest.mock.MagicMock() + + +# skip if python version is not >= 3.9 +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability" +) +class TestPydanticAITool: + def test_register_for_llm(self) -> None: + def foobar(a: int, b: str, c: dict[str, list[float]]) -> str: # type: ignore[misc] + """Get me foobar. + + Args: + a: apple pie + b: banana cake + c: carrot smoothie + """ + return f"{a} {b} {c}" + + tool = Tool(foobar) + ag2_tool = AG2PydanticAITool( + name=tool.name, + description=tool.description, + func=tool.function, + parameters_json_schema=tool._parameters_json_schema, + ) + config_list = [{"model": "gpt-4o", "api_key": "abc"}] + chatbot = AssistantAgent( + name="chatbot", + llm_config={"config_list": config_list}, + ) + ag2_tool.register_for_llm(chatbot) + expected_tools = [ + { + "type": "function", + "function": { + "name": "foobar", + "description": "Get me foobar.", + "parameters": { + "properties": { + "a": {"description": "apple pie", "title": "A", "type": "integer"}, + "b": {"description": "banana cake", "title": "B", "type": "string"}, + "c": { + "additionalProperties": {"items": {"type": "number"}, "type": "array"}, + "description": "carrot smoothie", + "title": "C", + "type": "object", + }, + }, + "required": ["a", "b", "c"], + "type": "object", + "additionalProperties": False, + }, + }, + } + ] + assert chatbot.llm_config["tools"] == expected_tools # type: ignore[index]