Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/add-tool-imports' into add-tool-…
Browse files Browse the repository at this point in the history
…imports-pydantic-ai2
  • Loading branch information
rjambrecic committed Dec 18, 2024
2 parents c5da59d + fd2b089 commit c929a27
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 15 deletions.
5 changes: 3 additions & 2 deletions autogen/interop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from .interoperability import Interoperable
from .interoperability import Interoperability
from .interoperable import Interoperable

__all__ = ["Interoperable"]
__all__ = ["Interoperable", "Interoperability"]
2 changes: 1 addition & 1 deletion autogen/interop/crewai/crewai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from crewai.tools import BaseTool as CrewAITool

from ...tools import Tool
from ..interoperability import Interoperable
from ..interoperable import Interoperable

__all__ = ["CrewAIInteroperability"]

Expand Down
55 changes: 55 additions & 0 deletions autogen/interop/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0
#

import importlib
import inspect
import logging
import pkgutil
import sys
from typing import Any, Dict, List, Set, Type

from .interoperable import Interoperable

logger = logging.getLogger(__name__)


def import_submodules(package_name: str) -> List[str]:
package = importlib.import_module(package_name)
imported_modules: List[str] = []
for loader, module_name, is_pkg in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
try:
importlib.import_module(module_name)

imported_modules.append(module_name)
except Exception as e:
logger.info(f"Error importing {module_name}, most likely perfectly fine: {e}")

return imported_modules


def find_classes_implementing_protocol(imported_modules: List[str], protocol: Type[Any]) -> List[Type[Any]]:
implementing_classes: Set[Type[Any]] = set()
for module in imported_modules:
for _, obj in inspect.getmembers(sys.modules[module], inspect.isclass):
if issubclass(obj, protocol) and obj is not protocol:
implementing_classes.add(obj)

return list(implementing_classes)


def get_all_interoperability_classes() -> Dict[str, Type[Interoperable]]:
imported_modules = import_submodules("autogen.interop")
classes = find_classes_implementing_protocol(imported_modules, Interoperable)

# check that all classes finish with 'Interoperability'
for cls in classes:
if not cls.__name__.endswith("Interoperability"):
raise RuntimeError(f"Class {cls} does not end with 'Interoperability'")

retval = {
cls.__name__.split("Interoperability")[0].lower(): cls for cls in classes if cls.__name__ != "Interoperability"
}

return retval
37 changes: 32 additions & 5 deletions autogen/interop/interoperability.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,40 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Protocol, runtime_checkable
from typing import Any, Dict, List, Type

from ..tools import Tool
from .helpers import get_all_interoperability_classes
from .interoperable import Interoperable

__all__ = ["Interoperable"]

@runtime_checkable
class Interoperable(Protocol):
def convert_tool(self, tool: Any) -> Tool: ...
class Interoperability:
_interoperability_classes: Dict[str, Type[Interoperable]] = get_all_interoperability_classes()

def __init__(self) -> None:
pass

def convert_tool(self, *, tool: Any, type: str) -> Tool:
interop_cls = self.get_interoperability_class(type)
interop = interop_cls()
return interop.convert_tool(tool)

@classmethod
def get_interoperability_class(cls, type: str) -> Type[Interoperable]:
if type not in cls._interoperability_classes:
raise ValueError(f"Interoperability class {type} not found")
return cls._interoperability_classes[type]

@classmethod
def supported_types(cls) -> List[str]:
return sorted(cls._interoperability_classes.keys())

@classmethod
def register_interoperability_class(cls, name: str, interoperability_class: Type[Interoperable]) -> None:
if not issubclass(interoperability_class, Interoperable):
raise ValueError(
f"Expected a class implementing `Interoperable` protocol, got {type(interoperability_class)}"
)

cls._interoperability_classes[name] = interoperability_class
14 changes: 14 additions & 0 deletions autogen/interop/interoperable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Protocol, runtime_checkable

from ..tools import Tool

__all__ = ["Interoperable"]


@runtime_checkable
class Interoperable(Protocol):
def convert_tool(self, tool: Any) -> Tool: ...
2 changes: 1 addition & 1 deletion autogen/interop/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langchain_core.tools import BaseTool as LangchainTool

from ...tools import Tool
from ..interoperability import Interoperable
from ..interoperable import Interoperable

__all__ = ["LangchainInteroperability"]

Expand Down
6 changes: 3 additions & 3 deletions test/interop/crewai/test_crewai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@

import os
import sys
import unittest
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock

import pytest
from conftest import reason, skip_openai

if sys.version_info >= (3, 10) and sys.version_info < (3, 13):
from crewai_tools import FileReadTool
else:
FileReadTool = unittest.mock.MagicMock()
FileReadTool = MagicMock()

from autogen import AssistantAgent, UserProxyAgent
from autogen.interop import Interoperable

if sys.version_info >= (3, 10) and sys.version_info < (3, 13):
from autogen.interop.crewai import CrewAIInteroperability
else:
CrewAIInteroperability = unittest.mock.MagicMock()
CrewAIInteroperability = MagicMock()


# skip if python version is not in [3.10, 3.11, 3.12]
Expand Down
62 changes: 62 additions & 0 deletions test/interop/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

import sys

import pytest

from autogen.interop import Interoperability, Interoperable
from autogen.interop.helpers import (
find_classes_implementing_protocol,
get_all_interoperability_classes,
import_submodules,
)


class TestHelpers:
@pytest.fixture(autouse=True)
def setup_method(self) -> None:
self.imported_modules = import_submodules("autogen.interop")

def test_import_submodules(self) -> None:
assert "autogen.interop.helpers" in self.imported_modules

def test_find_classes_implementing_protocol(self) -> None:
actual = find_classes_implementing_protocol(self.imported_modules, Interoperable)
print(f"test_find_classes_implementing_protocol: {actual=}")

assert Interoperability in actual
expected_count = 1

if sys.version_info >= (3, 10) and sys.version_info < (3, 13):
from autogen.interop.crewai import CrewAIInteroperability

assert CrewAIInteroperability in actual
expected_count += 1

if sys.version_info >= (3, 9):
from autogen.interop.langchain import LangchainInteroperability

assert LangchainInteroperability in actual
expected_count += 1

assert len(actual) == expected_count

def test_get_all_interoperability_classes(self) -> None:

actual = get_all_interoperability_classes()

if sys.version_info < (3, 9):
assert actual == {}

if sys.version_info >= (3, 10) and sys.version_info < (3, 13):
from autogen.interop.crewai import CrewAIInteroperability
from autogen.interop.langchain import LangchainInteroperability

assert actual == {"crewai": CrewAIInteroperability, "langchain": LangchainInteroperability}

if (sys.version_info >= (3, 9) and sys.version_info < (3, 10)) and sys.version_info >= (3, 13):
from autogen.interop.langchain import LangchainInteroperability

assert actual == {"langchain": LangchainInteroperability}
81 changes: 78 additions & 3 deletions test/interop/test_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,83 @@
#
# SPDX-License-Identifier: Apache-2.0

from autogen.interop import Interoperable
import sys
from tempfile import TemporaryDirectory
from typing import Any

import pytest

def test_interoperable() -> None:
assert Interoperable is not None
from autogen.interop import Interoperability, Interoperable
from autogen.tools.tool import Tool


class TestInteroperability:
def test_supported_types(self) -> None:
actual = Interoperability.supported_types()

if sys.version_info < (3, 9):
assert actual == []

if sys.version_info >= (3, 9) and sys.version_info < (3, 10):
assert actual == ["langchain"]

if sys.version_info >= (3, 10) and sys.version_info < (3, 13):
assert actual == ["crewai", "langchain"]

if sys.version_info >= (3, 13):
assert actual == ["langchain"]

def test_register_interoperability_class(self) -> None:
org_interoperability_classes = Interoperability._interoperability_classes
try:

class MyInteroperability:
def convert_tool(self, tool: Any) -> Tool:
return Tool(name="test", description="test description", func=tool)

Interoperability.register_interoperability_class("my_interop", MyInteroperability)
assert Interoperability.get_interoperability_class("my_interop") == MyInteroperability

interop = Interoperability()
tool = interop.convert_tool(type="my_interop", tool=lambda x: x)
assert tool.name == "test"
assert tool.description == "test description"
assert tool.func("hello") == "hello"

finally:
Interoperability._interoperability_classes = org_interoperability_classes

@pytest.mark.skipif(
sys.version_info < (3, 10) or sys.version_info >= (3, 13), reason="Only Python 3.10, 3.11, 3.12 are supported"
)
def test_crewai(self) -> None:
from crewai_tools import FileReadTool

crewai_tool = FileReadTool()

interoperability = Interoperability()
tool = interoperability.convert_tool(type="crewai", tool=crewai_tool)

with TemporaryDirectory() as tmp_dir:
file_path = f"{tmp_dir}/test.txt"
with open(file_path, "w") as file:
file.write("Hello, World!")

assert tool.name == "Read_a_file_s_content"
assert (
tool.description
== "A tool that can be used to read a file's content. (IMPORTANT: When using arguments, put them all in an `args` dictionary)"
)

model_type = crewai_tool.args_schema

args = model_type(file_path=file_path)

assert tool.func(args=args) == "Hello, World!"

@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@pytest.mark.skip(reason="This test is not yet implemented")
def test_langchain(self) -> None:
raise NotImplementedError("This test is not yet implemented")
9 changes: 9 additions & 0 deletions test/interop/test_interoperable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from autogen.interop import Interoperable


def test_interoperable() -> None:
assert Interoperable is not None

0 comments on commit c929a27

Please sign in to comment.