Skip to content

Commit

Permalink
Fix middlewares order (#1935)
Browse files Browse the repository at this point in the history
* Fix middlewares order for publishing scope

* Missing middlewares

* ruff

* add tests

* ruff tests

* pre-commit

* Add test for `aenter`, `aexit`

* pre-commit

* Update tests

* Fix tests

* Fix tests again

* lint: fix mypy

* Update tests

* pre-commit

* Fix tests

* Update tests, again :)

* format tests

* Remove logger from tests

* format

* fix: correct NATS Publisher middlewares order

* fix: correct All Publisher middlewares order

* Try to fix types

* format

* chore: fix CI

* fix: correct confluent middlewares order

* tests: make order middlewares tests in-memory

---------

Co-authored-by: Nikita Pastukhov <[email protected]>
Co-authored-by: Pastukhov Nikita <[email protected]>
  • Loading branch information
3 people authored Nov 29, 2024
1 parent 90aaf57 commit a173bb5
Show file tree
Hide file tree
Showing 58 changed files with 685 additions and 266 deletions.
3 changes: 2 additions & 1 deletion faststream/broker/core/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Iterable,
Mapping,
Optional,
Sequence,
)

from faststream.broker.types import MsgType
Expand All @@ -30,7 +31,7 @@ def __init__(
*,
prefix: str,
dependencies: Iterable["Depends"],
middlewares: Iterable["BrokerMiddleware[MsgType]"],
middlewares: Sequence["BrokerMiddleware[MsgType]"],
parser: Optional["CustomCallable"],
decoder: Optional["CustomCallable"],
include_in_schema: Optional[bool],
Expand Down
8 changes: 4 additions & 4 deletions faststream/broker/core/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
Doc("Dependencies to apply to all broker subscribers."),
],
middlewares: Annotated[
Iterable["BrokerMiddleware[MsgType]"],
Sequence["BrokerMiddleware[MsgType]"],
Doc("Middlewares to apply to all broker publishers/subscribers."),
],
graceful_timeout: Annotated[
Expand Down Expand Up @@ -342,7 +342,7 @@ async def publish(

publish = producer.publish

for m in self._middlewares:
for m in self._middlewares[::-1]:
publish = partial(m(None).publish_scope, publish)

return await publish(msg, correlation_id=correlation_id, **kwargs)
Expand All @@ -359,7 +359,7 @@ async def request(
assert producer, NOT_CONNECTED_YET # nosec B101

request = producer.request
for m in self._middlewares:
for m in self._middlewares[::-1]:
request = partial(m(None).publish_scope, request)

published_msg = await request(
Expand All @@ -370,7 +370,7 @@ async def request(

async with AsyncExitStack() as stack:
return_msg = return_input
for m in self._middlewares:
for m in self._middlewares[::-1]:
mid = m(published_msg)
await stack.enter_async_context(mid)
return_msg = partial(mid.consume_scope, return_msg)
Expand Down
2 changes: 1 addition & 1 deletion faststream/broker/fastapi/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class StreamRouter(
def __init__(
self,
*connection_args: Any,
middlewares: Iterable["BrokerMiddleware[MsgType]"] = (),
middlewares: Sequence["BrokerMiddleware[MsgType]"] = (),
prefix: str = "",
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence["params.Depends"]] = None,
Expand Down
8 changes: 4 additions & 4 deletions faststream/broker/publisher/fake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Iterable, Optional
from typing import TYPE_CHECKING, Any, Optional, Sequence

from faststream.broker.publisher.proto import BasePublisherProto

Expand All @@ -17,7 +17,7 @@ def __init__(
method: "AsyncFunc",
*,
publish_kwargs: "AnyDict",
middlewares: Iterable["PublisherMiddleware"] = (),
middlewares: Sequence["PublisherMiddleware"] = (),
) -> None:
"""Initialize an object."""
self.method = method
Expand All @@ -29,7 +29,7 @@ async def publish(
message: "SendableMessage",
*,
correlation_id: Optional[str] = None,
_extra_middlewares: Iterable["PublisherMiddleware"] = (),
_extra_middlewares: Sequence["PublisherMiddleware"] = (),
**kwargs: Any,
) -> Any:
"""Publish a message."""
Expand All @@ -51,7 +51,7 @@ async def request(
/,
*,
correlation_id: Optional[str] = None,
_extra_middlewares: Iterable["PublisherMiddleware"] = (),
_extra_middlewares: Sequence["PublisherMiddleware"] = (),
) -> Any:
raise NotImplementedError(
"`FakePublisher` can be used only to publish "
Expand Down
18 changes: 13 additions & 5 deletions faststream/broker/publisher/proto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Optional, Protocol
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
Protocol,
Sequence,
)

from typing_extensions import override

Expand Down Expand Up @@ -53,7 +61,7 @@ async def publish(
/,
*,
correlation_id: Optional[str] = None,
_extra_middlewares: Iterable["PublisherMiddleware"] = (),
_extra_middlewares: Sequence["PublisherMiddleware"] = (),
) -> Optional[Any]:
"""Publishes a message asynchronously."""
...
Expand All @@ -65,7 +73,7 @@ async def request(
/,
*,
correlation_id: Optional[str] = None,
_extra_middlewares: Iterable["PublisherMiddleware"] = (),
_extra_middlewares: Sequence["PublisherMiddleware"] = (),
) -> Optional[Any]:
"""Publishes a message synchronously."""
...
Expand All @@ -79,8 +87,8 @@ class PublisherProto(
):
schema_: Any

_broker_middlewares: Iterable["BrokerMiddleware[MsgType]"]
_middlewares: Iterable["PublisherMiddleware"]
_broker_middlewares: Sequence["BrokerMiddleware[MsgType]"]
_middlewares: Sequence["PublisherMiddleware"]
_producer: Optional["ProducerProto"]

@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions faststream/broker/publisher/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
TYPE_CHECKING,
Any,
Callable,
Iterable,
List,
Optional,
Sequence,
Tuple,
)
from unittest.mock import MagicMock
Expand Down Expand Up @@ -49,11 +49,11 @@ def __init__(
self,
*,
broker_middlewares: Annotated[
Iterable["BrokerMiddleware[MsgType]"],
Sequence["BrokerMiddleware[MsgType]"],
Doc("Top-level middlewares to use in direct `.publish` call."),
],
middlewares: Annotated[
Iterable["PublisherMiddleware"],
Sequence["PublisherMiddleware"],
Doc("Publisher middlewares."),
],
# AsyncAPI args
Expand Down
3 changes: 2 additions & 1 deletion faststream/broker/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Callable,
Iterable,
Optional,
Sequence,
)

from faststream.broker.core.abc import ABCBroker
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(
# base options
prefix: str,
dependencies: Iterable["Depends"],
middlewares: Iterable["BrokerMiddleware[MsgType]"],
middlewares: Sequence["BrokerMiddleware[MsgType]"],
parser: Optional["CustomCallable"],
decoder: Optional["CustomCallable"],
include_in_schema: Optional[bool],
Expand Down
5 changes: 3 additions & 2 deletions faststream/broker/subscriber/call_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Generic,
Iterable,
Optional,
Sequence,
cast,
)

Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(
filter: "AsyncFilter[StreamMessage[MsgType]]",
item_parser: Optional["CustomCallable"],
item_decoder: Optional["CustomCallable"],
item_middlewares: Iterable["SubscriberMiddleware[StreamMessage[MsgType]]"],
item_middlewares: Sequence["SubscriberMiddleware[StreamMessage[MsgType]]"],
dependencies: Iterable["Depends"],
) -> None:
self.handler = handler
Expand Down Expand Up @@ -157,7 +158,7 @@ async def call(
"""Execute wrapped handler with consume middlewares."""
call: AsyncFuncAny = self.handler.call_wrapped

for middleware in chain(self.item_middlewares, _extra_middlewares):
for middleware in chain(self.item_middlewares[::-1], _extra_middlewares):
call = partial(middleware, call)

try:
Expand Down
15 changes: 12 additions & 3 deletions faststream/broker/subscriber/proto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
)

from typing_extensions import Self, override

Expand Down Expand Up @@ -33,7 +42,7 @@ class SubscriberProto(
running: bool

_broker_dependencies: Iterable["Depends"]
_broker_middlewares: Iterable["BrokerMiddleware[MsgType]"]
_broker_middlewares: Sequence["BrokerMiddleware[MsgType]"]
_producer: Optional["ProducerProto"]

@abstractmethod
Expand Down Expand Up @@ -98,6 +107,6 @@ def add_call(
filter_: "Filter[Any]",
parser_: "CustomCallable",
decoder_: "CustomCallable",
middlewares_: Iterable["SubscriberMiddleware[Any]"],
middlewares_: Sequence["SubscriberMiddleware[Any]"],
dependencies_: Iterable["Depends"],
) -> Self: ...
27 changes: 16 additions & 11 deletions faststream/broker/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
overload,
Expand Down Expand Up @@ -53,11 +54,11 @@

class _CallOptions:
__slots__ = (
"filter",
"parser",
"decoder",
"middlewares",
"dependencies",
"filter",
"middlewares",
"parser",
)

def __init__(
Expand All @@ -66,7 +67,7 @@ def __init__(
filter: "Filter[Any]",
parser: Optional["CustomCallable"],
decoder: Optional["CustomCallable"],
middlewares: Iterable["SubscriberMiddleware[Any]"],
middlewares: Sequence["SubscriberMiddleware[Any]"],
dependencies: Iterable["Depends"],
) -> None:
self.filter = filter
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
no_reply: bool,
retry: Union[bool, int],
broker_dependencies: Iterable["Depends"],
broker_middlewares: Iterable["BrokerMiddleware[MsgType]"],
broker_middlewares: Sequence["BrokerMiddleware[MsgType]"],
default_parser: "AsyncCallable",
default_decoder: "AsyncCallable",
# AsyncAPI information
Expand Down Expand Up @@ -211,7 +212,7 @@ def add_call(
filter_: "Filter[Any]",
parser_: Optional["CustomCallable"],
decoder_: Optional["CustomCallable"],
middlewares_: Iterable["SubscriberMiddleware[Any]"],
middlewares_: Sequence["SubscriberMiddleware[Any]"],
dependencies_: Iterable["Depends"],
) -> Self:
self._call_options = _CallOptions(
Expand All @@ -231,7 +232,7 @@ def __call__(
filter: Optional["Filter[Any]"] = None,
parser: Optional["CustomCallable"] = None,
decoder: Optional["CustomCallable"] = None,
middlewares: Iterable["SubscriberMiddleware[Any]"] = (),
middlewares: Sequence["SubscriberMiddleware[Any]"] = (),
dependencies: Iterable["Depends"] = (),
) -> Callable[
[Callable[P_HandlerParams, T_HandlerReturn]],
Expand All @@ -246,7 +247,7 @@ def __call__(
filter: Optional["Filter[Any]"] = None,
parser: Optional["CustomCallable"] = None,
decoder: Optional["CustomCallable"] = None,
middlewares: Iterable["SubscriberMiddleware[Any]"] = (),
middlewares: Sequence["SubscriberMiddleware[Any]"] = (),
dependencies: Iterable["Depends"] = (),
) -> "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]": ...

Expand All @@ -257,7 +258,7 @@ def __call__(
filter: Optional["Filter[Any]"] = None,
parser: Optional["CustomCallable"] = None,
decoder: Optional["CustomCallable"] = None,
middlewares: Iterable["SubscriberMiddleware[Any]"] = (),
middlewares: Sequence["SubscriberMiddleware[Any]"] = (),
dependencies: Iterable["Depends"] = (),
) -> Any:
if (options := self._call_options) is None:
Expand Down Expand Up @@ -367,7 +368,9 @@ async def process_message(self, msg: MsgType) -> "Response":
await h.call(
message=message,
# consumer middlewares
_extra_middlewares=(m.consume_scope for m in middlewares),
_extra_middlewares=(
m.consume_scope for m in middlewares[::-1]
),
)
)

Expand All @@ -382,7 +385,9 @@ async def process_message(self, msg: MsgType) -> "Response":
result_msg.body,
**result_msg.as_publish_kwargs(),
# publisher middlewares
_extra_middlewares=(m.publish_scope for m in middlewares),
_extra_middlewares=[
m.publish_scope for m in middlewares[::-1]
],
)

# Return data for tests
Expand Down
6 changes: 3 additions & 3 deletions faststream/broker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
AsyncContextManager,
Awaitable,
Callable,
Iterable,
Optional,
Sequence,
Type,
Union,
cast,
Expand Down Expand Up @@ -37,7 +37,7 @@

async def process_msg(
msg: Optional[MsgType],
middlewares: Iterable["BrokerMiddleware[MsgType]"],
middlewares: Sequence["BrokerMiddleware[MsgType]"],
parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]],
decoder: Callable[["StreamMessage[MsgType]"], "Any"],
) -> Optional["StreamMessage[MsgType]"]:
Expand All @@ -50,7 +50,7 @@ async def process_msg(
Awaitable[StreamMessage[MsgType]],
] = return_input

for m in middlewares:
for m in middlewares[::-1]:
mid = m(msg)
await stack.enter_async_context(mid)
return_msg = partial(mid.consume_scope, return_msg)
Expand Down
Loading

0 comments on commit a173bb5

Please sign in to comment.