Skip to content

Commit

Permalink
Make parenthesis optional for type routed agent message handler (auto…
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits authored Jun 13, 2024
1 parent 387aa6a commit 8dad8b0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
36 changes: 33 additions & 3 deletions src/agnext/components/_type_routed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_args,
get_origin,
get_type_hints,
overload,
runtime_checkable,
)

Expand Down Expand Up @@ -74,12 +75,36 @@ async def __call__(self, message: ReceivesT, cancellation_token: CancellationTok

# NOTE: this works on concrete types and not inheritance
# TODO: Use a protocl for the outer function to check checked arg names


@overload
def message_handler(
strict: bool = True,
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]: ...


@overload
def message_handler(
func: None = None,
*,
strict: bool = ...,
) -> Callable[
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
]:
]: ...


def message_handler(
func: None | Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]] = None,
*,
strict: bool = True,
) -> (
Callable[
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
MessageHandler[ReceivesT, ProducesT],
]
| MessageHandler[ReceivesT, ProducesT]
):
def decorator(
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
) -> MessageHandler[ReceivesT, ProducesT]:
Expand Down Expand Up @@ -128,7 +153,12 @@ async def wrapper(self: Any, message: ReceivesT, cancellation_token: Cancellatio

return wrapper_handler

return decorator
if func is None and not callable(func):
return decorator
elif callable(func):
return decorator(func)
else:
raise ValueError("Invalid arguments")


class TypeRoutedAgent(BaseAgent):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ class MessageType:

class LongRunningAgent(TypeRoutedAgent): # type: ignore
def __init__(self, name: str, router: AgentRuntime) -> None: # type: ignore
super().__init__(name, "A long running agent", router)
super().__init__(name, "A long running agent", router)
self.called = False
self.cancelled = False

@message_handler() # type: ignore
@message_handler
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
self.called = True
sleep = asyncio.ensure_future(asyncio.sleep(100))
Expand All @@ -40,7 +40,7 @@ def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None
self.cancelled = False
self._nested_agent = nested_agent

@message_handler() # type: ignore
@message_handler
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType: # type: ignore
self.called = True
response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token)
Expand Down

0 comments on commit 8dad8b0

Please sign in to comment.