From 8cb530f65e59edbfd305252afd5af5ab7cdecf7d Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Wed, 5 Jun 2024 08:51:49 -0400 Subject: [PATCH] Simplify handler decorator (#50) * Simplify handler decorator * add more tests * mypy * formatting * fix 3.10 and improve type handling of decorator * test fix * format --- examples/futures.py | 4 +- .../chat/agents/chat_completion_agent.py | 8 +- src/agnext/chat/agents/oai_assistant.py | 6 +- src/agnext/chat/patterns/group_chat.py | 8 +- src/agnext/chat/patterns/orchestrator_chat.py | 2 +- src/agnext/chat/patterns/two_agent_chat.py | 18 +++ src/agnext/components/type_routed_agent.py | 129 ++++++++++++++++-- tests/test_cancellation.py | 4 +- tests/test_intervention.py | 2 +- tests/test_types.py | 39 ++++++ 10 files changed, 191 insertions(+), 29 deletions(-) create mode 100644 src/agnext/chat/patterns/two_agent_chat.py create mode 100644 tests/test_types.py diff --git a/examples/futures.py b/examples/futures.py index 6547679486d1..9480757982cf 100644 --- a/examples/futures.py +++ b/examples/futures.py @@ -16,7 +16,7 @@ class Inner(TypeRoutedAgent): def __init__(self, name: str, router: AgentRuntime) -> None: super().__init__(name, router) - @message_handler(MessageType) + @message_handler() async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: return MessageType(body=f"Inner: {message.body}", sender=self.name) @@ -26,7 +26,7 @@ def __init__(self, name: str, router: AgentRuntime, inner: Agent) -> None: super().__init__(name, router) self._inner = inner - @message_handler(MessageType) + @message_handler() async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: inner_response = self._send_message(message, self._inner) inner_message = await inner_response diff --git a/src/agnext/chat/agents/chat_completion_agent.py b/src/agnext/chat/agents/chat_completion_agent.py index 78465b74d257..80ac13c26bda 100644 --- a/src/agnext/chat/agents/chat_completion_agent.py +++ b/src/agnext/chat/agents/chat_completion_agent.py @@ -38,17 +38,17 @@ def __init__( self._chat_messages: List[Message] = [] self._function_executor = function_executor - @message_handler(TextMessage) + @message_handler() async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # Add a user message. self._chat_messages.append(message) - @message_handler(Reset) + @message_handler() async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: # Reset the chat messages. self._chat_messages = [] - @message_handler(RespondNow) + @message_handler() async def on_respond_now( self, message: RespondNow, cancellation_token: CancellationToken ) -> TextMessage | FunctionCallMessage: @@ -101,7 +101,7 @@ async def on_respond_now( # Return the response. return final_response - @message_handler(FunctionCallMessage) + @message_handler() async def on_tool_call_message( self, message: FunctionCallMessage, cancellation_token: CancellationToken ) -> FunctionExecutionResultMessage: diff --git a/src/agnext/chat/agents/oai_assistant.py b/src/agnext/chat/agents/oai_assistant.py index d121fd3cf632..1dd437812ad8 100644 --- a/src/agnext/chat/agents/oai_assistant.py +++ b/src/agnext/chat/agents/oai_assistant.py @@ -24,7 +24,7 @@ def __init__( self._assistant_id = assistant_id self._thread_id = thread_id - @message_handler(TextMessage) + @message_handler() async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # Save the message to the thread. _ = await self._client.beta.threads.messages.create( @@ -34,7 +34,7 @@ async def on_text_message(self, message: TextMessage, cancellation_token: Cancel metadata={"sender": message.source}, ) - @message_handler(Reset) + @message_handler() async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: # Get all messages in this thread. all_msgs: List[str] = [] @@ -52,7 +52,7 @@ async def on_reset(self, message: Reset, cancellation_token: CancellationToken) status = await self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id) assert status.deleted is True - @message_handler(RespondNow) + @message_handler() async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage: # Handle response format. if message.response_format == ResponseFormat.json_object: diff --git a/src/agnext/chat/patterns/group_chat.py b/src/agnext/chat/patterns/group_chat.py index 2c08bceee4b7..f6498ba52800 100644 --- a/src/agnext/chat/patterns/group_chat.py +++ b/src/agnext/chat/patterns/group_chat.py @@ -35,16 +35,16 @@ def subscriptions(self) -> Sequence[type]: agent_sublists = [agent.subscriptions for agent in self._agents] return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist] - @message_handler(Reset) + @message_handler() async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: self._history.clear() - @message_handler(RespondNow) + @message_handler() async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> Any: return self._output.get_output() - @message_handler(TextMessage) - async def on_text_message(self, message: Any, cancellation_token: CancellationToken) -> Any: + @message_handler() + async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> Any: # TODO: how should we handle the group chat receiving a message while in the middle of a conversation? # Should this class disallow it? diff --git a/src/agnext/chat/patterns/orchestrator_chat.py b/src/agnext/chat/patterns/orchestrator_chat.py index 802c62d63aa1..420b771cb687 100644 --- a/src/agnext/chat/patterns/orchestrator_chat.py +++ b/src/agnext/chat/patterns/orchestrator_chat.py @@ -34,7 +34,7 @@ def __init__( def children(self) -> Sequence[str]: return [agent.name for agent in self._specialists] + [self._orchestrator.name] + [self._planner.name] - @message_handler(TextMessage) + @message_handler() async def on_text_message( self, message: TextMessage, diff --git a/src/agnext/chat/patterns/two_agent_chat.py b/src/agnext/chat/patterns/two_agent_chat.py new file mode 100644 index 000000000000..fb7d39ecdd1f --- /dev/null +++ b/src/agnext/chat/patterns/two_agent_chat.py @@ -0,0 +1,18 @@ +from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput + +from ...core import AgentRuntime +from ..agents.base import BaseChatAgent + + +class TwoAgentChat(GroupChat): + def __init__( + self, + name: str, + description: str, + runtime: AgentRuntime, + agent1: BaseChatAgent, + agent2: BaseChatAgent, + num_rounds: int, + output: GroupChatOutput, + ) -> None: + super().__init__(name, description, runtime, [agent1, agent2], num_rounds, output) diff --git a/src/agnext/components/type_routed_agent.py b/src/agnext/components/type_routed_agent.py index ba471fca5c18..74fc0360c0dc 100644 --- a/src/agnext/components/type_routed_agent.py +++ b/src/agnext/components/type_routed_agent.py @@ -1,28 +1,132 @@ -from typing import Any, Callable, Coroutine, Dict, NoReturn, Sequence, Type, TypeVar +import logging +from functools import wraps +from types import NoneType, UnionType +from typing import ( + Any, + Callable, + Coroutine, + Dict, + Literal, + NoReturn, + Optional, + Protocol, + Sequence, + Type, + TypeVar, + Union, + cast, + get_args, + get_origin, + get_type_hints, + runtime_checkable, +) from agnext.core import AgentRuntime, BaseAgent, CancellationToken from agnext.core.exceptions import CantHandleException -ReceivesT = TypeVar("ReceivesT") +logger = logging.getLogger("agnext") + +ReceivesT = TypeVar("ReceivesT", contravariant=True) ProducesT = TypeVar("ProducesT", covariant=True) # TODO: Generic typevar bound binding U to agent type # Can't do because python doesnt support it +def is_union(t: object) -> bool: + origin = get_origin(t) + return origin is Union or origin is UnionType + + +def is_optional(t: object) -> bool: + origin = get_origin(t) + return origin is Optional + + +# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any +class AnyType: + pass + + +def get_types(t: object) -> Sequence[Type[Any]] | None: + if is_union(t): + return get_args(t) + elif is_optional(t): + return tuple(list(get_args(t)) + [NoneType]) + elif t is Any: + return (AnyType,) + elif isinstance(t, type): + return (t,) + elif isinstance(t, NoneType): + return (NoneType,) + else: + return None + + +@runtime_checkable +class MessageHandler(Protocol[ReceivesT, ProducesT]): + target_types: Sequence[type] + produces_types: Sequence[type] + is_message_handler: Literal[True] + + async def __call__(self, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: ... + + # NOTE: this works on concrete types and not inheritance +# TODO: Use a protocl for the outer function to check checked arg names def message_handler( - *target_types: Type[ReceivesT], + strict: bool = True, ) -> Callable[ - [Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]], - Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]], + [Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]], + MessageHandler[ReceivesT, ProducesT], ]: def decorator( - func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]], - ) -> Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]: + func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]], + ) -> MessageHandler[ReceivesT, ProducesT]: + type_hints = get_type_hints(func) + if "message" not in type_hints: + raise AssertionError("message parameter not found in function signature") + + if "return" not in type_hints: + raise AssertionError("return not found in function signature") + + # Get the type of the message parameter + target_types = get_types(type_hints["message"]) + if target_types is None: + raise AssertionError("Message type not found") + + print(type_hints) + return_types = get_types(type_hints["return"]) + + if return_types is None: + raise AssertionError("Return type not found") + # Convert target_types to list and stash - func._target_types = list(target_types) # type: ignore - return func + + @wraps(func) + async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: + if strict: + if type(message) not in target_types: + raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") + else: + logger.warning(f"Message type {type(message)} not in target types {target_types}") + + return_value = await func(self, message, cancellation_token) + + if strict: + if return_value is not AnyType and type(return_value) not in return_types: + raise ValueError(f"Return type {type(return_value)} not in return types {return_types}") + elif return_value is not AnyType: + logger.warning(f"Return type {type(return_value)} not in return types {return_types}") + + return return_value + + wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper) + wrapper_handler.target_types = list(target_types) + wrapper_handler.produces_types = list(return_types) + wrapper_handler.is_message_handler = True + + return wrapper_handler return decorator @@ -35,9 +139,10 @@ def __init__(self, name: str, router: AgentRuntime) -> None: for attr in dir(self): if callable(getattr(self, attr, None)): handler = getattr(self, attr) - if hasattr(handler, "_target_types"): - for target_type in handler._target_types: - self._handlers[target_type] = handler + if hasattr(handler, "is_message_handler"): + message_handler = cast(MessageHandler[Any, Any], handler) + for target_type in message_handler.target_types: + self._handlers[target_type] = message_handler super().__init__(name, router) diff --git a/tests/test_cancellation.py b/tests/test_cancellation.py index 3555a63ea3ed..2d44f724a370 100644 --- a/tests/test_cancellation.py +++ b/tests/test_cancellation.py @@ -22,7 +22,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None: self.called = False self.cancelled = False - @message_handler(MessageType) + @message_handler() async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: self.called = True sleep = asyncio.ensure_future(asyncio.sleep(100)) @@ -41,7 +41,7 @@ def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None self.cancelled = False self._nested_agent = nested_agent - @message_handler(MessageType) + @message_handler() async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: self.called = True response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token) diff --git a/tests/test_intervention.py b/tests/test_intervention.py index 1a5af25b24fd..e7780204f9bd 100644 --- a/tests/test_intervention.py +++ b/tests/test_intervention.py @@ -19,7 +19,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None: self.num_calls = 0 - @message_handler(MessageType) + @message_handler() async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: self.num_calls += 1 return message diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 000000000000..d214eece26f7 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,39 @@ +from types import NoneType +from typing import Any, Optional, Union + +from agnext.components.type_routed_agent import AnyType, get_types, message_handler +from agnext.core import CancellationToken + + +def test_get_types() -> None: + assert get_types(Union[int, str]) == (int, str) + assert get_types(int | str) == (int, str) + assert get_types(int) == (int,) + assert get_types(str) == (str,) + assert get_types("test") is None + assert get_types(Optional[int]) == (int, NoneType) + assert get_types(NoneType) == (NoneType,) + assert get_types(None) == (NoneType,) + + +def test_handler() -> None: + + class HandlerClass: + @message_handler() + async def handler(self, message: int, cancellation_token: CancellationToken) -> Any: + return None + + @message_handler() + async def handler2(self, message: str | bool, cancellation_token: CancellationToken) -> None: + return None + + assert HandlerClass.handler.target_types == [int] + assert HandlerClass.handler.produces_types == [AnyType] + + assert HandlerClass.handler2.target_types == [str, bool] + assert HandlerClass.handler2.produces_types == [NoneType] + +class HandlerClass: + @message_handler() + async def handler(self, message: int, cancellation_token: CancellationToken) -> Any: + return None