Skip to content

Commit

Permalink
Simplify handler decorator (autogenhub#50)
Browse files Browse the repository at this point in the history
* Simplify handler decorator

* add more tests

* mypy

* formatting

* fix 3.10 and improve type handling of decorator

* test fix

* format
  • Loading branch information
jackgerrits authored Jun 5, 2024
1 parent ad513d5 commit 8cb530f
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 29 deletions.
4 changes: 2 additions & 2 deletions examples/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Inner(TypeRoutedAgent):
def __init__(self, name: str, router: AgentRuntime) -> None:
super().__init__(name, router)

@message_handler(MessageType)
@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
return MessageType(body=f"Inner: {message.body}", sender=self.name)

Expand All @@ -26,7 +26,7 @@ def __init__(self, name: str, router: AgentRuntime, inner: Agent) -> None:
super().__init__(name, router)
self._inner = inner

@message_handler(MessageType)
@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
inner_response = self._send_message(message, self._inner)
inner_message = await inner_response
Expand Down
8 changes: 4 additions & 4 deletions src/agnext/chat/agents/chat_completion_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ def __init__(
self._chat_messages: List[Message] = []
self._function_executor = function_executor

@message_handler(TextMessage)
@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
# Add a user message.
self._chat_messages.append(message)

@message_handler(Reset)
@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Reset the chat messages.
self._chat_messages = []

@message_handler(RespondNow)
@message_handler()
async def on_respond_now(
self, message: RespondNow, cancellation_token: CancellationToken
) -> TextMessage | FunctionCallMessage:
Expand Down Expand Up @@ -101,7 +101,7 @@ async def on_respond_now(
# Return the response.
return final_response

@message_handler(FunctionCallMessage)
@message_handler()
async def on_tool_call_message(
self, message: FunctionCallMessage, cancellation_token: CancellationToken
) -> FunctionExecutionResultMessage:
Expand Down
6 changes: 3 additions & 3 deletions src/agnext/chat/agents/oai_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self._assistant_id = assistant_id
self._thread_id = thread_id

@message_handler(TextMessage)
@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
# Save the message to the thread.
_ = await self._client.beta.threads.messages.create(
Expand All @@ -34,7 +34,7 @@ async def on_text_message(self, message: TextMessage, cancellation_token: Cancel
metadata={"sender": message.source},
)

@message_handler(Reset)
@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Get all messages in this thread.
all_msgs: List[str] = []
Expand All @@ -52,7 +52,7 @@ async def on_reset(self, message: Reset, cancellation_token: CancellationToken)
status = await self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
assert status.deleted is True

@message_handler(RespondNow)
@message_handler()
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
# Handle response format.
if message.response_format == ResponseFormat.json_object:
Expand Down
8 changes: 4 additions & 4 deletions src/agnext/chat/patterns/group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,16 @@ def subscriptions(self) -> Sequence[type]:
agent_sublists = [agent.subscriptions for agent in self._agents]
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]

@message_handler(Reset)
@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
self._history.clear()

@message_handler(RespondNow)
@message_handler()
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> Any:
return self._output.get_output()

@message_handler(TextMessage)
async def on_text_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
@message_handler()
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> Any:
# TODO: how should we handle the group chat receiving a message while in the middle of a conversation?
# Should this class disallow it?

Expand Down
2 changes: 1 addition & 1 deletion src/agnext/chat/patterns/orchestrator_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
def children(self) -> Sequence[str]:
return [agent.name for agent in self._specialists] + [self._orchestrator.name] + [self._planner.name]

@message_handler(TextMessage)
@message_handler()
async def on_text_message(
self,
message: TextMessage,
Expand Down
18 changes: 18 additions & 0 deletions src/agnext/chat/patterns/two_agent_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput

from ...core import AgentRuntime
from ..agents.base import BaseChatAgent


class TwoAgentChat(GroupChat):
def __init__(
self,
name: str,
description: str,
runtime: AgentRuntime,
agent1: BaseChatAgent,
agent2: BaseChatAgent,
num_rounds: int,
output: GroupChatOutput,
) -> None:
super().__init__(name, description, runtime, [agent1, agent2], num_rounds, output)
129 changes: 117 additions & 12 deletions src/agnext/components/type_routed_agent.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,132 @@
from typing import Any, Callable, Coroutine, Dict, NoReturn, Sequence, Type, TypeVar
import logging
from functools import wraps
from types import NoneType, UnionType
from typing import (
Any,
Callable,
Coroutine,
Dict,
Literal,
NoReturn,
Optional,
Protocol,
Sequence,
Type,
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
runtime_checkable,
)

from agnext.core import AgentRuntime, BaseAgent, CancellationToken
from agnext.core.exceptions import CantHandleException

ReceivesT = TypeVar("ReceivesT")
logger = logging.getLogger("agnext")

ReceivesT = TypeVar("ReceivesT", contravariant=True)
ProducesT = TypeVar("ProducesT", covariant=True)

# TODO: Generic typevar bound binding U to agent type
# Can't do because python doesnt support it


def is_union(t: object) -> bool:
origin = get_origin(t)
return origin is Union or origin is UnionType


def is_optional(t: object) -> bool:
origin = get_origin(t)
return origin is Optional


# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
class AnyType:
pass


def get_types(t: object) -> Sequence[Type[Any]] | None:
if is_union(t):
return get_args(t)
elif is_optional(t):
return tuple(list(get_args(t)) + [NoneType])
elif t is Any:
return (AnyType,)
elif isinstance(t, type):
return (t,)
elif isinstance(t, NoneType):
return (NoneType,)
else:
return None


@runtime_checkable
class MessageHandler(Protocol[ReceivesT, ProducesT]):
target_types: Sequence[type]
produces_types: Sequence[type]
is_message_handler: Literal[True]

async def __call__(self, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: ...


# NOTE: this works on concrete types and not inheritance
# TODO: Use a protocl for the outer function to check checked arg names
def message_handler(
*target_types: Type[ReceivesT],
strict: bool = True,
) -> Callable[
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
]:
def decorator(
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
) -> Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]:
type_hints = get_type_hints(func)
if "message" not in type_hints:
raise AssertionError("message parameter not found in function signature")

if "return" not in type_hints:
raise AssertionError("return not found in function signature")

# Get the type of the message parameter
target_types = get_types(type_hints["message"])
if target_types is None:
raise AssertionError("Message type not found")

print(type_hints)
return_types = get_types(type_hints["return"])

if return_types is None:
raise AssertionError("Return type not found")

# Convert target_types to list and stash
func._target_types = list(target_types) # type: ignore
return func

@wraps(func)
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
if strict:
if type(message) not in target_types:
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
else:
logger.warning(f"Message type {type(message)} not in target types {target_types}")

return_value = await func(self, message, cancellation_token)

if strict:
if return_value is not AnyType and type(return_value) not in return_types:
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
elif return_value is not AnyType:
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")

return return_value

wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
wrapper_handler.target_types = list(target_types)
wrapper_handler.produces_types = list(return_types)
wrapper_handler.is_message_handler = True

return wrapper_handler

return decorator

Expand All @@ -35,9 +139,10 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
for attr in dir(self):
if callable(getattr(self, attr, None)):
handler = getattr(self, attr)
if hasattr(handler, "_target_types"):
for target_type in handler._target_types:
self._handlers[target_type] = handler
if hasattr(handler, "is_message_handler"):
message_handler = cast(MessageHandler[Any, Any], handler)
for target_type in message_handler.target_types:
self._handlers[target_type] = message_handler

super().__init__(name, router)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
self.called = False
self.cancelled = False

@message_handler(MessageType)
@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True
sleep = asyncio.ensure_future(asyncio.sleep(100))
Expand All @@ -41,7 +41,7 @@ def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None
self.cancelled = False
self._nested_agent = nested_agent

@message_handler(MessageType)
@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.called = True
response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
self.num_calls = 0


@message_handler(MessageType)
@message_handler()
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
self.num_calls += 1
return message
Expand Down
39 changes: 39 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from types import NoneType
from typing import Any, Optional, Union

from agnext.components.type_routed_agent import AnyType, get_types, message_handler
from agnext.core import CancellationToken


def test_get_types() -> None:
assert get_types(Union[int, str]) == (int, str)
assert get_types(int | str) == (int, str)
assert get_types(int) == (int,)
assert get_types(str) == (str,)
assert get_types("test") is None
assert get_types(Optional[int]) == (int, NoneType)
assert get_types(NoneType) == (NoneType,)
assert get_types(None) == (NoneType,)


def test_handler() -> None:

class HandlerClass:
@message_handler()
async def handler(self, message: int, cancellation_token: CancellationToken) -> Any:
return None

@message_handler()
async def handler2(self, message: str | bool, cancellation_token: CancellationToken) -> None:
return None

assert HandlerClass.handler.target_types == [int]
assert HandlerClass.handler.produces_types == [AnyType]

assert HandlerClass.handler2.target_types == [str, bool]
assert HandlerClass.handler2.produces_types == [NoneType]

class HandlerClass:
@message_handler()
async def handler(self, message: int, cancellation_token: CancellationToken) -> Any:
return None

0 comments on commit 8cb530f

Please sign in to comment.