diff --git a/autogen/interop/__init__.py b/autogen/interop/__init__.py index b5dc8c9521..90c5f92d7b 100644 --- a/autogen/interop/__init__.py +++ b/autogen/interop/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from .interoperability import Interoperable +from .interoperability import Interoperability +from .interoperable import Interoperable -__all__ = ["Interoperable"] +__all__ = ["Interoperable", "Interoperability"] diff --git a/autogen/interop/crewai/crewai.py b/autogen/interop/crewai/crewai.py index 98df36cf05..c391b22888 100644 --- a/autogen/interop/crewai/crewai.py +++ b/autogen/interop/crewai/crewai.py @@ -8,7 +8,7 @@ from crewai.tools import BaseTool as CrewAITool from ...tools import Tool -from ..interoperability import Interoperable +from ..interoperable import Interoperable __all__ = ["CrewAIInteroperability"] diff --git a/autogen/interop/helpers.py b/autogen/interop/helpers.py new file mode 100644 index 0000000000..5ec24afd9b --- /dev/null +++ b/autogen/interop/helpers.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 +# + +import importlib +import inspect +import logging +import pkgutil +import sys +from typing import Any, Dict, List, Set, Type + +from .interoperable import Interoperable + +logger = logging.getLogger(__name__) + + +def import_submodules(package_name: str) -> List[str]: + package = importlib.import_module(package_name) + imported_modules: List[str] = [] + for loader, module_name, is_pkg in pkgutil.walk_packages(package.__path__, package.__name__ + "."): + try: + importlib.import_module(module_name) + + imported_modules.append(module_name) + except Exception as e: + logger.info(f"Error importing {module_name}, most likely perfectly fine: {e}") + + return imported_modules + + +def find_classes_implementing_protocol(imported_modules: List[str], protocol: Type[Any]) -> List[Type[Any]]: + implementing_classes: Set[Type[Any]] = set() + for module in imported_modules: + for _, obj in inspect.getmembers(sys.modules[module], inspect.isclass): + if issubclass(obj, protocol) and obj is not protocol: + implementing_classes.add(obj) + + return list(implementing_classes) + + +def get_all_interoperability_classes() -> Dict[str, Type[Interoperable]]: + imported_modules = import_submodules("autogen.interop") + classes = find_classes_implementing_protocol(imported_modules, Interoperable) + + # check that all classes finish with 'Interoperability' + for cls in classes: + if not cls.__name__.endswith("Interoperability"): + raise RuntimeError(f"Class {cls} does not end with 'Interoperability'") + + retval = { + cls.__name__.split("Interoperability")[0].lower(): cls for cls in classes if cls.__name__ != "Interoperability" + } + + return retval diff --git a/autogen/interop/interoperability.py b/autogen/interop/interoperability.py index 1e3c786f36..1ae0a528ed 100644 --- a/autogen/interop/interoperability.py +++ b/autogen/interop/interoperability.py @@ -1,13 +1,40 @@ # Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai # # SPDX-License-Identifier: Apache-2.0 - -from typing import Any, Protocol, runtime_checkable +from typing import Any, Dict, List, Type from ..tools import Tool +from .helpers import get_all_interoperability_classes +from .interoperable import Interoperable __all__ = ["Interoperable"] -@runtime_checkable -class Interoperable(Protocol): - def convert_tool(self, tool: Any) -> Tool: ... +class Interoperability: + _interoperability_classes: Dict[str, Type[Interoperable]] = get_all_interoperability_classes() + + def __init__(self) -> None: + pass + + def convert_tool(self, *, tool: Any, type: str) -> Tool: + interop_cls = self.get_interoperability_class(type) + interop = interop_cls() + return interop.convert_tool(tool) + + @classmethod + def get_interoperability_class(cls, type: str) -> Type[Interoperable]: + if type not in cls._interoperability_classes: + raise ValueError(f"Interoperability class {type} not found") + return cls._interoperability_classes[type] + + @classmethod + def supported_types(cls) -> List[str]: + return sorted(cls._interoperability_classes.keys()) + + @classmethod + def register_interoperability_class(cls, name: str, interoperability_class: Type[Interoperable]) -> None: + if not issubclass(interoperability_class, Interoperable): + raise ValueError( + f"Expected a class implementing `Interoperable` protocol, got {type(interoperability_class)}" + ) + + cls._interoperability_classes[name] = interoperability_class diff --git a/autogen/interop/interoperable.py b/autogen/interop/interoperable.py new file mode 100644 index 0000000000..dfe0f82500 --- /dev/null +++ b/autogen/interop/interoperable.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Protocol, runtime_checkable + +from ..tools import Tool + +__all__ = ["Interoperable"] + + +@runtime_checkable +class Interoperable(Protocol): + def convert_tool(self, tool: Any) -> Tool: ... diff --git a/autogen/interop/langchain/langchain.py b/autogen/interop/langchain/langchain.py index ad5c2a12a2..b3f4713c63 100644 --- a/autogen/interop/langchain/langchain.py +++ b/autogen/interop/langchain/langchain.py @@ -7,7 +7,7 @@ from langchain_core.tools import BaseTool as LangchainTool from ...tools import Tool -from ..interoperability import Interoperable +from ..interoperable import Interoperable __all__ = ["LangchainInteroperability"] diff --git a/test/interop/crewai/test_crewai.py b/test/interop/crewai/test_crewai.py index 1f0bae4bb1..1a2cbbd513 100644 --- a/test/interop/crewai/test_crewai.py +++ b/test/interop/crewai/test_crewai.py @@ -4,8 +4,8 @@ import os import sys -import unittest from tempfile import TemporaryDirectory +from unittest.mock import MagicMock import pytest from conftest import reason, skip_openai @@ -13,7 +13,7 @@ if sys.version_info >= (3, 10) and sys.version_info < (3, 13): from crewai_tools import FileReadTool else: - FileReadTool = unittest.mock.MagicMock() + FileReadTool = MagicMock() from autogen import AssistantAgent, UserProxyAgent from autogen.interop import Interoperable @@ -21,7 +21,7 @@ if sys.version_info >= (3, 10) and sys.version_info < (3, 13): from autogen.interop.crewai import CrewAIInteroperability else: - CrewAIInteroperability = unittest.mock.MagicMock() + CrewAIInteroperability = MagicMock() # skip if python version is not in [3.10, 3.11, 3.12] diff --git a/test/interop/test_helpers.py b/test/interop/test_helpers.py new file mode 100644 index 0000000000..f3a5ad09ef --- /dev/null +++ b/test/interop/test_helpers.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +import sys + +import pytest + +from autogen.interop import Interoperability, Interoperable +from autogen.interop.helpers import ( + find_classes_implementing_protocol, + get_all_interoperability_classes, + import_submodules, +) + + +class TestHelpers: + @pytest.fixture(autouse=True) + def setup_method(self) -> None: + self.imported_modules = import_submodules("autogen.interop") + + def test_import_submodules(self) -> None: + assert "autogen.interop.helpers" in self.imported_modules + + def test_find_classes_implementing_protocol(self) -> None: + actual = find_classes_implementing_protocol(self.imported_modules, Interoperable) + print(f"test_find_classes_implementing_protocol: {actual=}") + + assert Interoperability in actual + expected_count = 1 + + if sys.version_info >= (3, 10) and sys.version_info < (3, 13): + from autogen.interop.crewai import CrewAIInteroperability + + assert CrewAIInteroperability in actual + expected_count += 1 + + if sys.version_info >= (3, 9): + from autogen.interop.langchain import LangchainInteroperability + + assert LangchainInteroperability in actual + expected_count += 1 + + assert len(actual) == expected_count + + def test_get_all_interoperability_classes(self) -> None: + + actual = get_all_interoperability_classes() + + if sys.version_info < (3, 9): + assert actual == {} + + if sys.version_info >= (3, 10) and sys.version_info < (3, 13): + from autogen.interop.crewai import CrewAIInteroperability + from autogen.interop.langchain import LangchainInteroperability + + assert actual == {"crewai": CrewAIInteroperability, "langchain": LangchainInteroperability} + + if (sys.version_info >= (3, 9) and sys.version_info < (3, 10)) and sys.version_info >= (3, 13): + from autogen.interop.langchain import LangchainInteroperability + + assert actual == {"langchain": LangchainInteroperability} diff --git a/test/interop/test_interoperability.py b/test/interop/test_interoperability.py index 164854f6aa..dacdc16c87 100644 --- a/test/interop/test_interoperability.py +++ b/test/interop/test_interoperability.py @@ -2,8 +2,83 @@ # # SPDX-License-Identifier: Apache-2.0 -from autogen.interop import Interoperable +import sys +from tempfile import TemporaryDirectory +from typing import Any +import pytest -def test_interoperable() -> None: - assert Interoperable is not None +from autogen.interop import Interoperability, Interoperable +from autogen.tools.tool import Tool + + +class TestInteroperability: + def test_supported_types(self) -> None: + actual = Interoperability.supported_types() + + if sys.version_info < (3, 9): + assert actual == [] + + if sys.version_info >= (3, 9) and sys.version_info < (3, 10): + assert actual == ["langchain"] + + if sys.version_info >= (3, 10) and sys.version_info < (3, 13): + assert actual == ["crewai", "langchain"] + + if sys.version_info >= (3, 13): + assert actual == ["langchain"] + + def test_register_interoperability_class(self) -> None: + org_interoperability_classes = Interoperability._interoperability_classes + try: + + class MyInteroperability: + def convert_tool(self, tool: Any) -> Tool: + return Tool(name="test", description="test description", func=tool) + + Interoperability.register_interoperability_class("my_interop", MyInteroperability) + assert Interoperability.get_interoperability_class("my_interop") == MyInteroperability + + interop = Interoperability() + tool = interop.convert_tool(type="my_interop", tool=lambda x: x) + assert tool.name == "test" + assert tool.description == "test description" + assert tool.func("hello") == "hello" + + finally: + Interoperability._interoperability_classes = org_interoperability_classes + + @pytest.mark.skipif( + sys.version_info < (3, 10) or sys.version_info >= (3, 13), reason="Only Python 3.10, 3.11, 3.12 are supported" + ) + def test_crewai(self) -> None: + from crewai_tools import FileReadTool + + crewai_tool = FileReadTool() + + interoperability = Interoperability() + tool = interoperability.convert_tool(type="crewai", tool=crewai_tool) + + with TemporaryDirectory() as tmp_dir: + file_path = f"{tmp_dir}/test.txt" + with open(file_path, "w") as file: + file.write("Hello, World!") + + assert tool.name == "Read_a_file_s_content" + assert ( + tool.description + == "A tool that can be used to read a file's content. (IMPORTANT: When using arguments, put them all in an `args` dictionary)" + ) + + model_type = crewai_tool.args_schema + + args = model_type(file_path=file_path) + + assert tool.func(args=args) == "Hello, World!" + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability" + ) + @pytest.mark.skip(reason="This test is not yet implemented") + def test_langchain(self) -> None: + raise NotImplementedError("This test is not yet implemented") diff --git a/test/interop/test_interoperable.py b/test/interop/test_interoperable.py new file mode 100644 index 0000000000..164854f6aa --- /dev/null +++ b/test/interop/test_interoperable.py @@ -0,0 +1,9 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +from autogen.interop import Interoperable + + +def test_interoperable() -> None: + assert Interoperable is not None