Skip to content

Commit

Permalink
polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
davorrunje committed Dec 19, 2024
2 parents 9ce614f + 5adbb53 commit 3f644d0
Show file tree
Hide file tree
Showing 19 changed files with 260 additions and 292 deletions.
6 changes: 6 additions & 0 deletions =8
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Requirement already satisfied: pytest in ./.venv-3.9/lib/python3.9/site-packages (7.4.4)
Requirement already satisfied: iniconfig in ./.venv-3.9/lib/python3.9/site-packages (from pytest) (2.0.0)
Requirement already satisfied: packaging in ./.venv-3.9/lib/python3.9/site-packages (from pytest) (24.2)
Requirement already satisfied: pluggy<2.0,>=0.12 in ./.venv-3.9/lib/python3.9/site-packages (from pytest) (1.5.0)
Requirement already satisfied: exceptiongroup>=1.0.0rc8 in ./.venv-3.9/lib/python3.9/site-packages (from pytest) (1.2.2)
Requirement already satisfied: tomli>=1.0.0 in ./.venv-3.9/lib/python3.9/site-packages (from pytest) (2.2.1)
6 changes: 5 additions & 1 deletion autogen/interop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
#
# SPDX-License-Identifier: Apache-2.0

from .crewai import CrewAIInteroperability
from .interoperability import Interoperability
from .interoperable import Interoperable
from .langchain import LangChainInteroperability
from .pydantic_ai import PydanticAIInteroperability
from .registry import register_interoperable_class

__all__ = ["Interoperable", "Interoperability"]
__all__ = ["Interoperability", "Interoperable", "register_interoperable_class"]
10 changes: 0 additions & 10 deletions autogen/interop/crewai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

import sys

if sys.version_info < (3, 10) or sys.version_info >= (3, 13):
raise ImportError("This submodule is only supported for Python versions 3.10, 3.11, and 3.12")

try:
import crewai.tools
except ImportError:
raise ImportError("Please install `interop-crewai` extra to use this module:\n\n\tpip install ag2[interop-crewai]")

from .crewai import CrewAIInteroperability

__all__ = ["CrewAIInteroperability"]
26 changes: 21 additions & 5 deletions autogen/interop/crewai/crewai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# SPDX-License-Identifier: Apache-2.0

import re
from typing import Any, cast

from crewai.tools import BaseTool as CrewAITool
import sys
from typing import Any, Optional

from ...tools import Tool
from ..interoperable import Interoperable
from ..registry import register_interoperable_class

__all__ = ["CrewAIInteroperability"]

Expand All @@ -17,15 +17,17 @@ def _sanitize_name(s: str) -> str:
return re.sub(r"\W|^(?=\d)", "_", s)


class CrewAIInteroperability(Interoperable):
@register_interoperable_class("crewai")
class CrewAIInteroperability:
"""
A class implementing the `Interoperable` protocol for converting CrewAI tools
to a general `Tool` format.
This class takes a `CrewAITool` and converts it into a standard `Tool` object.
"""

def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
@classmethod
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
"""
Converts a given CrewAI tool into a general `Tool` format.
Expand All @@ -44,6 +46,8 @@ def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
ValueError: If the provided tool is not an instance of `CrewAITool`, or if
any additional arguments are passed.
"""
from crewai.tools import BaseTool as CrewAITool

if not isinstance(tool, CrewAITool):
raise ValueError(f"Expected an instance of `crewai.tools.BaseTool`, got {type(tool)}")
if kwargs:
Expand All @@ -66,3 +70,15 @@ def func(args: crewai_tool.args_schema) -> Any: # type: ignore[no-any-unimporte
description=description,
func=func,
)

@classmethod
def get_unsupported_reason(cls) -> Optional[str]:
if sys.version_info < (3, 10) or sys.version_info >= (3, 13):
return "This submodule is only supported for Python versions 3.10, 3.11, and 3.12"

try:
import crewai.tools
except ImportError:
return "Please install `interop-crewai` extra to use this module:\n\n\tpip install ag2[interop-crewai]"

return None
55 changes: 0 additions & 55 deletions autogen/interop/helpers.py

This file was deleted.

53 changes: 15 additions & 38 deletions autogen/interop/interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Any, Dict, List, Type

from ..tools import Tool
from .helpers import get_all_interoperability_classes
from .interoperable import Interoperable
from .registry import InteroperableRegistry

__all__ = ["Interoperable"]

Expand All @@ -18,18 +18,10 @@ class Interoperability:
for retrieving and registering interoperability classes.
"""

_interoperability_classes: Dict[str, Type[Interoperable]] = get_all_interoperability_classes()
registry = InteroperableRegistry.get_instance()

def __init__(self) -> None:
"""
Initializes an instance of the Interoperability class.
This constructor does not perform any specific actions as the class is primarily used for its class
methods to manage interoperability classes.
"""
pass

def convert_tool(self, *, tool: Any, type: str, **kwargs: Any) -> Tool:
@classmethod
def convert_tool(cls, *, tool: Any, type: str, **kwargs: Any) -> Tool:
"""
Converts a given tool to an instance of a specified interoperability type.
Expand All @@ -44,8 +36,7 @@ def convert_tool(self, *, tool: Any, type: str, **kwargs: Any) -> Tool:
Raises:
ValueError: If the interoperability class for the provided type is not found.
"""
interop_cls = self.get_interoperability_class(type)
interop = interop_cls()
interop = cls.get_interoperability_class(type)
return interop.convert_tool(tool, **kwargs)

@classmethod
Expand All @@ -62,35 +53,21 @@ def get_interoperability_class(cls, type: str) -> Type[Interoperable]:
Raises:
ValueError: If no interoperability class is found for the provided type.
"""
if type not in cls._interoperability_classes:
raise ValueError(f"Interoperability class {type} not found")
return cls._interoperability_classes[type]
supported_types = cls.registry.get_supported_types()
if type not in supported_types:
supported_types_formated = ", ".join(["'t'" for t in supported_types])
raise ValueError(
f"Interoperability class {type} is not supported, supported types: {supported_types_formated}"
)

return cls.registry.get_class(type)

@classmethod
def supported_types(cls) -> List[str]:
def get_supported_types(cls) -> List[str]:
"""
Returns a sorted list of all supported interoperability types.
Returns:
List[str]: A sorted list of strings representing the supported interoperability types.
"""
return sorted(cls._interoperability_classes.keys())

@classmethod
def register_interoperability_class(cls, name: str, interoperability_class: Type[Interoperable]) -> None:
"""
Registers a new interoperability class with the given name.
Args:
name (str): The name to associate with the interoperability class.
interoperability_class (Type[Interoperable]): The class implementing the Interoperable protocol.
Raises:
ValueError: If the provided class does not implement the Interoperable protocol.
"""
if not issubclass(interoperability_class, Interoperable):
raise ValueError(
f"Expected a class implementing `Interoperable` protocol, got {type(interoperability_class)}"
)

cls._interoperability_classes[name] = interoperability_class
return sorted(cls.registry.get_supported_types())
16 changes: 14 additions & 2 deletions autogen/interop/interoperable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Protocol, runtime_checkable
from typing import Any, Optional, Protocol, runtime_checkable

from ..tools import Tool

Expand All @@ -18,7 +18,8 @@ class Interoperable(Protocol):
`convert_tool` to convert a given tool into a desired format or type.
"""

def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
@classmethod
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
"""
Converts a given tool to a desired format or type.
Expand All @@ -32,3 +33,14 @@ def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
Tool: The converted tool in the desired format or type.
"""
...

@classmethod
def get_unsupported_reason(cls) -> Optional[str]:
"""Returns the reason for the tool being unsupported.
This method should be implemented by any class adhering to the `Interoperable` protocol.
Returns:
str: The reason for the interoperability class being unsupported.
"""
...
16 changes: 2 additions & 14 deletions autogen/interop/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

import sys
from .langchain import LangChainInteroperability

if sys.version_info < (3, 9):
raise ImportError("This submodule is only supported for Python versions 3.9 and above")

try:
import langchain.tools
except ImportError:
raise ImportError(
"Please install `interop-langchain` extra to use this module:\n\n\tpip install ag2[interop-langchain]"
)

from .langchain import LangchainInteroperability

__all__ = ["LangchainInteroperability"]
__all__ = ["LangChainInteroperability"]
34 changes: 26 additions & 8 deletions autogen/interop/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any

from langchain_core.tools import BaseTool as LangchainTool
import sys
from typing import Any, Optional

from ...tools import Tool
from ..interoperable import Interoperable
from ..registry import register_interoperable_class

__all__ = ["LangchainInteroperability"]
__all__ = ["LangChainInteroperability"]


class LangchainInteroperability(Interoperable):
@register_interoperable_class("langchain")
class LangChainInteroperability:
"""
A class implementing the `Interoperable` protocol for converting Langchain tools
into a general `Tool` format.
Expand All @@ -22,7 +23,8 @@ class LangchainInteroperability(Interoperable):
the `Tool` format.
"""

def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
@classmethod
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
"""
Converts a given Langchain tool into a general `Tool` format.
Expand All @@ -41,19 +43,35 @@ def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
ValueError: If the provided tool is not an instance of `LangchainTool`, or if
any additional arguments are passed.
"""
from langchain_core.tools import BaseTool as LangchainTool

if not isinstance(tool, LangchainTool):
raise ValueError(f"Expected an instance of `langchain_core.tools.BaseTool`, got {type(tool)}")
if kwargs:
raise ValueError(f"The LangchainInteroperability does not support any additional arguments, got {kwargs}")

# needed for type checking
langchain_tool: LangchainTool = tool # type: ignore[no-any-unimported]
langchain_tool: LangchainTool = tool # type: ignore

def func(tool_input: langchain_tool.args_schema) -> Any: # type: ignore[no-any-unimported]
def func(tool_input: langchain_tool.args_schema) -> Any: # type: ignore
return langchain_tool.run(tool_input.model_dump())

return Tool(
name=langchain_tool.name,
description=langchain_tool.description,
func=func,
)

@classmethod
def get_unsupported_reason(cls) -> Optional[str]:
if sys.version_info < (3, 9):
return "This submodule is only supported for Python versions 3.9 and above"

try:
import langchain_core.tools
except ImportError:
return (
"Please install `interop-langchain` extra to use this module:\n\n\tpip install ag2[interop-langchain]"
)

return None
12 changes: 0 additions & 12 deletions autogen/interop/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

import sys

if sys.version_info < (3, 9):
raise ImportError("This submodule is only supported for Python versions 3.9 and above")

try:
import pydantic_ai.tools
except ImportError:
raise ImportError(
"Please install `interop-pydantic-ai` extra to use this module:\n\n\tpip install ag2[interop-pydantic-ai]"
)

from .pydantic_ai import PydanticAIInteroperability

__all__ = ["PydanticAIInteroperability"]
Loading

0 comments on commit 3f644d0

Please sign in to comment.