Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix middlewares order #1935

Merged
merged 28 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion faststream/broker/core/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Iterable,
Mapping,
Optional,
Sequence,
)

from faststream.broker.types import MsgType
Expand All @@ -30,7 +31,7 @@ def __init__(
*,
prefix: str,
dependencies: Iterable["Depends"],
middlewares: Iterable["BrokerMiddleware[MsgType]"],
middlewares: Sequence["BrokerMiddleware[MsgType]"],
parser: Optional["CustomCallable"],
decoder: Optional["CustomCallable"],
include_in_schema: Optional[bool],
Expand Down
8 changes: 4 additions & 4 deletions faststream/broker/core/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
Doc("Dependencies to apply to all broker subscribers."),
],
middlewares: Annotated[
Iterable["BrokerMiddleware[MsgType]"],
Sequence["BrokerMiddleware[MsgType]"],
Doc("Middlewares to apply to all broker publishers/subscribers."),
],
graceful_timeout: Annotated[
Expand Down Expand Up @@ -342,7 +342,7 @@ async def publish(

publish = producer.publish

for m in self._middlewares:
for m in self._middlewares[::-1]:
publish = partial(m(None).publish_scope, publish)

return await publish(msg, correlation_id=correlation_id, **kwargs)
Expand All @@ -359,7 +359,7 @@ async def request(
assert producer, NOT_CONNECTED_YET # nosec B101

request = producer.request
for m in self._middlewares:
for m in self._middlewares[::-1]:
request = partial(m(None).publish_scope, request)

published_msg = await request(
Expand All @@ -370,7 +370,7 @@ async def request(

async with AsyncExitStack() as stack:
return_msg = return_input
for m in self._middlewares:
for m in self._middlewares[::-1]:
mid = m(published_msg)
await stack.enter_async_context(mid)
return_msg = partial(mid.consume_scope, return_msg)
Expand Down
2 changes: 1 addition & 1 deletion faststream/broker/fastapi/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class StreamRouter(
def __init__(
self,
*connection_args: Any,
middlewares: Iterable["BrokerMiddleware[MsgType]"] = (),
middlewares: Sequence["BrokerMiddleware[MsgType]"] = (),
prefix: str = "",
tags: Optional[List[Union[str, Enum]]] = None,
dependencies: Optional[Sequence["params.Depends"]] = None,
Expand Down
8 changes: 4 additions & 4 deletions faststream/broker/publisher/fake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Iterable, Optional
from typing import TYPE_CHECKING, Any, Optional, Sequence

from faststream.broker.publisher.proto import BasePublisherProto

Expand All @@ -17,7 +17,7 @@ def __init__(
method: "AsyncFunc",
*,
publish_kwargs: "AnyDict",
middlewares: Iterable["PublisherMiddleware"] = (),
middlewares: Sequence["PublisherMiddleware"] = (),
) -> None:
"""Initialize an object."""
self.method = method
Expand All @@ -29,7 +29,7 @@ async def publish(
message: "SendableMessage",
*,
correlation_id: Optional[str] = None,
_extra_middlewares: Iterable["PublisherMiddleware"] = (),
_extra_middlewares: Sequence["PublisherMiddleware"] = (),
**kwargs: Any,
) -> Any:
"""Publish a message."""
Expand All @@ -51,7 +51,7 @@ async def request(
/,
*,
correlation_id: Optional[str] = None,
_extra_middlewares: Iterable["PublisherMiddleware"] = (),
_extra_middlewares: Sequence["PublisherMiddleware"] = (),
) -> Any:
raise NotImplementedError(
"`FakePublisher` can be used only to publish "
Expand Down
18 changes: 13 additions & 5 deletions faststream/broker/publisher/proto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Optional, Protocol
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
Protocol,
Sequence,
)

from typing_extensions import override

Expand Down Expand Up @@ -53,7 +61,7 @@ async def publish(
/,
*,
correlation_id: Optional[str] = None,
_extra_middlewares: Iterable["PublisherMiddleware"] = (),
_extra_middlewares: Sequence["PublisherMiddleware"] = (),
) -> Optional[Any]:
"""Publishes a message asynchronously."""
...
Expand All @@ -65,7 +73,7 @@ async def request(
/,
*,
correlation_id: Optional[str] = None,
_extra_middlewares: Iterable["PublisherMiddleware"] = (),
_extra_middlewares: Sequence["PublisherMiddleware"] = (),
) -> Optional[Any]:
"""Publishes a message synchronously."""
...
Expand All @@ -79,8 +87,8 @@ class PublisherProto(
):
schema_: Any

_broker_middlewares: Iterable["BrokerMiddleware[MsgType]"]
_middlewares: Iterable["PublisherMiddleware"]
_broker_middlewares: Sequence["BrokerMiddleware[MsgType]"]
_middlewares: Sequence["PublisherMiddleware"]
_producer: Optional["ProducerProto"]

@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions faststream/broker/publisher/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
TYPE_CHECKING,
Any,
Callable,
Iterable,
List,
Optional,
Sequence,
Tuple,
)
from unittest.mock import MagicMock
Expand Down Expand Up @@ -49,11 +49,11 @@ def __init__(
self,
*,
broker_middlewares: Annotated[
Iterable["BrokerMiddleware[MsgType]"],
Sequence["BrokerMiddleware[MsgType]"],
Doc("Top-level middlewares to use in direct `.publish` call."),
],
middlewares: Annotated[
Iterable["PublisherMiddleware"],
Sequence["PublisherMiddleware"],
Doc("Publisher middlewares."),
],
# AsyncAPI args
Expand Down
3 changes: 2 additions & 1 deletion faststream/broker/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Callable,
Iterable,
Optional,
Sequence,
)

from faststream.broker.core.abc import ABCBroker
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(
# base options
prefix: str,
dependencies: Iterable["Depends"],
middlewares: Iterable["BrokerMiddleware[MsgType]"],
middlewares: Sequence["BrokerMiddleware[MsgType]"],
parser: Optional["CustomCallable"],
decoder: Optional["CustomCallable"],
include_in_schema: Optional[bool],
Expand Down
5 changes: 3 additions & 2 deletions faststream/broker/subscriber/call_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Generic,
Iterable,
Optional,
Sequence,
cast,
)

Expand Down Expand Up @@ -54,7 +55,7 @@ def __init__(
filter: "AsyncFilter[StreamMessage[MsgType]]",
item_parser: Optional["CustomCallable"],
item_decoder: Optional["CustomCallable"],
item_middlewares: Iterable["SubscriberMiddleware[StreamMessage[MsgType]]"],
item_middlewares: Sequence["SubscriberMiddleware[StreamMessage[MsgType]]"],
dependencies: Iterable["Depends"],
) -> None:
self.handler = handler
Expand Down Expand Up @@ -157,7 +158,7 @@ async def call(
"""Execute wrapped handler with consume middlewares."""
call: AsyncFuncAny = self.handler.call_wrapped

for middleware in chain(self.item_middlewares, _extra_middlewares):
for middleware in chain(self.item_middlewares[::-1], _extra_middlewares):
call = partial(middleware, call)

try:
Expand Down
15 changes: 12 additions & 3 deletions faststream/broker/subscriber/proto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
)

from typing_extensions import Self, override

Expand Down Expand Up @@ -33,7 +42,7 @@ class SubscriberProto(
running: bool

_broker_dependencies: Iterable["Depends"]
_broker_middlewares: Iterable["BrokerMiddleware[MsgType]"]
_broker_middlewares: Sequence["BrokerMiddleware[MsgType]"]
_producer: Optional["ProducerProto"]

@abstractmethod
Expand Down Expand Up @@ -98,6 +107,6 @@ def add_call(
filter_: "Filter[Any]",
parser_: "CustomCallable",
decoder_: "CustomCallable",
middlewares_: Iterable["SubscriberMiddleware[Any]"],
middlewares_: Sequence["SubscriberMiddleware[Any]"],
dependencies_: Iterable["Depends"],
) -> Self: ...
27 changes: 16 additions & 11 deletions faststream/broker/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
overload,
Expand Down Expand Up @@ -53,11 +54,11 @@

class _CallOptions:
__slots__ = (
"filter",
"parser",
"decoder",
"middlewares",
"dependencies",
"filter",
"middlewares",
"parser",
)

def __init__(
Expand All @@ -66,7 +67,7 @@ def __init__(
filter: "Filter[Any]",
parser: Optional["CustomCallable"],
decoder: Optional["CustomCallable"],
middlewares: Iterable["SubscriberMiddleware[Any]"],
middlewares: Sequence["SubscriberMiddleware[Any]"],
dependencies: Iterable["Depends"],
) -> None:
self.filter = filter
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
no_reply: bool,
retry: Union[bool, int],
broker_dependencies: Iterable["Depends"],
broker_middlewares: Iterable["BrokerMiddleware[MsgType]"],
broker_middlewares: Sequence["BrokerMiddleware[MsgType]"],
default_parser: "AsyncCallable",
default_decoder: "AsyncCallable",
# AsyncAPI information
Expand Down Expand Up @@ -211,7 +212,7 @@ def add_call(
filter_: "Filter[Any]",
parser_: Optional["CustomCallable"],
decoder_: Optional["CustomCallable"],
middlewares_: Iterable["SubscriberMiddleware[Any]"],
middlewares_: Sequence["SubscriberMiddleware[Any]"],
dependencies_: Iterable["Depends"],
) -> Self:
self._call_options = _CallOptions(
Expand All @@ -231,7 +232,7 @@ def __call__(
filter: Optional["Filter[Any]"] = None,
parser: Optional["CustomCallable"] = None,
decoder: Optional["CustomCallable"] = None,
middlewares: Iterable["SubscriberMiddleware[Any]"] = (),
middlewares: Sequence["SubscriberMiddleware[Any]"] = (),
dependencies: Iterable["Depends"] = (),
) -> Callable[
[Callable[P_HandlerParams, T_HandlerReturn]],
Expand All @@ -246,7 +247,7 @@ def __call__(
filter: Optional["Filter[Any]"] = None,
parser: Optional["CustomCallable"] = None,
decoder: Optional["CustomCallable"] = None,
middlewares: Iterable["SubscriberMiddleware[Any]"] = (),
middlewares: Sequence["SubscriberMiddleware[Any]"] = (),
dependencies: Iterable["Depends"] = (),
) -> "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]": ...

Expand All @@ -257,7 +258,7 @@ def __call__(
filter: Optional["Filter[Any]"] = None,
parser: Optional["CustomCallable"] = None,
decoder: Optional["CustomCallable"] = None,
middlewares: Iterable["SubscriberMiddleware[Any]"] = (),
middlewares: Sequence["SubscriberMiddleware[Any]"] = (),
dependencies: Iterable["Depends"] = (),
) -> Any:
if (options := self._call_options) is None:
Expand Down Expand Up @@ -367,7 +368,9 @@ async def process_message(self, msg: MsgType) -> "Response":
await h.call(
message=message,
# consumer middlewares
_extra_middlewares=(m.consume_scope for m in middlewares),
_extra_middlewares=(
m.consume_scope for m in middlewares[::-1]
),
)
)

Expand All @@ -382,7 +385,9 @@ async def process_message(self, msg: MsgType) -> "Response":
result_msg.body,
**result_msg.as_publish_kwargs(),
# publisher middlewares
_extra_middlewares=(m.publish_scope for m in middlewares),
_extra_middlewares=[
m.publish_scope for m in middlewares[::-1]
],
)

# Return data for tests
Expand Down
6 changes: 3 additions & 3 deletions faststream/broker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
AsyncContextManager,
Awaitable,
Callable,
Iterable,
Optional,
Sequence,
Type,
Union,
cast,
Expand Down Expand Up @@ -37,7 +37,7 @@

async def process_msg(
msg: Optional[MsgType],
middlewares: Iterable["BrokerMiddleware[MsgType]"],
middlewares: Sequence["BrokerMiddleware[MsgType]"],
parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]],
decoder: Callable[["StreamMessage[MsgType]"], "Any"],
) -> Optional["StreamMessage[MsgType]"]:
Expand All @@ -50,7 +50,7 @@ async def process_msg(
Awaitable[StreamMessage[MsgType]],
] = return_input

for m in middlewares:
for m in middlewares[::-1]:
mid = m(msg)
await stack.enter_async_context(mid)
return_msg = partial(mid.consume_scope, return_msg)
Expand Down
Loading
Loading