-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/add-tool-imports' into add-tool-…
…imports-pydantic-ai2
- Loading branch information
Showing
10 changed files
with
258 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
|
||
import importlib | ||
import inspect | ||
import logging | ||
import pkgutil | ||
import sys | ||
from typing import Any, Dict, List, Set, Type | ||
|
||
from .interoperable import Interoperable | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def import_submodules(package_name: str) -> List[str]: | ||
package = importlib.import_module(package_name) | ||
imported_modules: List[str] = [] | ||
for loader, module_name, is_pkg in pkgutil.walk_packages(package.__path__, package.__name__ + "."): | ||
try: | ||
importlib.import_module(module_name) | ||
|
||
imported_modules.append(module_name) | ||
except Exception as e: | ||
logger.info(f"Error importing {module_name}, most likely perfectly fine: {e}") | ||
|
||
return imported_modules | ||
|
||
|
||
def find_classes_implementing_protocol(imported_modules: List[str], protocol: Type[Any]) -> List[Type[Any]]: | ||
implementing_classes: Set[Type[Any]] = set() | ||
for module in imported_modules: | ||
for _, obj in inspect.getmembers(sys.modules[module], inspect.isclass): | ||
if issubclass(obj, protocol) and obj is not protocol: | ||
implementing_classes.add(obj) | ||
|
||
return list(implementing_classes) | ||
|
||
|
||
def get_all_interoperability_classes() -> Dict[str, Type[Interoperable]]: | ||
imported_modules = import_submodules("autogen.interop") | ||
classes = find_classes_implementing_protocol(imported_modules, Interoperable) | ||
|
||
# check that all classes finish with 'Interoperability' | ||
for cls in classes: | ||
if not cls.__name__.endswith("Interoperability"): | ||
raise RuntimeError(f"Class {cls} does not end with 'Interoperability'") | ||
|
||
retval = { | ||
cls.__name__.split("Interoperability")[0].lower(): cls for cls in classes if cls.__name__ != "Interoperability" | ||
} | ||
|
||
return retval |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,40 @@ | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Any, Protocol, runtime_checkable | ||
from typing import Any, Dict, List, Type | ||
|
||
from ..tools import Tool | ||
from .helpers import get_all_interoperability_classes | ||
from .interoperable import Interoperable | ||
|
||
__all__ = ["Interoperable"] | ||
|
||
@runtime_checkable | ||
class Interoperable(Protocol): | ||
def convert_tool(self, tool: Any) -> Tool: ... | ||
class Interoperability: | ||
_interoperability_classes: Dict[str, Type[Interoperable]] = get_all_interoperability_classes() | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
def convert_tool(self, *, tool: Any, type: str) -> Tool: | ||
interop_cls = self.get_interoperability_class(type) | ||
interop = interop_cls() | ||
return interop.convert_tool(tool) | ||
|
||
@classmethod | ||
def get_interoperability_class(cls, type: str) -> Type[Interoperable]: | ||
if type not in cls._interoperability_classes: | ||
raise ValueError(f"Interoperability class {type} not found") | ||
return cls._interoperability_classes[type] | ||
|
||
@classmethod | ||
def supported_types(cls) -> List[str]: | ||
return sorted(cls._interoperability_classes.keys()) | ||
|
||
@classmethod | ||
def register_interoperability_class(cls, name: str, interoperability_class: Type[Interoperable]) -> None: | ||
if not issubclass(interoperability_class, Interoperable): | ||
raise ValueError( | ||
f"Expected a class implementing `Interoperable` protocol, got {type(interoperability_class)}" | ||
) | ||
|
||
cls._interoperability_classes[name] = interoperability_class |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Any, Protocol, runtime_checkable | ||
|
||
from ..tools import Tool | ||
|
||
__all__ = ["Interoperable"] | ||
|
||
|
||
@runtime_checkable | ||
class Interoperable(Protocol): | ||
def convert_tool(self, tool: Any) -> Tool: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import sys | ||
|
||
import pytest | ||
|
||
from autogen.interop import Interoperability, Interoperable | ||
from autogen.interop.helpers import ( | ||
find_classes_implementing_protocol, | ||
get_all_interoperability_classes, | ||
import_submodules, | ||
) | ||
|
||
|
||
class TestHelpers: | ||
@pytest.fixture(autouse=True) | ||
def setup_method(self) -> None: | ||
self.imported_modules = import_submodules("autogen.interop") | ||
|
||
def test_import_submodules(self) -> None: | ||
assert "autogen.interop.helpers" in self.imported_modules | ||
|
||
def test_find_classes_implementing_protocol(self) -> None: | ||
actual = find_classes_implementing_protocol(self.imported_modules, Interoperable) | ||
print(f"test_find_classes_implementing_protocol: {actual=}") | ||
|
||
assert Interoperability in actual | ||
expected_count = 1 | ||
|
||
if sys.version_info >= (3, 10) and sys.version_info < (3, 13): | ||
from autogen.interop.crewai import CrewAIInteroperability | ||
|
||
assert CrewAIInteroperability in actual | ||
expected_count += 1 | ||
|
||
if sys.version_info >= (3, 9): | ||
from autogen.interop.langchain import LangchainInteroperability | ||
|
||
assert LangchainInteroperability in actual | ||
expected_count += 1 | ||
|
||
assert len(actual) == expected_count | ||
|
||
def test_get_all_interoperability_classes(self) -> None: | ||
|
||
actual = get_all_interoperability_classes() | ||
|
||
if sys.version_info < (3, 9): | ||
assert actual == {} | ||
|
||
if sys.version_info >= (3, 10) and sys.version_info < (3, 13): | ||
from autogen.interop.crewai import CrewAIInteroperability | ||
from autogen.interop.langchain import LangchainInteroperability | ||
|
||
assert actual == {"crewai": CrewAIInteroperability, "langchain": LangchainInteroperability} | ||
|
||
if (sys.version_info >= (3, 9) and sys.version_info < (3, 10)) and sys.version_info >= (3, 13): | ||
from autogen.interop.langchain import LangchainInteroperability | ||
|
||
assert actual == {"langchain": LangchainInteroperability} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from autogen.interop import Interoperable | ||
|
||
|
||
def test_interoperable() -> None: | ||
assert Interoperable is not None |