Skip to content

Commit

Permalink
refactore: catch StopConsume at any level
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik committed Jan 13, 2024
1 parent 7fab3f2 commit cc5c7fb
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 51 deletions.
98 changes: 48 additions & 50 deletions faststream/broker/core/handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import abstractmethod
from contextlib import AsyncExitStack
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass
from inspect import unwrap
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
AsyncIterator,
Awaitable,
Callable,
Dict,
Expand All @@ -17,7 +18,6 @@
Tuple,
Type,
Union,
cast,
)

from faststream.asyncapi.base import AsyncAPIOperation
Expand Down Expand Up @@ -186,40 +186,12 @@ async def close(self) -> None:
self.running = False
await self.lock.wait_release(self.graceful_timeout)

@property
def call_name(self) -> str:
"""Returns the name of the handler call."""
return to_camelcase(self.calls[0].call_name)

@property
def description(self) -> Optional[str]:
"""Returns the description of the handler."""
if self._description:
return self._description

if not self.calls: # pragma: no cover
return None

else:
return self.calls[0].description

def get_payloads(self) -> List[Tuple[AnyDict, str]]:
"""Get the payloads of the handler."""
payloads: List[Tuple[AnyDict, str]] = []

for h in self.calls:
body = parse_handler_params(
h.dependant,
prefix=f"{self._title or self.call_name}:Message",
)
payloads.append(
(
body,
to_camelcase(h.call_name),
)
)

return payloads
@asynccontextmanager
async def stop_scope(self) -> AsyncIterator[None]:
try:
yield
except StopConsume:
await self.close()

def add_call(
self,
Expand Down Expand Up @@ -295,8 +267,8 @@ async def consume(self, msg: MsgType) -> SendableMessage:
middlewares = []
async with AsyncExitStack() as stack:
stack.enter_context(self.lock)

stack.enter_context(context.scope("handler_", self))
await stack.enter_async_context(self.stop_scope())

for m in self.middlewares:
middleware = m(msg)
Expand All @@ -305,12 +277,7 @@ async def consume(self, msg: MsgType) -> SendableMessage:

cache = {}
for h in self.calls:
if (
message := cast(
"StreamMessage[MsgType]",
await h.is_suitable(msg, cache),
)
) is not None:
if (message := await h.is_suitable(msg, cache)) is not None:
await stack.enter_async_context(self.watcher(message))
stack.enter_context(context.scope("message", message))
stack.enter_context(
Expand All @@ -326,13 +293,7 @@ async def close_middlewares(
for m in middlewares:
await m.__aexit__(exc_type, exc_val, exec_tb)

try:
result_msg = cast(
SendableMessage, await h.call(message, middlewares)
)
except StopConsume:
await self.close()
return
result_msg = await h.call(message, middlewares)

async with AsyncExitStack() as pub_stack:
result_msg = result_msg
Expand All @@ -359,3 +320,40 @@ def make_response_publisher(
self, message: "StreamMessage[MsgType]"
) -> Sequence[PublisherProtocol]:
raise NotImplementedError()

# AsyncAPI methods

@property
def call_name(self) -> str:
"""Returns the name of the handler call."""
return to_camelcase(self.calls[0].call_name)

@property
def description(self) -> Optional[str]:
"""Returns the description of the handler."""
if self._description:
return self._description

if not self.calls: # pragma: no cover
return None

else:
return self.calls[0].description

def get_payloads(self) -> List[Tuple[AnyDict, str]]:
"""Get the payloads of the handler."""
payloads: List[Tuple[AnyDict, str]] = []

for h in self.calls:
body = parse_handler_params(
h.dependant,
prefix=f"{self._title or self.call_name}:Message",
)
payloads.append(
(
body,
to_camelcase(h.call_name),
)
)

return payloads
2 changes: 1 addition & 1 deletion faststream/nats/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ async def publish( # type: ignore[override]
async with AsyncExitStack() as stack:
for m in self.middlewares:
message = await stack.enter_async_context(
m().publish_scope(message)
m(None).publish_scope(message)
)

return await publisher.publish(message, *args, **kwargs)
Expand Down

0 comments on commit cc5c7fb

Please sign in to comment.