diff --git a/tests/brokers/base/middlewares.py b/tests/brokers/base/middlewares.py index d16c4861ea..fc965cf21a 100644 --- a/tests/brokers/base/middlewares.py +++ b/tests/brokers/base/middlewares.py @@ -91,8 +91,8 @@ async def publish_scope(self, call_next, msg, *args, **kwargs): publisher = broker.publisher( queue, middlewares=[ - MiddleMiddleware(None, context=None).publish_scope, - InnerMiddleware(None, context=None).publish_scope, + MiddleMiddleware(None).publish_scope, + InnerMiddleware(None).publish_scope, ], ) @@ -178,8 +178,8 @@ async def consume_scope(self, call_next, msg): args, kwargs = self.get_subscriber_params( queue, middlewares=[ - MiddleMiddleware(None, context=None).consume_scope, - InnerMiddleware(None, context=None).consume_scope, + MiddleMiddleware(None).consume_scope, + InnerMiddleware(None).consume_scope, ], ) @@ -236,6 +236,58 @@ async def handler(msg): assert [c.args[0] for c in mock.call_args_list] == ["outer", "middle", "inner"] + async def test_aenter_aexit(self, queue: str, mock: Mock): + class InnerMiddleware(BaseMiddleware): + async def __aenter__(self): + mock.enter_inner() + mock.sub("inner") + return self + + async def __aexit__( + self, + exc_type = None, + exc_val = None, + exc_tb = None, + ): + mock.exit_inner() + mock.pub("inner") + return await self.after_processed(exc_type, exc_val, exc_tb) + + class OuterMiddleware(BaseMiddleware): + async def __aenter__(self): + mock.enter_inner() + mock.sub("outer") + return self + + async def __aexit__( + self, + exc_type=None, + exc_val=None, + exc_tb=None, + ): + mock.exit_inner() + mock.pub("outer") + return await self.after_processed(exc_type, exc_val, exc_tb) + + broker = self.broker_class(middlewares=[OuterMiddleware, InnerMiddleware]) + + args, kwargs = self.get_subscriber_params(queue) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + pass + + async with self.patch_broker(broker) as br: + await br.publish(None, queue) + + mock.consume_inner.assert_called_once() + mock.consume_outer.assert_called_once() + mock.publish_inner.assert_called_once() + mock.publish_outer.assert_called_once() + + assert [c.args[0] for c in mock.sub.call_args_list] == ["outer", "inner"] + assert [c.args[0] for c in mock.pub.call_args_list] == ["outer", "inner"] + @pytest.mark.asyncio class LocalMiddlewareTestcase(BaseTestcaseConfig):