From 7ddcababc19fad0d241304db3b7cc27adb5a56b9 Mon Sep 17 00:00:00 2001 From: Daniel Knell Date: Fri, 9 Aug 2024 11:55:22 +0000 Subject: [PATCH] feat: add return type for queries --- docs/conf.py | 2 +- docs/usage/dispatching-requests.md | 25 ++++++++++++++++++++++++- src/banshee/__init__.py | 3 ++- src/banshee/bus.py | 20 +++++++++++++++++--- tests/unit/test_message_bus.py | 2 +- tests/unit/test_traceable_bus.py | 2 +- 6 files changed, 46 insertions(+), 8 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d95100e..1cce86f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -85,7 +85,7 @@ ] intersphinx_mapping = { - "python": ("https://docs.python.org/3.8", None), + "python": ("https://docs.python.org/3.10", None), "injector": ("https://injector.readthedocs.io/en/latest", None), } diff --git a/docs/usage/dispatching-requests.md b/docs/usage/dispatching-requests.md index bfb7e11..537eb35 100644 --- a/docs/usage/dispatching-requests.md +++ b/docs/usage/dispatching-requests.md @@ -66,12 +66,31 @@ class GetUserQuery: user_id: int @registry.subscribe_to(GetUserQuery) -async def do_greeting(query: GetUserQuery) -> User: +async def get_user(query: GetUserQuery) -> User: return user_store.get(query.user_id) user = await bus.query(GetUserQuery(user_id=1)) ``` +#### Return types + +The return type of a query is {class}`~typing.Any`, this can be +overridden by inheriting from the generic {class}`~banshee.Query` class. + +```py +class GetUserQuery(banshee.Query[User]): + user_id: int + +@dataclasses.dataclass(freeze=True) +class GetUserQuery: + user_id: int + +@registry.subscribe_to(GetUserQuery) +async def get_user(query: GetUserQuery) -> User: + return user_store.get(query.user_id) + +user = await bus.query(GetUserQuery(user_id=1)) +``` ## Reference @@ -95,6 +114,10 @@ user = await bus.query(GetUserQuery(user_id=1)) .. autoclass:: banshee.Registry :show-inheritance: :members: + +.. autoclass:: banshee.Query + :show-inheritance: + :members: ``` ```{exception} banshee.ConfigurationError(message) diff --git a/src/banshee/__init__.py b/src/banshee/__init__.py index 2c1b0f9..ee8a69c 100644 --- a/src/banshee/__init__.py +++ b/src/banshee/__init__.py @@ -3,7 +3,7 @@ """ from banshee.builder import Builder -from banshee.bus import Bus, MessageBus +from banshee.bus import Bus, MessageBus, Query from banshee.context import Causation, Dispatch, HandleAfter, Identity from banshee.errors import ConfigurationError, DispatchError, MultipleErrors from banshee.message import HandleMessage, Message, Middleware, message_for @@ -48,6 +48,7 @@ "MessageInfo", "Middleware", "MultipleErrors", + "Query", "Registry", "SimpleHandlerFactory", "TraceableBus", diff --git a/src/banshee/bus.py b/src/banshee/bus.py index d62631f..cc9dc7c 100644 --- a/src/banshee/bus.py +++ b/src/banshee/bus.py @@ -13,6 +13,20 @@ #: T T = typing.TypeVar("T") +#: Return Type +ReturnType_co = typing.TypeVar( # pylint: disable=invalid-name + "ReturnType_co", + covariant=True, +) + + +class Query(typing.Protocol[ReturnType_co]): # pylint: disable=too-few-public-methods + """ + Query. + + A generic type for specifying the return type of a request. + """ + @typing.runtime_checkable class Bus(typing.Protocol): @@ -39,9 +53,9 @@ async def handle( async def query( self, - query: typing.Any, + query: Query[T], contexts: collections.abc.Iterable[typing.Any] | None = None, - ) -> typing.Any: + ) -> T: """ Query. @@ -68,7 +82,7 @@ async def query( f"multiple handlers for {type(message.request).__name__} found" ) - return dispatch_contexts[0].result + return typing.cast(T, dispatch_contexts[0].result) class MessageBus(Bus): diff --git a/tests/unit/test_message_bus.py b/tests/unit/test_message_bus.py index 1b3cd9e..21b92f0 100644 --- a/tests/unit/test_message_bus.py +++ b/tests/unit/test_message_bus.py @@ -105,7 +105,7 @@ async def middleware( bus = banshee.bus.MessageBus([middleware]) - result = await bus.query(_Request()) + result: object = await bus.query(_Request()) assert result is expected diff --git a/tests/unit/test_traceable_bus.py b/tests/unit/test_traceable_bus.py index ad624d6..10fcf3b 100644 --- a/tests/unit/test_traceable_bus.py +++ b/tests/unit/test_traceable_bus.py @@ -67,7 +67,7 @@ async def test_query_should_return_handler_result() -> None: bus = banshee.TraceableBus(inner) - result = await bus.query(request, contexts=[context1]) + result: object = await bus.query(request, contexts=[context1]) inner.handle.assert_awaited_once_with( banshee.message_for(request, contexts=[context1])