Skip to content

Commit

Permalink
refactoring: making convert_tool a class method
Browse files Browse the repository at this point in the history
  • Loading branch information
davorrunje committed Dec 19, 2024
1 parent 94c2378 commit 2de4a90
Show file tree
Hide file tree
Showing 11 changed files with 66 additions and 60 deletions.
2 changes: 1 addition & 1 deletion autogen/interop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .crewai import CrewAIInteroperability
from .interoperability import Interoperability
from .interoperable import Interoperable
from .langchain import LangchainInteroperability
from .langchain import LangChainInteroperability
from .pydantic_ai import PydanticAIInteroperability
from .registry import register_interoperable_class

Expand Down
5 changes: 3 additions & 2 deletions autogen/interop/crewai/crewai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import re
import sys
from typing import Any, Optional, cast
from typing import Any, Optional

from ...tools import Tool
from ..interoperable import Interoperable
Expand All @@ -26,7 +26,8 @@ class CrewAIInteroperability(Interoperable):
This class takes a `CrewAITool` and converts it into a standard `Tool` object.
"""

def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
@classmethod
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
"""
Converts a given CrewAI tool into a general `Tool` format.
Expand Down
27 changes: 11 additions & 16 deletions autogen/interop/interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,10 @@ class Interoperability:
for retrieving and registering interoperability classes.
"""

def __init__(self) -> None:
"""
Initializes an instance of the Interoperability class.
This constructor does not perform any specific actions as the class is primarily used for its class
methods to manage interoperability classes.
"""
self.registry = InteroperableRegistry.get_instance()
registry = InteroperableRegistry.get_instance()

def convert_tool(self, *, tool: Any, type: str, **kwargs: Any) -> Tool:
@classmethod
def convert_tool(cls, *, tool: Any, type: str, **kwargs: Any) -> Tool:
"""
Converts a given tool to an instance of a specified interoperability type.
Expand All @@ -42,11 +36,11 @@ def convert_tool(self, *, tool: Any, type: str, **kwargs: Any) -> Tool:
Raises:
ValueError: If the interoperability class for the provided type is not found.
"""
interop_cls = self.get_interoperability_class(type)
interop = interop_cls()
interop = cls.get_interoperability_class(type)
return interop.convert_tool(tool, **kwargs)

def get_interoperability_class(self, type: str) -> Type[Interoperable]:
@classmethod
def get_interoperability_class(cls, type: str) -> Type[Interoperable]:
"""
Retrieves the interoperability class corresponding to the specified type.
Expand All @@ -59,20 +53,21 @@ def get_interoperability_class(self, type: str) -> Type[Interoperable]:
Raises:
ValueError: If no interoperability class is found for the provided type.
"""
supported_types = self.registry.get_supported_types()
supported_types = cls.registry.get_supported_types()
if type not in supported_types:
supported_types_formated = ", ".join(["'t'" for t in supported_types])
raise ValueError(
f"Interoperability class {type} is not supported, supported types: {supported_types_formated}"
)

return self.registry.get_class(type)
return cls.registry.get_class(type)

def get_supported_types(self) -> List[str]:
@classmethod
def get_supported_types(cls) -> List[str]:
"""
Returns a sorted list of all supported interoperability types.
Returns:
List[str]: A sorted list of strings representing the supported interoperability types.
"""
return sorted(self.registry.get_supported_types())
return sorted(cls.registry.get_supported_types())
3 changes: 2 additions & 1 deletion autogen/interop/interoperable.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class Interoperable(Protocol):
`convert_tool` to convert a given tool into a desired format or type.
"""

def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
@classmethod
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
"""
Converts a given tool to a desired format or type.
Expand Down
4 changes: 2 additions & 2 deletions autogen/interop/langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
#
# SPDX-License-Identifier: Apache-2.0

from .langchain import LangchainInteroperability
from .langchain import LangChainInteroperability

__all__ = ["LangchainInteroperability"]
__all__ = ["LangChainInteroperability"]
7 changes: 4 additions & 3 deletions autogen/interop/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from ..interoperable import Interoperable
from ..registry import register_interoperable_class

__all__ = ["LangchainInteroperability"]
__all__ = ["LangChainInteroperability"]


@register_interoperable_class("langchain")
class LangchainInteroperability(Interoperable):
class LangChainInteroperability(Interoperable):
"""
A class implementing the `Interoperable` protocol for converting Langchain tools
into a general `Tool` format.
Expand All @@ -23,7 +23,8 @@ class LangchainInteroperability(Interoperable):
the `Tool` format.
"""

def convert_tool(self, tool: Any, **kwargs: Any) -> Tool:
@classmethod
def convert_tool(cls, tool: Any, **kwargs: Any) -> Tool:
"""
Converts a given Langchain tool into a general `Tool` format.
Expand Down
4 changes: 3 additions & 1 deletion autogen/interop/pydantic_ai/pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class PydanticAIInteroperability(Interoperable):
into the tool's function.
"""

@staticmethod
def inject_params(
ctx: Any,
tool: Any,
Expand Down Expand Up @@ -86,7 +87,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:

return wrapper

def convert_tool(self, tool: Any, deps: Any = None, **kwargs: Any) -> AG2PydanticAITool:
@classmethod
def convert_tool(cls, tool: Any, deps: Any = None, **kwargs: Any) -> AG2PydanticAITool:
"""
Converts a given Pydantic AI tool into a general `Tool` format.
Expand Down
15 changes: 5 additions & 10 deletions test/interop/crewai/test_crewai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,18 @@
class TestCrewAIInteroperability:
@pytest.fixture(autouse=True)
def setup(self) -> None:
self.crewai_interop = CrewAIInteroperability()

crewai_tool = FileReadTool()
self.model_type = crewai_tool.args_schema
self.tool = self.crewai_interop.convert_tool(crewai_tool)
self.tool = CrewAIInteroperability.convert_tool(crewai_tool)

def test_type_checks(self) -> None:
# mypy should fail if the type checks are not correct
interop: Interoperable = self.crewai_interop
interop: Interoperable = CrewAIInteroperability()

# runtime check
assert isinstance(interop, Interoperable)

def test_init(self) -> None:
assert isinstance(self.crewai_interop, Interoperable)

def test_convert_tool(self) -> None:
with TemporaryDirectory() as tmp_dir:
file_path = f"{tmp_dir}/test.txt"
Expand Down Expand Up @@ -95,17 +92,15 @@ def test_with_llm(self) -> None:
assert False, "Tool response not found in chat messages"

def test_get_unsupported_reason(self) -> None:
crewai_interop = CrewAIInteroperability()
assert crewai_interop.get_unsupported_reason() is None
assert CrewAIInteroperability.get_unsupported_reason() is None


@pytest.mark.skipif(
sys.version_info >= (3, 10) or sys.version_info < (3, 13), reason="Crew AI Interoperability is supported"
)
class TestCrewAIInteroperabilityIfNotSupported:
def test_get_unsupported_reason(self) -> None:
crewai_interop = CrewAIInteroperability()
assert (
crewai_interop.get_unsupported_reason()
CrewAIInteroperability.get_unsupported_reason()
== "This submodule is only supported for Python versions 3.10, 3.11, and 3.12"
)
32 changes: 20 additions & 12 deletions test/interop/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
if sys.version_info >= (3, 9):
from langchain.tools import tool

from autogen.interop.langchain import LangchainInteroperability
from autogen.interop.langchain import LangChainInteroperability
else:
tool = unittest.mock.MagicMock()
LangchainInteroperability = unittest.mock.MagicMock()
LangChainInteroperability = unittest.mock.MagicMock()


# skip if python version is not >= 3.9
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
class TestLangchainInteroperability:
class TestLangChainInteroperability:
@pytest.fixture(autouse=True)
def setup(self) -> None:
class SearchInput(BaseModel):
Expand All @@ -37,19 +37,16 @@ def search(query: SearchInput) -> str:
"""Look up things online."""
return "LangChain Integration"

self.langchain_interop = LangchainInteroperability()
self.model_type = search.args_schema
self.tool = self.langchain_interop.convert_tool(search)
self.tool = LangChainInteroperability.convert_tool(search)

def test_type_checks(self) -> None:
# mypy should fail if the type checks are not correct
interop: Interoperable = self.langchain_interop
interop: Interoperable = LangChainInteroperability()

# runtime check
assert isinstance(interop, Interoperable)

def test_init(self) -> None:
assert isinstance(self.langchain_interop, Interoperable)

def test_convert_tool(self) -> None:
assert self.tool.name == "search-tool"
assert self.tool.description == "Look up things online."
Expand Down Expand Up @@ -82,21 +79,23 @@ def test_with_llm(self) -> None:

assert False, "No tool response found in chat messages"

def test_get_unsupported_reason(self) -> None:
assert LangChainInteroperability.get_unsupported_reason() is None


# skip if python version is not >= 3.9
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
class TestLangchainInteroperabilityWithoutPydanticInput:
class TestLangChainInteroperabilityWithoutPydanticInput:
@pytest.fixture(autouse=True)
def setup(self) -> None:
@tool
def search(query: str, max_length: int) -> str:
"""Look up things online."""
return f"LangChain Integration, max_length: {max_length}"

self.langchain_interop = LangchainInteroperability()
self.tool = self.langchain_interop.convert_tool(search)
self.tool = LangChainInteroperability.convert_tool(search)
self.model_type = search.args_schema

def test_convert_tool(self) -> None:
Expand Down Expand Up @@ -130,3 +129,12 @@ def test_with_llm(self) -> None:
return

assert False, "No tool response found in chat messages"


@pytest.mark.skipif(sys.version_info >= (3, 9), reason="LangChain Interoperability is supported")
class TestLangChainInteroperabilityIfNotSupported:
def test_get_unsupported_reason(self) -> None:
assert (
LangChainInteroperability.get_unsupported_reason()
== "This submodule is only supported for Python versions 3.9 and above"
)
22 changes: 13 additions & 9 deletions test/interop/pydantic_ai/test_pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,15 @@ def roll_dice() -> str:
"""Roll a six-sided dice and return the result."""
return str(random.randint(1, 6))

self.pydantic_ai_interop = PydanticAIInteroperability()
pydantic_ai_tool = PydanticAITool(roll_dice, max_retries=3)
self.tool = self.pydantic_ai_interop.convert_tool(pydantic_ai_tool)
self.tool = PydanticAIInteroperability.convert_tool(pydantic_ai_tool)

def test_type_checks(self) -> None:
# mypy should fail if the type checks are not correct
interop: Interoperable = self.pydantic_ai_interop
interop: Interoperable = PydanticAIInteroperability()
# runtime check
assert isinstance(interop, Interoperable)

def test_init(self) -> None:
assert isinstance(self.pydantic_ai_interop, Interoperable)

def test_convert_tool(self) -> None:
assert self.tool.name == "roll_dice"
assert self.tool.description == "Roll a six-sided dice and return the result."
Expand Down Expand Up @@ -162,14 +158,13 @@ def get_player(ctx: RunContext[Player], additional_info: Optional[str] = None) -
"""
return f"Name: {ctx.deps.name}, Age: {ctx.deps.age}, Additional info: {additional_info}" # type: ignore[attr-defined]

self.pydantic_ai_interop = PydanticAIInteroperability()
self.pydantic_ai_tool = PydanticAITool(get_player, takes_ctx=True)
player = Player(name="Luka", age=25)
self.tool = self.pydantic_ai_interop.convert_tool(tool=self.pydantic_ai_tool, deps=player)
self.tool = PydanticAIInteroperability.convert_tool(tool=self.pydantic_ai_tool, deps=player)

def test_convert_tool_raises_error_if_take_ctx_is_true_and_deps_is_none(self) -> None:
with pytest.raises(ValueError, match="If the tool takes a context, the `deps` argument must be provided"):
self.pydantic_ai_interop.convert_tool(tool=self.pydantic_ai_tool, deps=None)
PydanticAIInteroperability.convert_tool(tool=self.pydantic_ai_tool, deps=None)

def test_expected_tools(self) -> None:
config_list = [{"model": "gpt-4o", "api_key": "abc"}]
Expand Down Expand Up @@ -229,3 +224,12 @@ def test_with_llm(self) -> None:
return

assert False, "No tool response found in chat messages"


@pytest.mark.skipif(sys.version_info >= (3, 9), reason="LangChain Interoperability is supported")
class TestPydanticAIInteroperabilityIfNotSupported:
def test_get_unsupported_reason(self) -> None:
assert (
PydanticAIInteroperability.get_unsupported_reason()
== "This submodule is only supported for Python versions 3.9 and above"
)
5 changes: 2 additions & 3 deletions test/interop/test_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class TestInteroperability:
def test_supported_types(self) -> None:
actual = Interoperability().get_supported_types()
actual = Interoperability.get_supported_types()

if sys.version_info < (3, 9):
assert actual == []
Expand All @@ -35,8 +35,7 @@ def test_crewai(self) -> None:

crewai_tool = FileReadTool()

interoperability = Interoperability()
tool = interoperability.convert_tool(type="crewai", tool=crewai_tool)
tool = Interoperability.convert_tool(type="crewai", tool=crewai_tool)

with TemporaryDirectory() as tmp_dir:
file_path = f"{tmp_dir}/test.txt"
Expand Down

0 comments on commit 2de4a90

Please sign in to comment.