Skip to content

Commit

Permalink
fix: correct middlewares
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik committed Jan 13, 2024
1 parent c12df70 commit 1ee5d1e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 56 deletions.
74 changes: 37 additions & 37 deletions faststream/broker/core/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
TYPE_CHECKING,
Any,
AsyncContextManager,
AsyncGenerator,
Awaitable,
Callable,
Dict,
Expand Down Expand Up @@ -85,15 +84,11 @@ def description(self) -> Optional[str]:
description = getattr(caller, "__doc__", None)
return description

async def call(
async def is_suitable(
self,
msg: MsgType,
cache: Dict[Any, Any],
extra_middlewares: Sequence["BaseMiddleware"],
) -> AsyncGenerator[
Union["StreamMessage[MsgType]", None, SendableMessage],
None,
]:
) -> Optional["StreamMessage[MsgType]"]:
message = cache[self.parser] = cache.get(
self.parser,
await self.parser(msg),
Expand All @@ -104,36 +99,40 @@ async def call(
)

if await self.filter(message):
yield message

result = None
async with AsyncExitStack() as consume_stack:
for middleware in chain(self.middlewares, extra_middlewares):
message.decoded_body = await consume_stack.enter_async_context(
middleware.consume_scope(message.decoded_body)
)
return message

try:
result = await self.handler.call_wrapped(message)
async def call(
self,
message: "StreamMessage[MsgType]",
extra_middlewares: Sequence["BaseMiddleware"],
) -> Optional[SendableMessage]:
assert message.decoded_body

result: SendableMessage = None
async with AsyncExitStack() as consume_stack:
for middleware in chain(self.middlewares, extra_middlewares):
message.decoded_body = await consume_stack.enter_async_context(
middleware.consume_scope(message.decoded_body)
)

except StopConsume:
self.handler.trigger()
raise
try:
result = await self.handler.call_wrapped(message)

except HandlerException:
self.handler.trigger()
raise
except StopConsume:
self.handler.trigger()
raise

except Exception as e:
self.handler.trigger(error=e)
raise e
except HandlerException:
self.handler.trigger()
raise

else:
self.handler.trigger(result=result[0] if result else None)
yield result
except Exception as e:
self.handler.trigger(error=e)
raise e

else:
yield None
else:
self.handler.trigger(result=result)
return result


class BaseHandler(AsyncAPIOperation, WrapHandlerMixin[MsgType]):
Expand Down Expand Up @@ -290,8 +289,7 @@ async def consume(self, msg: MsgType) -> SendableMessage:
Returns:
The sendable message.
"""
result: Optional[SendableMessage] = None
result_msg: SendableMessage = None
result_msg: Optional[SendableMessage] = None

if not self.running:
return result_msg
Expand All @@ -313,10 +311,10 @@ async def consume(self, msg: MsgType) -> SendableMessage:
if processed:
break

caller = h.call(msg, cache, middlewares)

if (
message := cast("StreamMessage[MsgType]", await caller.asend(None))
message := cast(
"StreamMessage[MsgType]", await h.is_suitable(msg, cache)
)
) is not None:
await stack.enter_async_context(self.watcher(message))
stack.enter_context(context.scope("message", message))
Expand All @@ -336,7 +334,9 @@ async def close_middlewares(
processed = True

try:
result_msg = cast(SendableMessage, await caller.asend(None))
result_msg = cast(
SendableMessage, await h.call(message, middlewares)
)
except StopConsume:
await self.close()
return
Expand Down
10 changes: 3 additions & 7 deletions faststream/broker/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,9 @@ class BaseMiddleware:
Asynchronous function to handle the after publish event.
"""

def __init__(self, msg: Any) -> None:
"""Initialize the class.
Args:
msg: Any message to be stored.
"""
self.msg = msg
def __init__(self) -> None:
"""Initialize the class."""
pass

async def on_receive(self) -> None:
pass
Expand Down
27 changes: 15 additions & 12 deletions tests/brokers/base/middlewares.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Type
from typing import Optional, Type
from unittest.mock import Mock

import pytest
Expand All @@ -25,35 +25,38 @@ async def test_local_middleware(
self, event: asyncio.Event, queue: str, mock: Mock, raw_broker
):
class mid(BaseMiddleware): # noqa: N801
async def on_receive(self):
mock.start(self.msg)
return await super().on_receive()
async def on_consume(self, msg):
mock.start(msg)
return await super().on_consume(msg)

async def after_processed(self, exc_type, exc_val, exec_tb):
async def after_consume(self, err: Optional[Exception]) -> None:
mock.end()
return await super().after_processed(exc_type, exc_val, exec_tb)
event.set()
return await super().after_consume(err)

broker = self.broker_class()

@broker.subscriber(queue, middlewares=(mid,))
@broker.subscriber(queue, middlewares=(mid(),))
async def handler(m):
event.set()
return ""
mock.inner(m)
return "end"

broker = self.patch_broker(raw_broker, broker)

async with broker:
await broker.start()
await asyncio.wait(
(
asyncio.create_task(broker.publish("", queue)),
asyncio.create_task(broker.publish("start", queue)),
asyncio.create_task(event.wait()),
),
timeout=3,
)

mock.start.assert_called_once_with("start")
mock.inner.assert_called_once_with("start")

assert event.is_set()
mock.start.assert_called_once()
mock.end.assert_called_once()

async def test_local_middleware_not_shared_between_subscribers(
Expand Down Expand Up @@ -236,7 +239,7 @@ async def after_processed(self, exc_type, exc_val, exec_tb):
return await super().after_processed(exc_type, exc_val, exec_tb)

broker = self.broker_class(
middlewares=(mid,),
middlewares=(mid(None),),
)

@broker.subscriber(queue)
Expand Down

0 comments on commit 1ee5d1e

Please sign in to comment.