From 2de4a90cb70705f6c65ab79d6f516640225fe776 Mon Sep 17 00:00:00 2001 From: Davor Runje Date: Thu, 19 Dec 2024 15:13:07 +0100 Subject: [PATCH] refactoring: making convert_tool a class method --- autogen/interop/__init__.py | 2 +- autogen/interop/crewai/crewai.py | 5 +-- autogen/interop/interoperability.py | 27 +++++++---------- autogen/interop/interoperable.py | 3 +- autogen/interop/langchain/__init__.py | 4 +-- autogen/interop/langchain/langchain.py | 7 +++-- autogen/interop/pydantic_ai/pydantic_ai.py | 4 ++- test/interop/crewai/test_crewai.py | 15 +++------ test/interop/langchain/test_langchain.py | 32 ++++++++++++-------- test/interop/pydantic_ai/test_pydantic_ai.py | 22 ++++++++------ test/interop/test_interoperability.py | 5 ++- 11 files changed, 66 insertions(+), 60 deletions(-) diff --git a/autogen/interop/__init__.py b/autogen/interop/__init__.py index 3afb5a529f..8f070c8f24 100644 --- a/autogen/interop/__init__.py +++ b/autogen/interop/__init__.py @@ -5,7 +5,7 @@ from .crewai import CrewAIInteroperability from .interoperability import Interoperability from .interoperable import Interoperable -from .langchain import LangchainInteroperability +from .langchain import LangChainInteroperability from .pydantic_ai import PydanticAIInteroperability from .registry import register_interoperable_class diff --git a/autogen/interop/crewai/crewai.py b/autogen/interop/crewai/crewai.py index 43465fa875..d6f2a608ba 100644 --- a/autogen/interop/crewai/crewai.py +++ b/autogen/interop/crewai/crewai.py @@ -4,7 +4,7 @@ import re import sys -from typing import Any, Optional, cast +from typing import Any, Optional from ...tools import Tool from ..interoperable import Interoperable @@ -26,7 +26,8 @@ class CrewAIInteroperability(Interoperable): This class takes a `CrewAITool` and converts it into a standard `Tool` object. """ - def convert_tool(self, tool: Any, **kwargs: Any) -> Tool: + @classmethod + def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool: """ Converts a given CrewAI tool into a general `Tool` format. diff --git a/autogen/interop/interoperability.py b/autogen/interop/interoperability.py index 243f759d07..b86285d6a6 100644 --- a/autogen/interop/interoperability.py +++ b/autogen/interop/interoperability.py @@ -18,16 +18,10 @@ class Interoperability: for retrieving and registering interoperability classes. """ - def __init__(self) -> None: - """ - Initializes an instance of the Interoperability class. - - This constructor does not perform any specific actions as the class is primarily used for its class - methods to manage interoperability classes. - """ - self.registry = InteroperableRegistry.get_instance() + registry = InteroperableRegistry.get_instance() - def convert_tool(self, *, tool: Any, type: str, **kwargs: Any) -> Tool: + @classmethod + def convert_tool(cls, *, tool: Any, type: str, **kwargs: Any) -> Tool: """ Converts a given tool to an instance of a specified interoperability type. @@ -42,11 +36,11 @@ def convert_tool(self, *, tool: Any, type: str, **kwargs: Any) -> Tool: Raises: ValueError: If the interoperability class for the provided type is not found. """ - interop_cls = self.get_interoperability_class(type) - interop = interop_cls() + interop = cls.get_interoperability_class(type) return interop.convert_tool(tool, **kwargs) - def get_interoperability_class(self, type: str) -> Type[Interoperable]: + @classmethod + def get_interoperability_class(cls, type: str) -> Type[Interoperable]: """ Retrieves the interoperability class corresponding to the specified type. @@ -59,20 +53,21 @@ def get_interoperability_class(self, type: str) -> Type[Interoperable]: Raises: ValueError: If no interoperability class is found for the provided type. """ - supported_types = self.registry.get_supported_types() + supported_types = cls.registry.get_supported_types() if type not in supported_types: supported_types_formated = ", ".join(["'t'" for t in supported_types]) raise ValueError( f"Interoperability class {type} is not supported, supported types: {supported_types_formated}" ) - return self.registry.get_class(type) + return cls.registry.get_class(type) - def get_supported_types(self) -> List[str]: + @classmethod + def get_supported_types(cls) -> List[str]: """ Returns a sorted list of all supported interoperability types. Returns: List[str]: A sorted list of strings representing the supported interoperability types. """ - return sorted(self.registry.get_supported_types()) + return sorted(cls.registry.get_supported_types()) diff --git a/autogen/interop/interoperable.py b/autogen/interop/interoperable.py index 2a11ace35d..185e36089d 100644 --- a/autogen/interop/interoperable.py +++ b/autogen/interop/interoperable.py @@ -18,7 +18,8 @@ class Interoperable(Protocol): `convert_tool` to convert a given tool into a desired format or type. """ - def convert_tool(self, tool: Any, **kwargs: Any) -> Tool: + @classmethod + def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool: """ Converts a given tool to a desired format or type. diff --git a/autogen/interop/langchain/__init__.py b/autogen/interop/langchain/__init__.py index 9af4e592cc..1aa1f7892c 100644 --- a/autogen/interop/langchain/__init__.py +++ b/autogen/interop/langchain/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from .langchain import LangchainInteroperability +from .langchain import LangChainInteroperability -__all__ = ["LangchainInteroperability"] +__all__ = ["LangChainInteroperability"] diff --git a/autogen/interop/langchain/langchain.py b/autogen/interop/langchain/langchain.py index 841689e283..446617f7e2 100644 --- a/autogen/interop/langchain/langchain.py +++ b/autogen/interop/langchain/langchain.py @@ -9,11 +9,11 @@ from ..interoperable import Interoperable from ..registry import register_interoperable_class -__all__ = ["LangchainInteroperability"] +__all__ = ["LangChainInteroperability"] @register_interoperable_class("langchain") -class LangchainInteroperability(Interoperable): +class LangChainInteroperability(Interoperable): """ A class implementing the `Interoperable` protocol for converting Langchain tools into a general `Tool` format. @@ -23,7 +23,8 @@ class LangchainInteroperability(Interoperable): the `Tool` format. """ - def convert_tool(self, tool: Any, **kwargs: Any) -> Tool: + @classmethod + def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool: """ Converts a given Langchain tool into a general `Tool` format. diff --git a/autogen/interop/pydantic_ai/pydantic_ai.py b/autogen/interop/pydantic_ai/pydantic_ai.py index d3260b4e82..fbb12e35b5 100644 --- a/autogen/interop/pydantic_ai/pydantic_ai.py +++ b/autogen/interop/pydantic_ai/pydantic_ai.py @@ -28,6 +28,7 @@ class PydanticAIInteroperability(Interoperable): into the tool's function. """ + @staticmethod def inject_params( ctx: Any, tool: Any, @@ -86,7 +87,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper - def convert_tool(self, tool: Any, deps: Any = None, **kwargs: Any) -> AG2PydanticAITool: + @classmethod + def convert_tool(cls, tool: Any, deps: Any = None, **kwargs: Any) -> AG2PydanticAITool: """ Converts a given Pydantic AI tool into a general `Tool` format. diff --git a/test/interop/crewai/test_crewai.py b/test/interop/crewai/test_crewai.py index 00666740fe..0d39926638 100644 --- a/test/interop/crewai/test_crewai.py +++ b/test/interop/crewai/test_crewai.py @@ -31,21 +31,18 @@ class TestCrewAIInteroperability: @pytest.fixture(autouse=True) def setup(self) -> None: - self.crewai_interop = CrewAIInteroperability() crewai_tool = FileReadTool() self.model_type = crewai_tool.args_schema - self.tool = self.crewai_interop.convert_tool(crewai_tool) + self.tool = CrewAIInteroperability.convert_tool(crewai_tool) def test_type_checks(self) -> None: # mypy should fail if the type checks are not correct - interop: Interoperable = self.crewai_interop + interop: Interoperable = CrewAIInteroperability() + # runtime check assert isinstance(interop, Interoperable) - def test_init(self) -> None: - assert isinstance(self.crewai_interop, Interoperable) - def test_convert_tool(self) -> None: with TemporaryDirectory() as tmp_dir: file_path = f"{tmp_dir}/test.txt" @@ -95,8 +92,7 @@ def test_with_llm(self) -> None: assert False, "Tool response not found in chat messages" def test_get_unsupported_reason(self) -> None: - crewai_interop = CrewAIInteroperability() - assert crewai_interop.get_unsupported_reason() is None + assert CrewAIInteroperability.get_unsupported_reason() is None @pytest.mark.skipif( @@ -104,8 +100,7 @@ def test_get_unsupported_reason(self) -> None: ) class TestCrewAIInteroperabilityIfNotSupported: def test_get_unsupported_reason(self) -> None: - crewai_interop = CrewAIInteroperability() assert ( - crewai_interop.get_unsupported_reason() + CrewAIInteroperability.get_unsupported_reason() == "This submodule is only supported for Python versions 3.10, 3.11, and 3.12" ) diff --git a/test/interop/langchain/test_langchain.py b/test/interop/langchain/test_langchain.py index 070c396784..376612ad11 100644 --- a/test/interop/langchain/test_langchain.py +++ b/test/interop/langchain/test_langchain.py @@ -16,17 +16,17 @@ if sys.version_info >= (3, 9): from langchain.tools import tool - from autogen.interop.langchain import LangchainInteroperability + from autogen.interop.langchain import LangChainInteroperability else: tool = unittest.mock.MagicMock() - LangchainInteroperability = unittest.mock.MagicMock() + LangChainInteroperability = unittest.mock.MagicMock() # skip if python version is not >= 3.9 @pytest.mark.skipif( sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability" ) -class TestLangchainInteroperability: +class TestLangChainInteroperability: @pytest.fixture(autouse=True) def setup(self) -> None: class SearchInput(BaseModel): @@ -37,19 +37,16 @@ def search(query: SearchInput) -> str: """Look up things online.""" return "LangChain Integration" - self.langchain_interop = LangchainInteroperability() self.model_type = search.args_schema - self.tool = self.langchain_interop.convert_tool(search) + self.tool = LangChainInteroperability.convert_tool(search) def test_type_checks(self) -> None: # mypy should fail if the type checks are not correct - interop: Interoperable = self.langchain_interop + interop: Interoperable = LangChainInteroperability() + # runtime check assert isinstance(interop, Interoperable) - def test_init(self) -> None: - assert isinstance(self.langchain_interop, Interoperable) - def test_convert_tool(self) -> None: assert self.tool.name == "search-tool" assert self.tool.description == "Look up things online." @@ -82,12 +79,15 @@ def test_with_llm(self) -> None: assert False, "No tool response found in chat messages" + def test_get_unsupported_reason(self) -> None: + assert LangChainInteroperability.get_unsupported_reason() is None + # skip if python version is not >= 3.9 @pytest.mark.skipif( sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability" ) -class TestLangchainInteroperabilityWithoutPydanticInput: +class TestLangChainInteroperabilityWithoutPydanticInput: @pytest.fixture(autouse=True) def setup(self) -> None: @tool @@ -95,8 +95,7 @@ def search(query: str, max_length: int) -> str: """Look up things online.""" return f"LangChain Integration, max_length: {max_length}" - self.langchain_interop = LangchainInteroperability() - self.tool = self.langchain_interop.convert_tool(search) + self.tool = LangChainInteroperability.convert_tool(search) self.model_type = search.args_schema def test_convert_tool(self) -> None: @@ -130,3 +129,12 @@ def test_with_llm(self) -> None: return assert False, "No tool response found in chat messages" + + +@pytest.mark.skipif(sys.version_info >= (3, 9), reason="LangChain Interoperability is supported") +class TestLangChainInteroperabilityIfNotSupported: + def test_get_unsupported_reason(self) -> None: + assert ( + LangChainInteroperability.get_unsupported_reason() + == "This submodule is only supported for Python versions 3.9 and above" + ) diff --git a/test/interop/pydantic_ai/test_pydantic_ai.py b/test/interop/pydantic_ai/test_pydantic_ai.py index 6b5dae9590..764a44e855 100644 --- a/test/interop/pydantic_ai/test_pydantic_ai.py +++ b/test/interop/pydantic_ai/test_pydantic_ai.py @@ -38,19 +38,15 @@ def roll_dice() -> str: """Roll a six-sided dice and return the result.""" return str(random.randint(1, 6)) - self.pydantic_ai_interop = PydanticAIInteroperability() pydantic_ai_tool = PydanticAITool(roll_dice, max_retries=3) - self.tool = self.pydantic_ai_interop.convert_tool(pydantic_ai_tool) + self.tool = PydanticAIInteroperability.convert_tool(pydantic_ai_tool) def test_type_checks(self) -> None: # mypy should fail if the type checks are not correct - interop: Interoperable = self.pydantic_ai_interop + interop: Interoperable = PydanticAIInteroperability() # runtime check assert isinstance(interop, Interoperable) - def test_init(self) -> None: - assert isinstance(self.pydantic_ai_interop, Interoperable) - def test_convert_tool(self) -> None: assert self.tool.name == "roll_dice" assert self.tool.description == "Roll a six-sided dice and return the result." @@ -162,14 +158,13 @@ def get_player(ctx: RunContext[Player], additional_info: Optional[str] = None) - """ return f"Name: {ctx.deps.name}, Age: {ctx.deps.age}, Additional info: {additional_info}" # type: ignore[attr-defined] - self.pydantic_ai_interop = PydanticAIInteroperability() self.pydantic_ai_tool = PydanticAITool(get_player, takes_ctx=True) player = Player(name="Luka", age=25) - self.tool = self.pydantic_ai_interop.convert_tool(tool=self.pydantic_ai_tool, deps=player) + self.tool = PydanticAIInteroperability.convert_tool(tool=self.pydantic_ai_tool, deps=player) def test_convert_tool_raises_error_if_take_ctx_is_true_and_deps_is_none(self) -> None: with pytest.raises(ValueError, match="If the tool takes a context, the `deps` argument must be provided"): - self.pydantic_ai_interop.convert_tool(tool=self.pydantic_ai_tool, deps=None) + PydanticAIInteroperability.convert_tool(tool=self.pydantic_ai_tool, deps=None) def test_expected_tools(self) -> None: config_list = [{"model": "gpt-4o", "api_key": "abc"}] @@ -229,3 +224,12 @@ def test_with_llm(self) -> None: return assert False, "No tool response found in chat messages" + + +@pytest.mark.skipif(sys.version_info >= (3, 9), reason="LangChain Interoperability is supported") +class TestPydanticAIInteroperabilityIfNotSupported: + def test_get_unsupported_reason(self) -> None: + assert ( + PydanticAIInteroperability.get_unsupported_reason() + == "This submodule is only supported for Python versions 3.9 and above" + ) diff --git a/test/interop/test_interoperability.py b/test/interop/test_interoperability.py index 2b1e0704ab..3925f056f1 100644 --- a/test/interop/test_interoperability.py +++ b/test/interop/test_interoperability.py @@ -13,7 +13,7 @@ class TestInteroperability: def test_supported_types(self) -> None: - actual = Interoperability().get_supported_types() + actual = Interoperability.get_supported_types() if sys.version_info < (3, 9): assert actual == [] @@ -35,8 +35,7 @@ def test_crewai(self) -> None: crewai_tool = FileReadTool() - interoperability = Interoperability() - tool = interoperability.convert_tool(type="crewai", tool=crewai_tool) + tool = Interoperability.convert_tool(type="crewai", tool=crewai_tool) with TemporaryDirectory() as tmp_dir: file_path = f"{tmp_dir}/test.txt"