Skip to content

Commit

Permalink
Merge pull request #238 from ag2ai/add-tool-imports-refactoring-rj
Browse files Browse the repository at this point in the history
Add tool imports refactoring
  • Loading branch information
davorrunje authored Dec 19, 2024
2 parents 5adbb53 + 3f644d0 commit d5f10ef
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 12 deletions.
2 changes: 1 addition & 1 deletion autogen/interop/crewai/crewai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def _sanitize_name(s: str) -> str:


@register_interoperable_class("crewai")
class CrewAIInteroperability(Interoperable):
class CrewAIInteroperability:
"""
A class implementing the `Interoperable` protocol for converting CrewAI tools
to a general `Tool` format.
Expand Down
2 changes: 1 addition & 1 deletion autogen/interop/langchain/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@register_interoperable_class("langchain")
class LangChainInteroperability(Interoperable):
class LangChainInteroperability:
"""
A class implementing the `Interoperable` protocol for converting Langchain tools
into a general `Tool` format.
Expand Down
4 changes: 2 additions & 2 deletions autogen/interop/pydantic_ai/pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from inspect import signature
from typing import Any, Callable, Optional

from ...tools import PydanticAITool as AG2PydanticAITool
from ..interoperability import Interoperable
from ..registry import register_interoperable_class
from .pydantic_ai_tool import PydanticAITool as AG2PydanticAITool

__all__ = ["PydanticAIInteroperability"]


@register_interoperable_class("pydanticai")
class PydanticAIInteroperability(Interoperable):
class PydanticAIInteroperability:
"""
A class implementing the `Interoperable` protocol for converting Pydantic AI tools
into a general `Tool` format.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from typing import Any, Callable, Dict

from autogen.agentchat.conversable_agent import ConversableAgent

from .tool import Tool
from ...agentchat.conversable_agent import ConversableAgent
from ...tools import Tool

__all__ = ["PydanticAITool"]

Expand Down
7 changes: 5 additions & 2 deletions autogen/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .pydantic_ai_tool import PydanticAITool
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

from .tool import Tool

__all__ = ["PydanticAITool", "Tool"]
__all__ = ["Tool"]
5 changes: 2 additions & 3 deletions autogen/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict
from unittest.mock import MagicMock
from typing import Any, Callable

from autogen.agentchat.conversable_agent import ConversableAgent
from ..agentchat.conversable_agent import ConversableAgent

__all__ = ["Tool"]

Expand Down
74 changes: 74 additions & 0 deletions test/interop/pydantic_ai/test_pydantic_ai_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai
#
# SPDX-License-Identifier: Apache-2.0

import sys
import unittest

import pytest

from autogen import AssistantAgent

if sys.version_info >= (3, 9):
from pydantic_ai.tools import Tool

from autogen.interop.pydantic_ai.pydantic_ai_tool import PydanticAITool as AG2PydanticAITool
else:
Tool = unittest.mock.MagicMock()
AG2PydanticAITool = unittest.mock.MagicMock()


# skip if python version is not >= 3.9
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
class TestPydanticAITool:
def test_register_for_llm(self) -> None:
def foobar(a: int, b: str, c: dict[str, list[float]]) -> str: # type: ignore[misc]
"""Get me foobar.
Args:
a: apple pie
b: banana cake
c: carrot smoothie
"""
return f"{a} {b} {c}"

tool = Tool(foobar)
ag2_tool = AG2PydanticAITool(
name=tool.name,
description=tool.description,
func=tool.function,
parameters_json_schema=tool._parameters_json_schema,
)
config_list = [{"model": "gpt-4o", "api_key": "abc"}]
chatbot = AssistantAgent(
name="chatbot",
llm_config={"config_list": config_list},
)
ag2_tool.register_for_llm(chatbot)
expected_tools = [
{
"type": "function",
"function": {
"name": "foobar",
"description": "Get me foobar.",
"parameters": {
"properties": {
"a": {"description": "apple pie", "title": "A", "type": "integer"},
"b": {"description": "banana cake", "title": "B", "type": "string"},
"c": {
"additionalProperties": {"items": {"type": "number"}, "type": "array"},
"description": "carrot smoothie",
"title": "C",
"type": "object",
},
},
"required": ["a", "b", "c"],
"type": "object",
"additionalProperties": False,
},
},
}
]
assert chatbot.llm_config["tools"] == expected_tools # type: ignore[index]

0 comments on commit d5f10ef

Please sign in to comment.