-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #238 from ag2ai/add-tool-imports-refactoring-rj
Add tool imports refactoring
- Loading branch information
Showing
7 changed files
with
87 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
from .pydantic_ai_tool import PydanticAITool | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .tool import Tool | ||
|
||
__all__ = ["PydanticAITool", "Tool"] | ||
__all__ = ["Tool"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import sys | ||
import unittest | ||
|
||
import pytest | ||
|
||
from autogen import AssistantAgent | ||
|
||
if sys.version_info >= (3, 9): | ||
from pydantic_ai.tools import Tool | ||
|
||
from autogen.interop.pydantic_ai.pydantic_ai_tool import PydanticAITool as AG2PydanticAITool | ||
else: | ||
Tool = unittest.mock.MagicMock() | ||
AG2PydanticAITool = unittest.mock.MagicMock() | ||
|
||
|
||
# skip if python version is not >= 3.9 | ||
@pytest.mark.skipif( | ||
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability" | ||
) | ||
class TestPydanticAITool: | ||
def test_register_for_llm(self) -> None: | ||
def foobar(a: int, b: str, c: dict[str, list[float]]) -> str: # type: ignore[misc] | ||
"""Get me foobar. | ||
Args: | ||
a: apple pie | ||
b: banana cake | ||
c: carrot smoothie | ||
""" | ||
return f"{a} {b} {c}" | ||
|
||
tool = Tool(foobar) | ||
ag2_tool = AG2PydanticAITool( | ||
name=tool.name, | ||
description=tool.description, | ||
func=tool.function, | ||
parameters_json_schema=tool._parameters_json_schema, | ||
) | ||
config_list = [{"model": "gpt-4o", "api_key": "abc"}] | ||
chatbot = AssistantAgent( | ||
name="chatbot", | ||
llm_config={"config_list": config_list}, | ||
) | ||
ag2_tool.register_for_llm(chatbot) | ||
expected_tools = [ | ||
{ | ||
"type": "function", | ||
"function": { | ||
"name": "foobar", | ||
"description": "Get me foobar.", | ||
"parameters": { | ||
"properties": { | ||
"a": {"description": "apple pie", "title": "A", "type": "integer"}, | ||
"b": {"description": "banana cake", "title": "B", "type": "string"}, | ||
"c": { | ||
"additionalProperties": {"items": {"type": "number"}, "type": "array"}, | ||
"description": "carrot smoothie", | ||
"title": "C", | ||
"type": "object", | ||
}, | ||
}, | ||
"required": ["a", "b", "c"], | ||
"type": "object", | ||
"additionalProperties": False, | ||
}, | ||
}, | ||
} | ||
] | ||
assert chatbot.llm_config["tools"] == expected_tools # type: ignore[index] |