Skip to content

Commit

Permalink
Fix function tools (autogenhub#57)
Browse files Browse the repository at this point in the history
* Fix function tools

* lint
  • Loading branch information
jackgerrits authored Jun 7, 2024
1 parent 06ba5d3 commit c6360fe
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 33 deletions.
28 changes: 6 additions & 22 deletions src/agnext/components/_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
)
Expand Down Expand Up @@ -67,7 +68,8 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return_annotation = get_typed_annotation(signature.return_annotation, globalns)
typed_signature = inspect.Signature(typed_params, return_annotation=return_annotation)
return typed_signature


Expand Down Expand Up @@ -313,7 +315,7 @@ def normalize_annotated_type(type_hint: Type[Any]) -> Type[Any]:


def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
fields: List[tuple[str, Any]] = []
fields: Dict[str, tuple[Type[Any], Any]] = {}
for name, param in sig.parameters.items():
# This is handled externally
if name == "cancellation_token":
Expand All @@ -326,24 +328,6 @@ def args_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[Ba
description = type2description(name, param.annotation)
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined

fields.append((name, (type, Field(default=default_value, description=description))))
fields[name] = (type, Field(default=default_value, description=description))

return create_model(name, *fields)


def return_value_base_model_from_signature(name: str, sig: inspect.Signature) -> Type[BaseModel]:
if issubclass(BaseModel, sig.return_annotation):
return sig.return_annotation # type: ignore

fields: List[tuple[str, Any]] = []
for name, param in sig.return_annotation:
if param.annotation is inspect.Parameter.empty:
raise ValueError("No annotation")

type = normalize_annotated_type(param.annotation)
description = type2description(name, param.annotation)
default_value = param.default if param.default is not inspect.Parameter.empty else PydanticUndefined

fields.append((name, (type, Field(default=default_value, description=description))))

return create_model(name, *fields)
return cast(BaseModel, create_model(name, **fields)) # type: ignore
20 changes: 16 additions & 4 deletions src/agnext/components/tools/_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypeVar

Expand All @@ -20,11 +21,13 @@ def schema(self) -> Mapping[str, Any]: ...

def args_type(self) -> Type[BaseModel]: ...

def return_type(self) -> Type[BaseModel]: ...
def return_type(self) -> Type[Any]: ...

def state_type(self) -> Type[BaseModel] | None: ...

async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> BaseModel: ...
def return_value_as_string(self, value: Any) -> str: ...

async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any: ...

def save_state_json(self) -> Mapping[str, Any]: ...

Expand Down Expand Up @@ -63,16 +66,25 @@ def description(self) -> str:
def args_type(self) -> Type[BaseModel]:
return self._args_type

def return_type(self) -> Type[BaseModel]:
def return_type(self) -> Type[Any]:
return self._return_type

def state_type(self) -> Type[BaseModel] | None:
return None

def return_value_as_string(self, value: Any) -> str:
if isinstance(value, BaseModel):
dumped = value.model_dump()
if isinstance(dumped, dict):
return json.dumps(dumped)
return str(dumped)

return str(value)

@abstractmethod
async def run(self, args: ArgsT, cancellation_token: CancellationToken) -> ReturnT: ...

async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> BaseModel:
async def run_json(self, args: Mapping[str, Any], cancellation_token: CancellationToken) -> Any:
return_value = await self.run(self._args_type.model_validate(args), cancellation_token)
return return_value

Expand Down
9 changes: 4 additions & 5 deletions src/agnext/components/tools/_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .._function_utils import (
args_base_model_from_signature,
get_typed_signature,
return_value_base_model_from_signature,
)
from ._base import BaseTool

Expand All @@ -19,12 +18,12 @@ def __init__(self, func: Callable[..., Any], description: str, name: str | None
signature = get_typed_signature(func)
func_name = name or func.__name__
args_model = args_base_model_from_signature(func_name + "args", signature)
return_model = return_value_base_model_from_signature(func_name + "return", signature)
return_type = signature.return_annotation
self._has_cancellation_support = "cancellation_token" in signature.parameters

super().__init__(args_model, return_model, func_name, description)
super().__init__(args_model, return_type, func_name, description)

async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> BaseModel:
async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
if asyncio.iscoroutinefunction(self._func):
if self._has_cancellation_support:
result = await self._func(**args.model_dump(), cancellation_token=cancellation_token)
Expand All @@ -42,5 +41,5 @@ async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> B
cancellation_token.link_future(future)
result = await future

assert isinstance(result, BaseModel)
assert isinstance(result, self.return_type())
return result
187 changes: 185 additions & 2 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@

import inspect
from typing import Annotated

import pytest
from agnext.components.tools import BaseTool
from agnext.components._function_utils import get_typed_signature
from agnext.components.tools import BaseTool, FunctionTool
from agnext.core import CancellationToken
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_serializer
from pydantic_core import PydanticUndefined


class MyArgs(BaseModel):
Expand Down Expand Up @@ -58,3 +63,181 @@ def test_tool_properties()-> None:
assert tool.args_type() == MyArgs
assert tool.return_type() == MyResult
assert tool.state_type() is None

def test_get_typed_signature()-> None:
def my_function() -> str:
return "result"

sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 0
assert sig.return_annotation == str

def test_get_typed_signature_annotated()-> None:
def my_function() -> Annotated[str, "The return type"]:
return "result"

sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 0
assert sig.return_annotation == Annotated[str, "The return type"]

def test_get_typed_signature_string()-> None:
def my_function() -> "str":
return "result"

sig = get_typed_signature(my_function)
assert isinstance(sig, inspect.Signature)
assert len(sig.parameters) == 0
assert sig.return_annotation == str


def test_func_tool()-> None:
def my_function() -> str:
return "result"

tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert issubclass(tool.return_type(), str)
assert tool.state_type() is None

def test_func_tool_annotated_arg()-> None:
def my_function(my_arg: Annotated[str, "test description"]) -> str:
return "result"

tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert issubclass(tool.return_type(), str)
assert tool.args_type().model_fields["my_arg"].description == "test description"
assert tool.args_type().model_fields["my_arg"].annotation == str
assert tool.args_type().model_fields["my_arg"].is_required() is True
assert tool.args_type().model_fields["my_arg"].default is PydanticUndefined
assert len(tool.args_type().model_fields) == 1
assert tool.return_type() == str
assert tool.state_type() is None

def test_func_tool_return_annotated()-> None:
def my_function() -> Annotated[str, "test description"]:
return "result"

tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert tool.return_type() == Annotated[str, "test description"]
assert tool.state_type() is None

def test_func_tool_no_args()-> None:
def my_function() -> str:
return "result"

tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert len(tool.args_type().model_fields) == 0
assert tool.return_type() == str
assert tool.state_type() is None

def test_func_tool_return_none()-> None:
def my_function() -> None:
return None

tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert tool.return_type() is None
assert tool.state_type() is None

def test_func_tool_return_base_model()-> None:
def my_function() -> MyResult:
return MyResult(result="value")

tool = FunctionTool(my_function, description="Function tool.")
assert tool.name == "my_function"
assert tool.description == "Function tool."
assert issubclass(tool.args_type(), BaseModel)
assert tool.return_type() is MyResult
assert tool.state_type() is None

@pytest.mark.asyncio
async def test_func_call_tool()-> None:
def my_function() -> str:
return "result"

tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({}, CancellationToken())
assert result == "result"

@pytest.mark.asyncio
async def test_func_call_tool_base_model()-> None:
def my_function() -> MyResult:
return MyResult(result="value")

tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({}, CancellationToken())
assert isinstance(result, MyResult)
assert result.result == "value"


@pytest.mark.asyncio
async def test_func_call_tool_with_arg_base_model()-> None:
def my_function(arg: str) -> MyResult:
return MyResult(result="value")

tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert isinstance(result, MyResult)
assert result.result == "value"

@pytest.mark.asyncio
async def test_func_str_res()-> None:
def my_function(arg: str) -> str:
return "test"

tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert tool.return_value_as_string(result) == "test"

@pytest.mark.asyncio
async def test_func_base_model_res()-> None:


def my_function(arg: str) -> MyResult:
return MyResult(result="test")

tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert tool.return_value_as_string(result) == '{"result": "test"}'

@pytest.mark.asyncio
async def test_func_base_model_custom_dump_res()-> None:

class MyResultCustomDump(BaseModel):
result: str = Field(description="The other description.")

@model_serializer
def ser_model(self) -> str:
return "custom: " + self.result


def my_function(arg: str) -> MyResultCustomDump:
return MyResultCustomDump(result="test")

tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": "test"}, CancellationToken())
assert tool.return_value_as_string(result) == "custom: test"

@pytest.mark.asyncio
async def test_func_int_res()-> None:
def my_function(arg: int) -> int:
return arg

tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": 5}, CancellationToken())
assert tool.return_value_as_string(result) == "5"

0 comments on commit c6360fe

Please sign in to comment.