diff --git a/=8 b/=8 new file mode 100644 index 0000000000..d47c38ce90 --- /dev/null +++ b/=8 @@ -0,0 +1,6 @@ +Requirement already satisfied: pytest in ./.venv-3.9/lib/python3.9/site-packages (7.4.4) +Requirement already satisfied: iniconfig in ./.venv-3.9/lib/python3.9/site-packages (from pytest) (2.0.0) +Requirement already satisfied: packaging in ./.venv-3.9/lib/python3.9/site-packages (from pytest) (24.2) +Requirement already satisfied: pluggy<2.0,>=0.12 in ./.venv-3.9/lib/python3.9/site-packages (from pytest) (1.5.0) +Requirement already satisfied: exceptiongroup>=1.0.0rc8 in ./.venv-3.9/lib/python3.9/site-packages (from pytest) (1.2.2) +Requirement already satisfied: tomli>=1.0.0 in ./.venv-3.9/lib/python3.9/site-packages (from pytest) (2.2.1) diff --git a/autogen/interop/__init__.py b/autogen/interop/__init__.py index 90c5f92d7b..8f070c8f24 100644 --- a/autogen/interop/__init__.py +++ b/autogen/interop/__init__.py @@ -2,7 +2,11 @@ # # SPDX-License-Identifier: Apache-2.0 +from .crewai import CrewAIInteroperability from .interoperability import Interoperability from .interoperable import Interoperable +from .langchain import LangChainInteroperability +from .pydantic_ai import PydanticAIInteroperability +from .registry import register_interoperable_class -__all__ = ["Interoperable", "Interoperability"] +__all__ = ["Interoperability", "Interoperable", "register_interoperable_class"] diff --git a/autogen/interop/crewai/__init__.py b/autogen/interop/crewai/__init__.py index f3018b8787..50abbf3913 100644 --- a/autogen/interop/crewai/__init__.py +++ b/autogen/interop/crewai/__init__.py @@ -2,16 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import sys - -if sys.version_info < (3, 10) or sys.version_info >= (3, 13): - raise ImportError("This submodule is only supported for Python versions 3.10, 3.11, and 3.12") - -try: - import crewai.tools -except ImportError: - raise ImportError("Please install `interop-crewai` extra to use this module:\n\n\tpip install ag2[interop-crewai]") - from .crewai import CrewAIInteroperability __all__ = ["CrewAIInteroperability"] diff --git a/autogen/interop/crewai/crewai.py b/autogen/interop/crewai/crewai.py index 97abda9e77..1588ccd529 100644 --- a/autogen/interop/crewai/crewai.py +++ b/autogen/interop/crewai/crewai.py @@ -3,12 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 import re -from typing import Any, cast - -from crewai.tools import BaseTool as CrewAITool +import sys +from typing import Any, Optional from ...tools import Tool from ..interoperable import Interoperable +from ..registry import register_interoperable_class __all__ = ["CrewAIInteroperability"] @@ -17,7 +17,8 @@ def _sanitize_name(s: str) -> str: return re.sub(r"\W|^(?=\d)", "_", s) -class CrewAIInteroperability(Interoperable): +@register_interoperable_class("crewai") +class CrewAIInteroperability: """ A class implementing the `Interoperable` protocol for converting CrewAI tools to a general `Tool` format. @@ -25,7 +26,8 @@ class CrewAIInteroperability(Interoperable): This class takes a `CrewAITool` and converts it into a standard `Tool` object. """ - def convert_tool(self, tool: Any, **kwargs: Any) -> Tool: + @classmethod + def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool: """ Converts a given CrewAI tool into a general `Tool` format. @@ -44,6 +46,8 @@ def convert_tool(self, tool: Any, **kwargs: Any) -> Tool: ValueError: If the provided tool is not an instance of `CrewAITool`, or if any additional arguments are passed. """ + from crewai.tools import BaseTool as CrewAITool + if not isinstance(tool, CrewAITool): raise ValueError(f"Expected an instance of `crewai.tools.BaseTool`, got {type(tool)}") if kwargs: @@ -66,3 +70,15 @@ def func(args: crewai_tool.args_schema) -> Any: # type: ignore[no-any-unimporte description=description, func=func, ) + + @classmethod + def get_unsupported_reason(cls) -> Optional[str]: + if sys.version_info < (3, 10) or sys.version_info >= (3, 13): + return "This submodule is only supported for Python versions 3.10, 3.11, and 3.12" + + try: + import crewai.tools + except ImportError: + return "Please install `interop-crewai` extra to use this module:\n\n\tpip install ag2[interop-crewai]" + + return None diff --git a/autogen/interop/helpers.py b/autogen/interop/helpers.py deleted file mode 100644 index 5ec24afd9b..0000000000 --- a/autogen/interop/helpers.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai -# -# SPDX-License-Identifier: Apache-2.0 -# - -import importlib -import inspect -import logging -import pkgutil -import sys -from typing import Any, Dict, List, Set, Type - -from .interoperable import Interoperable - -logger = logging.getLogger(__name__) - - -def import_submodules(package_name: str) -> List[str]: - package = importlib.import_module(package_name) - imported_modules: List[str] = [] - for loader, module_name, is_pkg in pkgutil.walk_packages(package.__path__, package.__name__ + "."): - try: - importlib.import_module(module_name) - - imported_modules.append(module_name) - except Exception as e: - logger.info(f"Error importing {module_name}, most likely perfectly fine: {e}") - - return imported_modules - - -def find_classes_implementing_protocol(imported_modules: List[str], protocol: Type[Any]) -> List[Type[Any]]: - implementing_classes: Set[Type[Any]] = set() - for module in imported_modules: - for _, obj in inspect.getmembers(sys.modules[module], inspect.isclass): - if issubclass(obj, protocol) and obj is not protocol: - implementing_classes.add(obj) - - return list(implementing_classes) - - -def get_all_interoperability_classes() -> Dict[str, Type[Interoperable]]: - imported_modules = import_submodules("autogen.interop") - classes = find_classes_implementing_protocol(imported_modules, Interoperable) - - # check that all classes finish with 'Interoperability' - for cls in classes: - if not cls.__name__.endswith("Interoperability"): - raise RuntimeError(f"Class {cls} does not end with 'Interoperability'") - - retval = { - cls.__name__.split("Interoperability")[0].lower(): cls for cls in classes if cls.__name__ != "Interoperability" - } - - return retval diff --git a/autogen/interop/interoperability.py b/autogen/interop/interoperability.py index 27df3cd9c7..b86285d6a6 100644 --- a/autogen/interop/interoperability.py +++ b/autogen/interop/interoperability.py @@ -4,8 +4,8 @@ from typing import Any, Dict, List, Type from ..tools import Tool -from .helpers import get_all_interoperability_classes from .interoperable import Interoperable +from .registry import InteroperableRegistry __all__ = ["Interoperable"] @@ -18,18 +18,10 @@ class Interoperability: for retrieving and registering interoperability classes. """ - _interoperability_classes: Dict[str, Type[Interoperable]] = get_all_interoperability_classes() + registry = InteroperableRegistry.get_instance() - def __init__(self) -> None: - """ - Initializes an instance of the Interoperability class. - - This constructor does not perform any specific actions as the class is primarily used for its class - methods to manage interoperability classes. - """ - pass - - def convert_tool(self, *, tool: Any, type: str, **kwargs: Any) -> Tool: + @classmethod + def convert_tool(cls, *, tool: Any, type: str, **kwargs: Any) -> Tool: """ Converts a given tool to an instance of a specified interoperability type. @@ -44,8 +36,7 @@ def convert_tool(self, *, tool: Any, type: str, **kwargs: Any) -> Tool: Raises: ValueError: If the interoperability class for the provided type is not found. """ - interop_cls = self.get_interoperability_class(type) - interop = interop_cls() + interop = cls.get_interoperability_class(type) return interop.convert_tool(tool, **kwargs) @classmethod @@ -62,35 +53,21 @@ def get_interoperability_class(cls, type: str) -> Type[Interoperable]: Raises: ValueError: If no interoperability class is found for the provided type. """ - if type not in cls._interoperability_classes: - raise ValueError(f"Interoperability class {type} not found") - return cls._interoperability_classes[type] + supported_types = cls.registry.get_supported_types() + if type not in supported_types: + supported_types_formated = ", ".join(["'t'" for t in supported_types]) + raise ValueError( + f"Interoperability class {type} is not supported, supported types: {supported_types_formated}" + ) + + return cls.registry.get_class(type) @classmethod - def supported_types(cls) -> List[str]: + def get_supported_types(cls) -> List[str]: """ Returns a sorted list of all supported interoperability types. Returns: List[str]: A sorted list of strings representing the supported interoperability types. """ - return sorted(cls._interoperability_classes.keys()) - - @classmethod - def register_interoperability_class(cls, name: str, interoperability_class: Type[Interoperable]) -> None: - """ - Registers a new interoperability class with the given name. - - Args: - name (str): The name to associate with the interoperability class. - interoperability_class (Type[Interoperable]): The class implementing the Interoperable protocol. - - Raises: - ValueError: If the provided class does not implement the Interoperable protocol. - """ - if not issubclass(interoperability_class, Interoperable): - raise ValueError( - f"Expected a class implementing `Interoperable` protocol, got {type(interoperability_class)}" - ) - - cls._interoperability_classes[name] = interoperability_class + return sorted(cls.registry.get_supported_types()) diff --git a/autogen/interop/interoperable.py b/autogen/interop/interoperable.py index 75aefaaf25..185e36089d 100644 --- a/autogen/interop/interoperable.py +++ b/autogen/interop/interoperable.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Protocol, runtime_checkable +from typing import Any, Optional, Protocol, runtime_checkable from ..tools import Tool @@ -18,7 +18,8 @@ class Interoperable(Protocol): `convert_tool` to convert a given tool into a desired format or type. """ - def convert_tool(self, tool: Any, **kwargs: Any) -> Tool: + @classmethod + def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool: """ Converts a given tool to a desired format or type. @@ -32,3 +33,14 @@ def convert_tool(self, tool: Any, **kwargs: Any) -> Tool: Tool: The converted tool in the desired format or type. """ ... + + @classmethod + def get_unsupported_reason(cls) -> Optional[str]: + """Returns the reason for the tool being unsupported. + + This method should be implemented by any class adhering to the `Interoperable` protocol. + + Returns: + str: The reason for the interoperability class being unsupported. + """ + ... diff --git a/autogen/interop/langchain/__init__.py b/autogen/interop/langchain/__init__.py index 233c8642c3..1aa1f7892c 100644 --- a/autogen/interop/langchain/__init__.py +++ b/autogen/interop/langchain/__init__.py @@ -2,18 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import sys +from .langchain import LangChainInteroperability -if sys.version_info < (3, 9): - raise ImportError("This submodule is only supported for Python versions 3.9 and above") - -try: - import langchain.tools -except ImportError: - raise ImportError( - "Please install `interop-langchain` extra to use this module:\n\n\tpip install ag2[interop-langchain]" - ) - -from .langchain import LangchainInteroperability - -__all__ = ["LangchainInteroperability"] +__all__ = ["LangChainInteroperability"] diff --git a/autogen/interop/langchain/langchain.py b/autogen/interop/langchain/langchain.py index 925e00431a..3471e83203 100644 --- a/autogen/interop/langchain/langchain.py +++ b/autogen/interop/langchain/langchain.py @@ -2,17 +2,18 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any - -from langchain_core.tools import BaseTool as LangchainTool +import sys +from typing import Any, Optional from ...tools import Tool from ..interoperable import Interoperable +from ..registry import register_interoperable_class -__all__ = ["LangchainInteroperability"] +__all__ = ["LangChainInteroperability"] -class LangchainInteroperability(Interoperable): +@register_interoperable_class("langchain") +class LangChainInteroperability: """ A class implementing the `Interoperable` protocol for converting Langchain tools into a general `Tool` format. @@ -22,7 +23,8 @@ class LangchainInteroperability(Interoperable): the `Tool` format. """ - def convert_tool(self, tool: Any, **kwargs: Any) -> Tool: + @classmethod + def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool: """ Converts a given Langchain tool into a general `Tool` format. @@ -41,15 +43,17 @@ def convert_tool(self, tool: Any, **kwargs: Any) -> Tool: ValueError: If the provided tool is not an instance of `LangchainTool`, or if any additional arguments are passed. """ + from langchain_core.tools import BaseTool as LangchainTool + if not isinstance(tool, LangchainTool): raise ValueError(f"Expected an instance of `langchain_core.tools.BaseTool`, got {type(tool)}") if kwargs: raise ValueError(f"The LangchainInteroperability does not support any additional arguments, got {kwargs}") # needed for type checking - langchain_tool: LangchainTool = tool # type: ignore[no-any-unimported] + langchain_tool: LangchainTool = tool # type: ignore - def func(tool_input: langchain_tool.args_schema) -> Any: # type: ignore[no-any-unimported] + def func(tool_input: langchain_tool.args_schema) -> Any: # type: ignore return langchain_tool.run(tool_input.model_dump()) return Tool( @@ -57,3 +61,17 @@ def func(tool_input: langchain_tool.args_schema) -> Any: # type: ignore[no-any- description=langchain_tool.description, func=func, ) + + @classmethod + def get_unsupported_reason(cls) -> Optional[str]: + if sys.version_info < (3, 9): + return "This submodule is only supported for Python versions 3.9 and above" + + try: + import langchain_core.tools + except ImportError: + return ( + "Please install `interop-langchain` extra to use this module:\n\n\tpip install ag2[interop-langchain]" + ) + + return None diff --git a/autogen/interop/pydantic_ai/__init__.py b/autogen/interop/pydantic_ai/__init__.py index 55d52347ce..c022ebc414 100644 --- a/autogen/interop/pydantic_ai/__init__.py +++ b/autogen/interop/pydantic_ai/__init__.py @@ -2,18 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -import sys - -if sys.version_info < (3, 9): - raise ImportError("This submodule is only supported for Python versions 3.9 and above") - -try: - import pydantic_ai.tools -except ImportError: - raise ImportError( - "Please install `interop-pydantic-ai` extra to use this module:\n\n\tpip install ag2[interop-pydantic-ai]" - ) - from .pydantic_ai import PydanticAIInteroperability __all__ = ["PydanticAIInteroperability"] diff --git a/autogen/interop/pydantic_ai/pydantic_ai.py b/autogen/interop/pydantic_ai/pydantic_ai.py index 39d157daa6..db2e22dd4b 100644 --- a/autogen/interop/pydantic_ai/pydantic_ai.py +++ b/autogen/interop/pydantic_ai/pydantic_ai.py @@ -3,21 +3,21 @@ # SPDX-License-Identifier: Apache-2.0 +import sys import warnings from functools import wraps from inspect import signature from typing import Any, Callable, Optional -from pydantic_ai import RunContext -from pydantic_ai.tools import Tool as PydanticAITool - from ..interoperability import Interoperable +from ..registry import register_interoperable_class from .pydantic_ai_tool import PydanticAITool as AG2PydanticAITool __all__ = ["PydanticAIInteroperability"] -class PydanticAIInteroperability(Interoperable): +@register_interoperable_class("pydanticai") +class PydanticAIInteroperability: """ A class implementing the `Interoperable` protocol for converting Pydantic AI tools into a general `Tool` format. @@ -29,9 +29,9 @@ class PydanticAIInteroperability(Interoperable): """ @staticmethod - def inject_params( # type: ignore[no-any-unimported] - ctx: Optional[RunContext[Any]], - tool: PydanticAITool, + def inject_params( + ctx: Any, + tool: Any, ) -> Callable[..., Any]: """ Wraps the tool's function to inject context parameters and handle retries. @@ -40,8 +40,7 @@ def inject_params( # type: ignore[no-any-unimported] when invoked and that retries are managed according to the tool's settings. Args: - ctx (Optional[RunContext[Any]]): The run context, which may include dependencies - and retry information. + ctx (Optional[RunContext[Any]]): The run context, which may include dependencies and retry information. tool (PydanticAITool): The Pydantic AI tool whose function is to be wrapped. Returns: @@ -50,30 +49,36 @@ def inject_params( # type: ignore[no-any-unimported] Raises: ValueError: If the tool fails after the maximum number of retries. """ - max_retries = tool.max_retries if tool.max_retries is not None else 1 - f = tool.function + from pydantic_ai import RunContext + from pydantic_ai.tools import Tool as PydanticAITool + + ctx_typed: Optional[RunContext[Any]] = ctx # type: ignore + tool_typed: PydanticAITool[Any] = tool # type: ignore + + max_retries = tool_typed.max_retries if tool_typed.max_retries is not None else 1 + f = tool_typed.function @wraps(f) def wrapper(*args: Any, **kwargs: Any) -> Any: - if tool.current_retry >= max_retries: - raise ValueError(f"{tool.name} failed after {max_retries} retries") + if tool_typed.current_retry >= max_retries: + raise ValueError(f"{tool_typed.name} failed after {max_retries} retries") try: - if ctx is not None: + if ctx_typed is not None: kwargs.pop("ctx", None) - ctx.retry = tool.current_retry - result = f(**kwargs, ctx=ctx) + ctx_typed.retry = tool_typed.current_retry + result = f(**kwargs, ctx=ctx_typed) # type: ignore[call-arg] else: - result = f(**kwargs) - tool.current_retry = 0 + result = f(**kwargs) # type: ignore[call-arg] + tool_typed.current_retry = 0 except Exception as e: - tool.current_retry += 1 + tool_typed.current_retry += 1 raise e return result sig = signature(f) - if ctx is not None: + if ctx_typed is not None: new_params = [param for name, param in sig.parameters.items() if name != "ctx"] else: new_params = list(sig.parameters.values()) @@ -82,7 +87,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper - def convert_tool(self, tool: Any, deps: Any = None, **kwargs: Any) -> AG2PydanticAITool: + @classmethod + def convert_tool(cls, tool: Any, deps: Any = None, **kwargs: Any) -> AG2PydanticAITool: """ Converts a given Pydantic AI tool into a general `Tool` format. @@ -103,11 +109,14 @@ def convert_tool(self, tool: Any, deps: Any = None, **kwargs: Any) -> AG2Pydanti dependencies are missing for tools that require a context. UserWarning: If the `deps` argument is provided for a tool that does not take a context. """ + from pydantic_ai import RunContext + from pydantic_ai.tools import Tool as PydanticAITool + if not isinstance(tool, PydanticAITool): raise ValueError(f"Expected an instance of `pydantic_ai.tools.Tool`, got {type(tool)}") # needed for type checking - pydantic_ai_tool: PydanticAITool = tool # type: ignore[no-any-unimported] + pydantic_ai_tool: PydanticAITool[Any] = tool # type: ignore[no-any-unimported] if tool.takes_ctx and deps is None: raise ValueError("If the tool takes a context, the `deps` argument must be provided") @@ -123,7 +132,7 @@ def convert_tool(self, tool: Any, deps: Any = None, **kwargs: Any) -> AG2Pydanti retry=0, # All messages send to or returned by a model. # This is mostly used on pydantic_ai Agent level. - messages=None, # TODO: check in the future if this is needed on Tool level + messages=[], # TODO: check in the future if this is needed on Tool level tool_name=pydantic_ai_tool.name, ) else: @@ -140,3 +149,15 @@ def convert_tool(self, tool: Any, deps: Any = None, **kwargs: Any) -> AG2Pydanti func=func, parameters_json_schema=pydantic_ai_tool._parameters_json_schema, ) + + @classmethod + def get_unsupported_reason(cls) -> Optional[str]: + if sys.version_info < (3, 9): + return "This submodule is only supported for Python versions 3.9 and above" + + try: + import pydantic_ai.tools + except ImportError: + return "Please install `interop-pydantic-ai` extra to use this module:\n\n\tpip install ag2[interop-pydantic-ai]" + + return None diff --git a/autogen/interop/registry.py b/autogen/interop/registry.py new file mode 100644 index 0000000000..443dcb5beb --- /dev/null +++ b/autogen/interop/registry.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, Dict, Generic, List, Type, TypeVar + +from .interoperable import Interoperable + +__all__ = ["register_interoperable_class", "InteroperableRegistry"] + +InteroperableClass = TypeVar("InteroperableClass", bound=Type[Interoperable]) + + +class InteroperableRegistry: + def __init__(self) -> None: + self._registry: Dict[str, Type[Interoperable]] = {} + + def register(self, short_name: str, cls: InteroperableClass) -> InteroperableClass: + if short_name in self._registry: + raise ValueError(f"Duplicate registration for {short_name}") + + self._registry[short_name] = cls + + return cls + + def get_short_names(self) -> List[str]: + return sorted(self._registry.keys()) + + def get_supported_types(self) -> List[str]: + short_names = self.get_short_names() + supported_types = [name for name in short_names if self._registry[name].get_unsupported_reason() is None] + return supported_types + + def get_class(self, short_name: str) -> Type[Interoperable]: + return self._registry[short_name] + + @classmethod + def get_instance(cls) -> "InteroperableRegistry": + return _register + + +# global registry +_register = InteroperableRegistry() + + +# register decorator +def register_interoperable_class(short_name: str) -> Callable[[InteroperableClass], InteroperableClass]: + """Register an Interoperable class in the global registry. + + Returns: + Callable[[InteroperableClass], InteroperableClass]: Decorator function + + Example: + ```python + @register_interoperable_class("myinterop") + class MyInteroperability(Interoperable): + def convert_tool(self, tool: Any) -> Tool: + # implementation + ... + ``` + """ + + def inner(cls: InteroperableClass) -> InteroperableClass: + global _register + return _register.register(short_name, cls) + + return inner diff --git a/pyproject.toml b/pyproject.toml index 74a6d04c1e..c209c0b2fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,8 @@ no_implicit_optional = true check_untyped_defs = true warn_return_any = true show_error_codes = true -warn_unused_ignores = true + +warn_unused_ignores = false disallow_incomplete_defs = true disallow_untyped_decorators = true diff --git a/setup.py b/setup.py index 63cade834a..c438608894 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,9 @@ interop_crewai = ["crewai[tools]>=0.86,<1; python_version>='3.10' and python_version<'3.13'"] interop_langchain = ["langchain-community>=0.3.12,<1; python_version>='3.9'"] interop_pydantic_ai = ["pydantic-ai>=0.0.13,<1; python_version>='3.9'"] +interop = interop_crewai + interop_langchain + interop_pydantic_ai + +types = (["mypy==1.9.0", "pytest"] + jupyter_executor + interop,) if current_os in ["Windows", "Darwin"]: retrieve_chat_pgvector.extend(["psycopg[binary]>=3.1.18"]) @@ -117,7 +120,7 @@ "cosmosdb": ["azure-cosmos>=4.2.0"], "websockets": ["websockets>=12.0,<13"], "jupyter-executor": jupyter_executor, - "types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor, + "types": types, "long-context": ["llmlingua<0.3"], "anthropic": ["anthropic>=0.23.1"], "cerebras": ["cerebras_cloud_sdk>=1.0.0"], @@ -129,7 +132,7 @@ "interop-crewai": interop_crewai, "interop-langchain": interop_langchain, "interop-pydantic-ai": interop_pydantic_ai, - "interop": interop_crewai + interop_langchain + interop_pydantic_ai, + "interop": interop, "neo4j": neo4j, } diff --git a/test/interop/crewai/test_crewai.py b/test/interop/crewai/test_crewai.py index 1a2cbbd513..0d39926638 100644 --- a/test/interop/crewai/test_crewai.py +++ b/test/interop/crewai/test_crewai.py @@ -31,21 +31,18 @@ class TestCrewAIInteroperability: @pytest.fixture(autouse=True) def setup(self) -> None: - self.crewai_interop = CrewAIInteroperability() crewai_tool = FileReadTool() self.model_type = crewai_tool.args_schema - self.tool = self.crewai_interop.convert_tool(crewai_tool) + self.tool = CrewAIInteroperability.convert_tool(crewai_tool) def test_type_checks(self) -> None: # mypy should fail if the type checks are not correct - interop: Interoperable = self.crewai_interop + interop: Interoperable = CrewAIInteroperability() + # runtime check assert isinstance(interop, Interoperable) - def test_init(self) -> None: - assert isinstance(self.crewai_interop, Interoperable) - def test_convert_tool(self) -> None: with TemporaryDirectory() as tmp_dir: file_path = f"{tmp_dir}/test.txt" @@ -93,3 +90,17 @@ def test_with_llm(self) -> None: return assert False, "Tool response not found in chat messages" + + def test_get_unsupported_reason(self) -> None: + assert CrewAIInteroperability.get_unsupported_reason() is None + + +@pytest.mark.skipif( + sys.version_info >= (3, 10) or sys.version_info < (3, 13), reason="Crew AI Interoperability is supported" +) +class TestCrewAIInteroperabilityIfNotSupported: + def test_get_unsupported_reason(self) -> None: + assert ( + CrewAIInteroperability.get_unsupported_reason() + == "This submodule is only supported for Python versions 3.10, 3.11, and 3.12" + ) diff --git a/test/interop/langchain/test_langchain.py b/test/interop/langchain/test_langchain.py index 070c396784..376612ad11 100644 --- a/test/interop/langchain/test_langchain.py +++ b/test/interop/langchain/test_langchain.py @@ -16,17 +16,17 @@ if sys.version_info >= (3, 9): from langchain.tools import tool - from autogen.interop.langchain import LangchainInteroperability + from autogen.interop.langchain import LangChainInteroperability else: tool = unittest.mock.MagicMock() - LangchainInteroperability = unittest.mock.MagicMock() + LangChainInteroperability = unittest.mock.MagicMock() # skip if python version is not >= 3.9 @pytest.mark.skipif( sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability" ) -class TestLangchainInteroperability: +class TestLangChainInteroperability: @pytest.fixture(autouse=True) def setup(self) -> None: class SearchInput(BaseModel): @@ -37,19 +37,16 @@ def search(query: SearchInput) -> str: """Look up things online.""" return "LangChain Integration" - self.langchain_interop = LangchainInteroperability() self.model_type = search.args_schema - self.tool = self.langchain_interop.convert_tool(search) + self.tool = LangChainInteroperability.convert_tool(search) def test_type_checks(self) -> None: # mypy should fail if the type checks are not correct - interop: Interoperable = self.langchain_interop + interop: Interoperable = LangChainInteroperability() + # runtime check assert isinstance(interop, Interoperable) - def test_init(self) -> None: - assert isinstance(self.langchain_interop, Interoperable) - def test_convert_tool(self) -> None: assert self.tool.name == "search-tool" assert self.tool.description == "Look up things online." @@ -82,12 +79,15 @@ def test_with_llm(self) -> None: assert False, "No tool response found in chat messages" + def test_get_unsupported_reason(self) -> None: + assert LangChainInteroperability.get_unsupported_reason() is None + # skip if python version is not >= 3.9 @pytest.mark.skipif( sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability" ) -class TestLangchainInteroperabilityWithoutPydanticInput: +class TestLangChainInteroperabilityWithoutPydanticInput: @pytest.fixture(autouse=True) def setup(self) -> None: @tool @@ -95,8 +95,7 @@ def search(query: str, max_length: int) -> str: """Look up things online.""" return f"LangChain Integration, max_length: {max_length}" - self.langchain_interop = LangchainInteroperability() - self.tool = self.langchain_interop.convert_tool(search) + self.tool = LangChainInteroperability.convert_tool(search) self.model_type = search.args_schema def test_convert_tool(self) -> None: @@ -130,3 +129,12 @@ def test_with_llm(self) -> None: return assert False, "No tool response found in chat messages" + + +@pytest.mark.skipif(sys.version_info >= (3, 9), reason="LangChain Interoperability is supported") +class TestLangChainInteroperabilityIfNotSupported: + def test_get_unsupported_reason(self) -> None: + assert ( + LangChainInteroperability.get_unsupported_reason() + == "This submodule is only supported for Python versions 3.9 and above" + ) diff --git a/test/interop/pydantic_ai/test_pydantic_ai.py b/test/interop/pydantic_ai/test_pydantic_ai.py index 6b5dae9590..764a44e855 100644 --- a/test/interop/pydantic_ai/test_pydantic_ai.py +++ b/test/interop/pydantic_ai/test_pydantic_ai.py @@ -38,19 +38,15 @@ def roll_dice() -> str: """Roll a six-sided dice and return the result.""" return str(random.randint(1, 6)) - self.pydantic_ai_interop = PydanticAIInteroperability() pydantic_ai_tool = PydanticAITool(roll_dice, max_retries=3) - self.tool = self.pydantic_ai_interop.convert_tool(pydantic_ai_tool) + self.tool = PydanticAIInteroperability.convert_tool(pydantic_ai_tool) def test_type_checks(self) -> None: # mypy should fail if the type checks are not correct - interop: Interoperable = self.pydantic_ai_interop + interop: Interoperable = PydanticAIInteroperability() # runtime check assert isinstance(interop, Interoperable) - def test_init(self) -> None: - assert isinstance(self.pydantic_ai_interop, Interoperable) - def test_convert_tool(self) -> None: assert self.tool.name == "roll_dice" assert self.tool.description == "Roll a six-sided dice and return the result." @@ -162,14 +158,13 @@ def get_player(ctx: RunContext[Player], additional_info: Optional[str] = None) - """ return f"Name: {ctx.deps.name}, Age: {ctx.deps.age}, Additional info: {additional_info}" # type: ignore[attr-defined] - self.pydantic_ai_interop = PydanticAIInteroperability() self.pydantic_ai_tool = PydanticAITool(get_player, takes_ctx=True) player = Player(name="Luka", age=25) - self.tool = self.pydantic_ai_interop.convert_tool(tool=self.pydantic_ai_tool, deps=player) + self.tool = PydanticAIInteroperability.convert_tool(tool=self.pydantic_ai_tool, deps=player) def test_convert_tool_raises_error_if_take_ctx_is_true_and_deps_is_none(self) -> None: with pytest.raises(ValueError, match="If the tool takes a context, the `deps` argument must be provided"): - self.pydantic_ai_interop.convert_tool(tool=self.pydantic_ai_tool, deps=None) + PydanticAIInteroperability.convert_tool(tool=self.pydantic_ai_tool, deps=None) def test_expected_tools(self) -> None: config_list = [{"model": "gpt-4o", "api_key": "abc"}] @@ -229,3 +224,12 @@ def test_with_llm(self) -> None: return assert False, "No tool response found in chat messages" + + +@pytest.mark.skipif(sys.version_info >= (3, 9), reason="LangChain Interoperability is supported") +class TestPydanticAIInteroperabilityIfNotSupported: + def test_get_unsupported_reason(self) -> None: + assert ( + PydanticAIInteroperability.get_unsupported_reason() + == "This submodule is only supported for Python versions 3.9 and above" + ) diff --git a/test/interop/test_helpers.py b/test/interop/test_helpers.py deleted file mode 100644 index 76f52a85f4..0000000000 --- a/test/interop/test_helpers.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai -# -# SPDX-License-Identifier: Apache-2.0 - -import sys - -import pytest - -from autogen.interop import Interoperability, Interoperable -from autogen.interop.helpers import ( - find_classes_implementing_protocol, - get_all_interoperability_classes, - import_submodules, -) - - -class TestHelpers: - @pytest.fixture(autouse=True) - def setup_method(self) -> None: - self.imported_modules = import_submodules("autogen.interop") - - def test_import_submodules(self) -> None: - assert "autogen.interop.helpers" in self.imported_modules - - def test_find_classes_implementing_protocol(self) -> None: - actual = find_classes_implementing_protocol(self.imported_modules, Interoperable) - print(f"test_find_classes_implementing_protocol: {actual=}") - - assert Interoperability in actual - expected_count = 1 - - if sys.version_info >= (3, 10) and sys.version_info < (3, 13): - from autogen.interop.crewai import CrewAIInteroperability - - assert CrewAIInteroperability in actual - expected_count += 1 - - if sys.version_info >= (3, 9): - from autogen.interop.langchain import LangchainInteroperability - from autogen.interop.pydantic_ai import PydanticAIInteroperability - - assert LangchainInteroperability in actual - assert PydanticAIInteroperability in actual - expected_count += 2 - - assert len(actual) == expected_count - - def test_get_all_interoperability_classes(self) -> None: - - actual = get_all_interoperability_classes() - - if sys.version_info < (3, 9): - assert actual == {} - - if sys.version_info >= (3, 10) and sys.version_info < (3, 13): - from autogen.interop.crewai import CrewAIInteroperability - from autogen.interop.langchain import LangchainInteroperability - from autogen.interop.pydantic_ai import PydanticAIInteroperability - - assert actual == { - "pydanticai": PydanticAIInteroperability, - "crewai": CrewAIInteroperability, - "langchain": LangchainInteroperability, - } - - if (sys.version_info >= (3, 9) and sys.version_info < (3, 10)) and sys.version_info >= (3, 13): - from autogen.interop.langchain import LangchainInteroperability - - assert actual == {"langchain": LangchainInteroperability} diff --git a/test/interop/test_interoperability.py b/test/interop/test_interoperability.py index df7789dd13..3925f056f1 100644 --- a/test/interop/test_interoperability.py +++ b/test/interop/test_interoperability.py @@ -8,13 +8,12 @@ import pytest -from autogen.interop import Interoperability, Interoperable -from autogen.tools.tool import Tool +from autogen.interop import Interoperability class TestInteroperability: def test_supported_types(self) -> None: - actual = Interoperability.supported_types() + actual = Interoperability.get_supported_types() if sys.version_info < (3, 9): assert actual == [] @@ -28,26 +27,6 @@ def test_supported_types(self) -> None: if sys.version_info >= (3, 13): assert actual == ["langchain", "pydanticai"] - def test_register_interoperability_class(self) -> None: - org_interoperability_classes = Interoperability._interoperability_classes - try: - - class MyInteroperability: - def convert_tool(self, tool: Any, **kwargs: Any) -> Tool: - return Tool(name="test", description="test description", func=tool) - - Interoperability.register_interoperability_class("my_interop", MyInteroperability) - assert Interoperability.get_interoperability_class("my_interop") == MyInteroperability - - interop = Interoperability() - tool = interop.convert_tool(type="my_interop", tool=lambda x: x) - assert tool.name == "test" - assert tool.description == "test description" - assert tool.func("hello") == "hello" - - finally: - Interoperability._interoperability_classes = org_interoperability_classes - @pytest.mark.skipif( sys.version_info < (3, 10) or sys.version_info >= (3, 13), reason="Only Python 3.10, 3.11, 3.12 are supported" ) @@ -56,8 +35,7 @@ def test_crewai(self) -> None: crewai_tool = FileReadTool() - interoperability = Interoperability() - tool = interoperability.convert_tool(type="crewai", tool=crewai_tool) + tool = Interoperability.convert_tool(type="crewai", tool=crewai_tool) with TemporaryDirectory() as tmp_dir: file_path = f"{tmp_dir}/test.txt"