Skip to content

Commit

Permalink
Nats PullStreamSubscriber get_one
Browse files Browse the repository at this point in the history
  • Loading branch information
KrySeyt committed Aug 27, 2024
1 parent e4d7079 commit f6d136b
Showing 1 changed file with 89 additions and 2 deletions.
91 changes: 89 additions & 2 deletions faststream/nats/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,11 @@ def clear_subject(self) -> str:
async def start(self) -> None:
"""Create NATS subscription and start consume tasks."""
assert self._connection, NOT_CONNECTED_YET # nosec B101
await super().start()

if not self.calls:
return None
return

await super().start()

await self._create_subscription(connection=self._connection)

Expand Down Expand Up @@ -613,6 +614,7 @@ def get_log_context(
],
) -> Dict[str, str]:
"""Log context factory using in `self.consume` scope."""

return self.build_log_context(
message=message,
subject=self._resolved_subject_string,
Expand Down Expand Up @@ -752,6 +754,35 @@ def __init__(
include_in_schema=include_in_schema,
)

async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]:
if not self.subscription:
await self._create_subscription(connection=self._connection)

try:
raw_message ,= await self.subscription.fetch(
batch=1,
timeout=timeout,
)
except TimeoutError:
raw_message = None

if not raw_message:
return None

async with AsyncExitStack() as stack:
return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = (
return_input
)

for m in self._broker_middlewares:
mid = m(raw_message)
await stack.enter_async_context(mid)
return_msg = partial(mid.consume_scope, return_msg)

parsed_msg = await self._parser(raw_message)
parsed_msg._decoded_body = await self._decoder(parsed_msg)
return await return_msg(parsed_msg)

@override
async def _create_subscription( # type: ignore[override]
self,
Expand Down Expand Up @@ -832,6 +863,36 @@ def __init__(
include_in_schema=include_in_schema,
)

async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]:
if not self.subscription:
self.subscription = await self._connection.pull_subscribe(
subject=self.clear_subject,
config=self.config,
**self.extra_options,
)

raw_message ,= await self.subscription.fetch(
batch=1,
timeout=timeout,
)

if not raw_message:
return None

async with AsyncExitStack() as stack:
return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = (
return_input
)

for m in self._broker_middlewares:
mid = m(raw_message)
await stack.enter_async_context(mid)
return_msg = partial(mid.consume_scope, return_msg)

parsed_msg = await self._parser(raw_message)
parsed_msg._decoded_body = await self._decoder(parsed_msg)
return await return_msg(parsed_msg)

@override
async def _create_subscription( # type: ignore[override]
self,
Expand Down Expand Up @@ -901,6 +962,32 @@ def __init__(
include_in_schema=include_in_schema,
)

async def get_one(self, *, timeout: float = 5) -> Optional[NatsMessage]:
if not self.subscription:
await self._create_subscription(connection=self._connection)

raw_message ,= await self.subscription.fetch(
batch=1,
timeout=timeout,
)

if not raw_message:
return None

async with AsyncExitStack() as stack:
return_msg: Callable[[NatsMessage], Awaitable[NatsMessage]] = (
return_input
)

for m in self._broker_middlewares:
mid = m(raw_message)
await stack.enter_async_context(mid)
return_msg = partial(mid.consume_scope, return_msg)

parsed_msg = await self._parser(raw_message)
parsed_msg._decoded_body = await self._decoder(parsed_msg)
return await return_msg(parsed_msg)

@override
async def _create_subscription( # type: ignore[override]
self,
Expand Down

0 comments on commit f6d136b

Please sign in to comment.