From b9c7dd442d1a015bc0aff92450aefd675e2e52f0 Mon Sep 17 00:00:00 2001 From: Flosckow <66554425+Flosckow@users.noreply.github.com> Date: Wed, 11 Dec 2024 02:41:00 +0700 Subject: [PATCH] Feat: replace subscribers (#1976) * Feat: replace subscribers * Fix: lint * Fix: add mixin * Fix: lint * Fix: add markers * chore: fix tests --------- Co-authored-by: Daniil Dumchenko Co-authored-by: Pastukhov Nikita Co-authored-by: Nikita Pastukhov --- faststream/redis/subscriber/specified.py | 8 +- faststream/redis/subscriber/usecase.py | 771 ------------------ .../redis/subscriber/usecases/__init__.py | 19 + faststream/redis/subscriber/usecases/basic.py | 155 ++++ .../subscriber/usecases/channel_subscriber.py | 155 ++++ .../subscriber/usecases/list_subscriber.py | 221 +++++ .../subscriber/usecases/stream_subscriber.py | 336 ++++++++ faststream/redis/testing.py | 10 +- 8 files changed, 896 insertions(+), 779 deletions(-) delete mode 100644 faststream/redis/subscriber/usecase.py create mode 100644 faststream/redis/subscriber/usecases/__init__.py create mode 100644 faststream/redis/subscriber/usecases/basic.py create mode 100644 faststream/redis/subscriber/usecases/channel_subscriber.py create mode 100644 faststream/redis/subscriber/usecases/list_subscriber.py create mode 100644 faststream/redis/subscriber/usecases/stream_subscriber.py diff --git a/faststream/redis/subscriber/specified.py b/faststream/redis/subscriber/specified.py index e943a80aeb..30591233db 100644 --- a/faststream/redis/subscriber/specified.py +++ b/faststream/redis/subscriber/specified.py @@ -3,10 +3,14 @@ ) from faststream.redis.schemas import ListSub, StreamSub from faststream.redis.schemas.proto import RedisSpecificationProtocol -from faststream.redis.subscriber.usecase import ( - BatchListSubscriber, +from faststream.redis.subscriber.usecases.channel_subscriber import ( ChannelSubscriber, +) +from faststream.redis.subscriber.usecases.list_subscriber import ( + BatchListSubscriber, ListSubscriber, +) +from faststream.redis.subscriber.usecases.stream_subscriber import ( StreamBatchSubscriber, StreamSubscriber, ) diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py deleted file mode 100644 index f1c9bb880f..0000000000 --- a/faststream/redis/subscriber/usecase.py +++ /dev/null @@ -1,771 +0,0 @@ -import math -from abc import abstractmethod -from collections.abc import Awaitable, Iterable, Sequence -from contextlib import suppress -from copy import deepcopy -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Optional, -) - -import anyio -from redis.asyncio.client import ( - PubSub as RPubSub, - Redis, -) -from redis.exceptions import ResponseError -from typing_extensions import TypeAlias, override - -from faststream._internal.subscriber.mixins import TasksMixin -from faststream._internal.subscriber.usecase import SubscriberUsecase -from faststream._internal.subscriber.utils import process_msg -from faststream.middlewares import AckPolicy -from faststream.redis.message import ( - BatchListMessage, - BatchStreamMessage, - DefaultListMessage, - DefaultStreamMessage, - PubSubMessage, - RedisListMessage, - RedisMessage, - RedisStreamMessage, - UnifyRedisDict, -) -from faststream.redis.parser import ( - RedisBatchListParser, - RedisBatchStreamParser, - RedisListParser, - RedisPubSubParser, - RedisStreamParser, -) -from faststream.redis.publisher.fake import RedisFakePublisher -from faststream.redis.schemas import ListSub, PubSub, StreamSub - -if TYPE_CHECKING: - from fast_depends.dependencies import Dependant - - from faststream._internal.basic_types import AnyDict - from faststream._internal.publisher.proto import BasePublisherProto - from faststream._internal.state import BrokerState - from faststream._internal.types import ( - AsyncCallable, - BrokerMiddleware, - CustomCallable, - ) - from faststream.message import StreamMessage as BrokerStreamMessage - - -TopicName: TypeAlias = bytes -Offset: TypeAlias = bytes - - -class LogicSubscriber(TasksMixin, SubscriberUsecase[UnifyRedisDict]): - """A class to represent a Redis handler.""" - - _client: Optional["Redis[bytes]"] - - def __init__( - self, - *, - default_parser: "AsyncCallable", - default_decoder: "AsyncCallable", - # Subscriber args - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], - ) -> None: - super().__init__( - default_parser=default_parser, - default_decoder=default_decoder, - # Propagated options - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - ) - - self._client = None - - @override - def _setup( # type: ignore[override] - self, - *, - connection: Optional["Redis[bytes]"], - # basic args - extra_context: "AnyDict", - # broker options - broker_parser: Optional["CustomCallable"], - broker_decoder: Optional["CustomCallable"], - # dependant args - state: "BrokerState", - ) -> None: - self._client = connection - - super()._setup( - extra_context=extra_context, - broker_parser=broker_parser, - broker_decoder=broker_decoder, - state=state, - ) - - def _make_response_publisher( - self, - message: "BrokerStreamMessage[UnifyRedisDict]", - ) -> Sequence["BasePublisherProto"]: - return ( - RedisFakePublisher( - self._state.get().producer, - channel=message.reply_to, - ), - ) - - @override - async def start( - self, - *args: Any, - ) -> None: - if self.tasks: - return - - await super().start() - - start_signal = anyio.Event() - - if self.calls: - self.add_task(self._consume(*args, start_signal=start_signal)) - - with anyio.fail_after(3.0): - await start_signal.wait() - - else: - start_signal.set() - - async def _consume(self, *args: Any, start_signal: anyio.Event) -> None: - connected = True - - while self.running: - try: - await self._get_msgs(*args) - - except Exception: # noqa: PERF203 - if connected: - connected = False - await anyio.sleep(5) - - else: - if not connected: - connected = True - - finally: - if not start_signal.is_set(): - with suppress(Exception): - start_signal.set() - - @abstractmethod - async def _get_msgs(self, *args: Any) -> None: - raise NotImplementedError - - @staticmethod - def build_log_context( - message: Optional["BrokerStreamMessage[Any]"], - channel: str = "", - ) -> dict[str, str]: - return { - "channel": channel, - "message_id": getattr(message, "message_id", ""), - } - - -class ChannelSubscriber(LogicSubscriber): - subscription: Optional[RPubSub] - - def __init__( - self, - *, - channel: "PubSub", - # Subscriber args - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], - ) -> None: - parser = RedisPubSubParser(pattern=channel.path_regex) - super().__init__( - default_parser=parser.parse_message, - default_decoder=parser.decode_message, - # Propagated options - ack_policy=AckPolicy.DO_NOTHING, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - ) - - self.channel = channel - self.subscription = None - - def get_log_context( - self, - message: Optional["BrokerStreamMessage[Any]"], - ) -> dict[str, str]: - return self.build_log_context( - message=message, - channel=self.channel.name, - ) - - @override - async def start(self) -> None: - if self.subscription: - return - - assert self._client, "You should setup subscriber at first." # nosec B101 - - self.subscription = psub = self._client.pubsub() - - if self.channel.pattern: - await psub.psubscribe(self.channel.name) - else: - await psub.subscribe(self.channel.name) - - await super().start(psub) - - async def close(self) -> None: - if self.subscription is not None: - await self.subscription.unsubscribe() - await self.subscription.aclose() # type: ignore[attr-defined] - self.subscription = None - - await super().close() - - @override - async def get_one( # type: ignore[override] - self, - *, - timeout: float = 5.0, - ) -> "Optional[RedisMessage]": - assert self.subscription, "You should start subscriber at first." # nosec B101 - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - sleep_interval = timeout / 10 - - raw_message: Optional[PubSubMessage] = None - - with anyio.move_on_after(timeout): - while (raw_message := await self._get_message(self.subscription)) is None: # noqa: ASYNC110 - await anyio.sleep(sleep_interval) - - context = self._state.get().di_state.context - - msg: Optional[RedisMessage] = await process_msg( # type: ignore[assignment] - msg=raw_message, - middlewares=( - m(raw_message, context=context) for m in self._broker_middlewares - ), - parser=self._parser, - decoder=self._decoder, - ) - return msg - - async def _get_message(self, psub: RPubSub) -> Optional[PubSubMessage]: - raw_msg = await psub.get_message( - ignore_subscribe_messages=True, - timeout=self.channel.polling_interval, - ) - - if raw_msg: - return PubSubMessage( - type=raw_msg["type"], - data=raw_msg["data"], - channel=raw_msg["channel"].decode(), - pattern=raw_msg["pattern"], - ) - - return None - - async def _get_msgs(self, psub: RPubSub) -> None: - if msg := await self._get_message(psub): - await self.consume(msg) # type: ignore[arg-type] - - def add_prefix(self, prefix: str) -> None: - new_ch = deepcopy(self.channel) - new_ch.name = f"{prefix}{new_ch.name}" - self.channel = new_ch - - -class _ListHandlerMixin(LogicSubscriber): - def __init__( - self, - *, - list: ListSub, - default_parser: "AsyncCallable", - default_decoder: "AsyncCallable", - # Subscriber args - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], - ) -> None: - super().__init__( - default_parser=default_parser, - default_decoder=default_decoder, - # Propagated options - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - ) - - self.list_sub = list - - def get_log_context( - self, - message: Optional["BrokerStreamMessage[Any]"], - ) -> dict[str, str]: - return self.build_log_context( - message=message, - channel=self.list_sub.name, - ) - - @override - async def _consume( # type: ignore[override] - self, - client: "Redis[bytes]", - *, - start_signal: "anyio.Event", - ) -> None: - start_signal.set() - await super()._consume(client, start_signal=start_signal) - - @override - async def start(self) -> None: - if self.tasks: - return - - assert self._client, "You should setup subscriber at first." # nosec B101 - - await super().start(self._client) - - @override - async def get_one( # type: ignore[override] - self, - *, - timeout: float = 5.0, - ) -> "Optional[RedisListMessage]": - assert self._client, "You should start subscriber at first." # nosec B101 - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - sleep_interval = timeout / 10 - raw_message = None - - with anyio.move_on_after(timeout): - while ( # noqa: ASYNC110 - raw_message := await self._client.lpop(name=self.list_sub.name) - ) is None: - await anyio.sleep(sleep_interval) - - if not raw_message: - return None - - redis_incoming_msg = DefaultListMessage( - type="list", - data=raw_message, - channel=self.list_sub.name, - ) - - context = self._state.get().di_state.context - - msg: RedisListMessage = await process_msg( # type: ignore[assignment] - msg=redis_incoming_msg, - middlewares=( - m(redis_incoming_msg, context=context) for m in self._broker_middlewares - ), - parser=self._parser, - decoder=self._decoder, - ) - return msg - - def add_prefix(self, prefix: str) -> None: - new_list = deepcopy(self.list_sub) - new_list.name = f"{prefix}{new_list.name}" - self.list_sub = new_list - - -class ListSubscriber(_ListHandlerMixin): - def __init__( - self, - *, - list: ListSub, - # Subscriber args - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], - ) -> None: - parser = RedisListParser() - super().__init__( - list=list, - default_parser=parser.parse_message, - default_decoder=parser.decode_message, - # Propagated options - ack_policy=AckPolicy.DO_NOTHING, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - ) - - async def _get_msgs(self, client: "Redis[bytes]") -> None: - raw_msg = await client.blpop( - self.list_sub.name, - timeout=self.list_sub.polling_interval, - ) - - if raw_msg: - _, msg_data = raw_msg - - msg = DefaultListMessage( - type="list", - data=msg_data, - channel=self.list_sub.name, - ) - - await self.consume(msg) - - -class BatchListSubscriber(_ListHandlerMixin): - def __init__( - self, - *, - list: ListSub, - # Subscriber args - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], - ) -> None: - parser = RedisBatchListParser() - super().__init__( - list=list, - default_parser=parser.parse_message, - default_decoder=parser.decode_message, - # Propagated options - ack_policy=AckPolicy.DO_NOTHING, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - ) - - async def _get_msgs(self, client: "Redis[bytes]") -> None: - raw_msgs = await client.lpop( - name=self.list_sub.name, - count=self.list_sub.max_records, - ) - - if raw_msgs: - msg = BatchListMessage( - type="blist", - channel=self.list_sub.name, - data=raw_msgs, - ) - - await self.consume(msg) # type: ignore[arg-type] - - else: - await anyio.sleep(self.list_sub.polling_interval) - - -class _StreamHandlerMixin(LogicSubscriber): - def __init__( - self, - *, - stream: StreamSub, - default_parser: "AsyncCallable", - default_decoder: "AsyncCallable", - # Subscriber args - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], - ) -> None: - super().__init__( - default_parser=default_parser, - default_decoder=default_decoder, - # Propagated options - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - ) - - self.stream_sub = stream - self.last_id = stream.last_id - - def get_log_context( - self, - message: Optional["BrokerStreamMessage[Any]"], - ) -> dict[str, str]: - return self.build_log_context( - message=message, - channel=self.stream_sub.name, - ) - - @override - async def start(self) -> None: - if self.tasks: - return - - assert self._client, "You should setup subscriber at first." # nosec B101 - - client = self._client - - self.extra_watcher_options.update( - redis=client, - group=self.stream_sub.group, - ) - - stream = self.stream_sub - - read: Callable[ - [str], - Awaitable[ - tuple[ - tuple[ - TopicName, - tuple[ - tuple[ - Offset, - dict[bytes, bytes], - ], - ..., - ], - ], - ..., - ], - ], - ] - - if stream.group and stream.consumer: - try: - await client.xgroup_create( - name=stream.name, - id=self.last_id, - groupname=stream.group, - mkstream=True, - ) - except ResponseError as e: - if "already exists" not in str(e): - raise - - def read( - _: str, - ) -> Awaitable[ - tuple[ - tuple[ - TopicName, - tuple[ - tuple[ - Offset, - dict[bytes, bytes], - ], - ..., - ], - ], - ..., - ], - ]: - return client.xreadgroup( - groupname=stream.group, - consumername=stream.consumer, - streams={stream.name: ">"}, - count=stream.max_records, - block=stream.polling_interval, - noack=stream.no_ack, - ) - - else: - - def read( - last_id: str, - ) -> Awaitable[ - tuple[ - tuple[ - TopicName, - tuple[ - tuple[ - Offset, - dict[bytes, bytes], - ], - ..., - ], - ], - ..., - ], - ]: - return client.xread( - {stream.name: last_id}, - block=stream.polling_interval, - count=stream.max_records, - ) - - await super().start(read) - - @override - async def get_one( # type: ignore[override] - self, - *, - timeout: float = 5.0, - ) -> "Optional[RedisStreamMessage]": - assert self._client, "You should start subscriber at first." # nosec B101 - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - stream_message = await self._client.xread( - {self.stream_sub.name: self.last_id}, - block=math.ceil(timeout * 1000), - count=1, - ) - - if not stream_message: - return None - - ((stream_name, ((message_id, raw_message),)),) = stream_message - - self.last_id = message_id.decode() - - redis_incoming_msg = DefaultStreamMessage( - type="stream", - channel=stream_name.decode(), - message_ids=[message_id], - data=raw_message, - ) - - context = self._state.get().di_state.context - - msg: RedisStreamMessage = await process_msg( # type: ignore[assignment] - msg=redis_incoming_msg, - middlewares=( - m(redis_incoming_msg, context=context) for m in self._broker_middlewares - ), - parser=self._parser, - decoder=self._decoder, - ) - return msg - - def add_prefix(self, prefix: str) -> None: - new_stream = deepcopy(self.stream_sub) - new_stream.name = f"{prefix}{new_stream.name}" - self.stream_sub = new_stream - - -class StreamSubscriber(_StreamHandlerMixin): - def __init__( - self, - *, - stream: StreamSub, - # Subscriber args - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], - ) -> None: - parser = RedisStreamParser() - super().__init__( - stream=stream, - default_parser=parser.parse_message, - default_decoder=parser.decode_message, - # Propagated options - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - ) - - async def _get_msgs( - self, - read: Callable[ - [str], - Awaitable[ - tuple[ - tuple[ - TopicName, - tuple[ - tuple[ - Offset, - dict[bytes, bytes], - ], - ..., - ], - ], - ..., - ], - ], - ], - ) -> None: - for stream_name, msgs in await read(self.last_id): - if msgs: - self.last_id = msgs[-1][0].decode() - - for message_id, raw_msg in msgs: - msg = DefaultStreamMessage( - type="stream", - channel=stream_name.decode(), - message_ids=[message_id], - data=raw_msg, - ) - - await self.consume(msg) # type: ignore[arg-type] - - -class StreamBatchSubscriber(_StreamHandlerMixin): - def __init__( - self, - *, - stream: StreamSub, - # Subscriber args - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], - ) -> None: - parser = RedisBatchStreamParser() - super().__init__( - stream=stream, - default_parser=parser.parse_message, - default_decoder=parser.decode_message, - # Propagated options - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - ) - - async def _get_msgs( - self, - read: Callable[ - [str], - Awaitable[ - tuple[tuple[bytes, tuple[tuple[bytes, dict[bytes, bytes]], ...]], ...], - ], - ], - ) -> None: - for stream_name, msgs in await read(self.last_id): - if msgs: - self.last_id = msgs[-1][0].decode() - - data: list[dict[bytes, bytes]] = [] - ids: list[bytes] = [] - for message_id, i in msgs: - data.append(i) - ids.append(message_id) - - msg = BatchStreamMessage( - type="bstream", - channel=stream_name.decode(), - data=data, - message_ids=ids, - ) - - await self.consume(msg) # type: ignore[arg-type] diff --git a/faststream/redis/subscriber/usecases/__init__.py b/faststream/redis/subscriber/usecases/__init__.py new file mode 100644 index 0000000000..32ad97f400 --- /dev/null +++ b/faststream/redis/subscriber/usecases/__init__.py @@ -0,0 +1,19 @@ +from .basic import LogicSubscriber +from .channel_subscriber import ChannelSubscriber +from .list_subscriber import BatchListSubscriber, ListSubscriber, _ListHandlerMixin +from .stream_subscriber import ( + StreamBatchSubscriber, + StreamSubscriber, + _StreamHandlerMixin, +) + +__all__ = ( + "BatchListSubscriber", + "ChannelSubscriber", + "ListSubscriber", + "LogicSubscriber", + "StreamBatchSubscriber", + "StreamSubscriber", + "_ListHandlerMixin", + "_StreamHandlerMixin", +) diff --git a/faststream/redis/subscriber/usecases/basic.py b/faststream/redis/subscriber/usecases/basic.py new file mode 100644 index 0000000000..a5592f3730 --- /dev/null +++ b/faststream/redis/subscriber/usecases/basic.py @@ -0,0 +1,155 @@ +from abc import abstractmethod +from collections.abc import Iterable, Sequence +from contextlib import suppress +from typing import ( + TYPE_CHECKING, + Any, + Optional, +) + +import anyio +from typing_extensions import TypeAlias, override + +from faststream._internal.subscriber.mixins import TasksMixin +from faststream._internal.subscriber.usecase import SubscriberUsecase +from faststream.redis.message import ( + UnifyRedisDict, +) +from faststream.redis.publisher.fake import RedisFakePublisher + +if TYPE_CHECKING: + from fast_depends.dependencies import Dependant + from redis.asyncio.client import Redis + + from faststream._internal.basic_types import AnyDict + from faststream._internal.publisher.proto import BasePublisherProto + from faststream._internal.state import BrokerState, Pointer + from faststream._internal.types import ( + AsyncCallable, + BrokerMiddleware, + CustomCallable, + ) + from faststream.message import StreamMessage as BrokerStreamMessage + from faststream.middlewares import AckPolicy + + +TopicName: TypeAlias = bytes +Offset: TypeAlias = bytes + + +class LogicSubscriber(TasksMixin, SubscriberUsecase[UnifyRedisDict]): + """A class to represent a Redis handler.""" + + _client: Optional["Redis[bytes]"] + + def __init__( + self, + *, + default_parser: "AsyncCallable", + default_decoder: "AsyncCallable", + # Subscriber args + ack_policy: "AckPolicy", + no_reply: bool, + broker_dependencies: Iterable["Dependant"], + broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], + ) -> None: + super().__init__( + default_parser=default_parser, + default_decoder=default_decoder, + # Propagated options + ack_policy=ack_policy, + no_reply=no_reply, + broker_middlewares=broker_middlewares, + broker_dependencies=broker_dependencies, + ) + + self._client = None + + @override + def _setup( # type: ignore[override] + self, + *, + connection: Optional["Redis[bytes]"], + # basic args + extra_context: "AnyDict", + # broker options + broker_parser: Optional["CustomCallable"], + broker_decoder: Optional["CustomCallable"], + # dependant args + state: "Pointer[BrokerState]", + ) -> None: + self._client = connection + + super()._setup( + extra_context=extra_context, + broker_parser=broker_parser, + broker_decoder=broker_decoder, + state=state, + ) + + def _make_response_publisher( + self, + message: "BrokerStreamMessage[UnifyRedisDict]", + ) -> Sequence["BasePublisherProto"]: + return ( + RedisFakePublisher( + self._state.get().producer, + channel=message.reply_to, + ), + ) + + @override + async def start( + self, + *args: Any, + ) -> None: + if self.tasks: + return + + await super().start() + + start_signal = anyio.Event() + + if self.calls: + self.add_task(self._consume(*args, start_signal=start_signal)) + + with anyio.fail_after(3.0): + await start_signal.wait() + + else: + start_signal.set() + + async def _consume(self, *args: Any, start_signal: anyio.Event) -> None: + connected = True + + while self.running: + try: + await self._get_msgs(*args) + + except Exception: # noqa: PERF203 + if connected: + connected = False + await anyio.sleep(5) + + else: + if not connected: + connected = True + + finally: + if not start_signal.is_set(): + with suppress(Exception): + start_signal.set() + + @abstractmethod + async def _get_msgs(self, *args: Any) -> None: + raise NotImplementedError + + @staticmethod + def build_log_context( + message: Optional["BrokerStreamMessage[Any]"], + channel: str = "", + ) -> dict[str, str]: + return { + "channel": channel, + "message_id": getattr(message, "message_id", ""), + } diff --git a/faststream/redis/subscriber/usecases/channel_subscriber.py b/faststream/redis/subscriber/usecases/channel_subscriber.py new file mode 100644 index 0000000000..e7c261ad07 --- /dev/null +++ b/faststream/redis/subscriber/usecases/channel_subscriber.py @@ -0,0 +1,155 @@ +from collections.abc import Iterable, Sequence +from copy import deepcopy +from typing import ( + TYPE_CHECKING, + Any, + Optional, +) + +import anyio +from redis.asyncio.client import ( + PubSub as RPubSub, +) +from typing_extensions import TypeAlias, override + +from faststream._internal.subscriber.utils import process_msg +from faststream.middlewares import AckPolicy +from faststream.redis.message import ( + PubSubMessage, + RedisMessage, + UnifyRedisDict, +) +from faststream.redis.parser import ( + RedisPubSubParser, +) + +from .basic import LogicSubscriber + +if TYPE_CHECKING: + from fast_depends.dependencies import Dependant + + from faststream._internal.types import ( + BrokerMiddleware, + ) + from faststream.message import StreamMessage as BrokerStreamMessage + from faststream.redis.schemas import PubSub + + +TopicName: TypeAlias = bytes +Offset: TypeAlias = bytes + + +class ChannelSubscriber(LogicSubscriber): + subscription: Optional[RPubSub] + + def __init__( + self, + *, + channel: "PubSub", + # Subscriber args + no_reply: bool, + broker_dependencies: Iterable["Dependant"], + broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], + ) -> None: + parser = RedisPubSubParser(pattern=channel.path_regex) + super().__init__( + default_parser=parser.parse_message, + default_decoder=parser.decode_message, + # Propagated options + ack_policy=AckPolicy.DO_NOTHING, + no_reply=no_reply, + broker_middlewares=broker_middlewares, + broker_dependencies=broker_dependencies, + ) + + self.channel = channel + self.subscription = None + + def get_log_context( + self, + message: Optional["BrokerStreamMessage[Any]"], + ) -> dict[str, str]: + return self.build_log_context( + message=message, + channel=self.channel.name, + ) + + @override + async def start(self) -> None: + if self.subscription: + return + + assert self._client, "You should setup subscriber at first." # nosec B101 + + self.subscription = psub = self._client.pubsub() + + if self.channel.pattern: + await psub.psubscribe(self.channel.name) + else: + await psub.subscribe(self.channel.name) + + await super().start(psub) + + async def close(self) -> None: + if self.subscription is not None: + await self.subscription.unsubscribe() + await self.subscription.aclose() # type: ignore[attr-defined] + self.subscription = None + + await super().close() + + @override + async def get_one( + self, + *, + timeout: float = 5.0, + ) -> "Optional[RedisMessage]": + assert self.subscription, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + sleep_interval = timeout / 10 + + raw_message: Optional[PubSubMessage] = None + + with anyio.move_on_after(timeout): + while (raw_message := await self._get_message(self.subscription)) is None: # noqa: ASYNC110 + await anyio.sleep(sleep_interval) + + context = self._state.get().di_state.context + + msg: Optional[RedisMessage] = await process_msg( # type: ignore[assignment] + msg=raw_message, + middlewares=( + m(raw_message, context=context) for m in self._broker_middlewares + ), + parser=self._parser, + decoder=self._decoder, + ) + return msg + + async def _get_message(self, psub: RPubSub) -> Optional[PubSubMessage]: + raw_msg = await psub.get_message( + ignore_subscribe_messages=True, + timeout=self.channel.polling_interval, + ) + + if raw_msg: + return PubSubMessage( + type=raw_msg["type"], + data=raw_msg["data"], + channel=raw_msg["channel"].decode(), + pattern=raw_msg["pattern"], + ) + + return None + + async def _get_msgs(self, psub: RPubSub) -> None: + if msg := await self._get_message(psub): + await self.consume(msg) + + def add_prefix(self, prefix: str) -> None: + new_ch = deepcopy(self.channel) + new_ch.name = f"{prefix}{new_ch.name}" + self.channel = new_ch diff --git a/faststream/redis/subscriber/usecases/list_subscriber.py b/faststream/redis/subscriber/usecases/list_subscriber.py new file mode 100644 index 0000000000..8c6558398f --- /dev/null +++ b/faststream/redis/subscriber/usecases/list_subscriber.py @@ -0,0 +1,221 @@ +from collections.abc import Iterable, Sequence +from copy import deepcopy +from typing import ( + TYPE_CHECKING, + Any, + Optional, +) + +import anyio +from typing_extensions import TypeAlias, override + +from faststream._internal.subscriber.utils import process_msg +from faststream.middlewares import AckPolicy +from faststream.redis.message import ( + BatchListMessage, + DefaultListMessage, + RedisListMessage, + UnifyRedisDict, +) +from faststream.redis.parser import ( + RedisBatchListParser, + RedisListParser, +) +from faststream.redis.schemas import ListSub + +from .basic import LogicSubscriber + +if TYPE_CHECKING: + from fast_depends.dependencies import Dependant + from redis.asyncio.client import Redis + + from faststream._internal.types import ( + AsyncCallable, + BrokerMiddleware, + ) + from faststream.message import StreamMessage as BrokerStreamMessage + + +TopicName: TypeAlias = bytes +Offset: TypeAlias = bytes + + +class _ListHandlerMixin(LogicSubscriber): + def __init__( + self, + *, + list: ListSub, + default_parser: "AsyncCallable", + default_decoder: "AsyncCallable", + # Subscriber args + ack_policy: "AckPolicy", + no_reply: bool, + broker_dependencies: Iterable["Dependant"], + broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], + ) -> None: + super().__init__( + default_parser=default_parser, + default_decoder=default_decoder, + # Propagated options + ack_policy=ack_policy, + no_reply=no_reply, + broker_middlewares=broker_middlewares, + broker_dependencies=broker_dependencies, + ) + + self.list_sub = list + + def get_log_context( + self, + message: Optional["BrokerStreamMessage[Any]"], + ) -> dict[str, str]: + return self.build_log_context( + message=message, + channel=self.list_sub.name, + ) + + @override + async def _consume( # type: ignore[override] + self, + client: "Redis[bytes]", + *, + start_signal: "anyio.Event", + ) -> None: + start_signal.set() + await super()._consume(client, start_signal=start_signal) + + @override + async def start(self) -> None: + if self.tasks: + return + + assert self._client, "You should setup subscriber at first." # nosec B101 + + await super().start(self._client) + + @override + async def get_one( + self, + *, + timeout: float = 5.0, + ) -> "Optional[RedisListMessage]": + assert self._client, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + sleep_interval = timeout / 10 + raw_message = None + + with anyio.move_on_after(timeout): + while ( # noqa: ASYNC110 + raw_message := await self._client.lpop(name=self.list_sub.name) + ) is None: + await anyio.sleep(sleep_interval) + + if not raw_message: + return None + + redis_incoming_msg = DefaultListMessage( + type="list", + data=raw_message, + channel=self.list_sub.name, + ) + + context = self._state.get().di_state.context + + msg: RedisListMessage = await process_msg( # type: ignore[assignment] + msg=redis_incoming_msg, + middlewares=( + m(redis_incoming_msg, context=context) for m in self._broker_middlewares + ), + parser=self._parser, + decoder=self._decoder, + ) + return msg + + def add_prefix(self, prefix: str) -> None: + new_list = deepcopy(self.list_sub) + new_list.name = f"{prefix}{new_list.name}" + self.list_sub = new_list + + +class ListSubscriber(_ListHandlerMixin): + def __init__( + self, + *, + list: ListSub, + # Subscriber args + no_reply: bool, + broker_dependencies: Iterable["Dependant"], + broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], + ) -> None: + parser = RedisListParser() + super().__init__( + list=list, + default_parser=parser.parse_message, + default_decoder=parser.decode_message, + # Propagated options + ack_policy=AckPolicy.DO_NOTHING, + no_reply=no_reply, + broker_middlewares=broker_middlewares, + broker_dependencies=broker_dependencies, + ) + + async def _get_msgs(self, client: "Redis[bytes]") -> None: + raw_msg = await client.blpop( + self.list_sub.name, + timeout=self.list_sub.polling_interval, + ) + + if raw_msg: + _, msg_data = raw_msg + + msg = DefaultListMessage( + type="list", + data=msg_data, + channel=self.list_sub.name, + ) + + await self.consume(msg) + + +class BatchListSubscriber(_ListHandlerMixin): + def __init__( + self, + *, + list: ListSub, + # Subscriber args + no_reply: bool, + broker_dependencies: Iterable["Dependant"], + broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], + ) -> None: + parser = RedisBatchListParser() + super().__init__( + list=list, + default_parser=parser.parse_message, + default_decoder=parser.decode_message, + # Propagated options + ack_policy=AckPolicy.DO_NOTHING, + no_reply=no_reply, + broker_middlewares=broker_middlewares, + broker_dependencies=broker_dependencies, + ) + + async def _get_msgs(self, client: "Redis[bytes]") -> None: + raw_msgs = await client.lpop( + name=self.list_sub.name, + count=self.list_sub.max_records, + ) + + if raw_msgs: + msg = BatchListMessage( + type="blist", + channel=self.list_sub.name, + data=raw_msgs, + ) + + await self.consume(msg) + + else: + await anyio.sleep(self.list_sub.polling_interval) diff --git a/faststream/redis/subscriber/usecases/stream_subscriber.py b/faststream/redis/subscriber/usecases/stream_subscriber.py new file mode 100644 index 0000000000..44c3a0c28b --- /dev/null +++ b/faststream/redis/subscriber/usecases/stream_subscriber.py @@ -0,0 +1,336 @@ +import math +from collections.abc import Awaitable, Iterable, Sequence +from copy import deepcopy +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, +) + +from redis.exceptions import ResponseError +from typing_extensions import TypeAlias, override + +from faststream._internal.subscriber.utils import process_msg +from faststream.redis.message import ( + BatchStreamMessage, + DefaultStreamMessage, + RedisStreamMessage, + UnifyRedisDict, +) +from faststream.redis.parser import ( + RedisBatchStreamParser, + RedisStreamParser, +) +from faststream.redis.schemas import StreamSub + +from .basic import LogicSubscriber + +if TYPE_CHECKING: + from fast_depends.dependencies import Dependant + + from faststream._internal.types import ( + AsyncCallable, + BrokerMiddleware, + ) + from faststream.message import StreamMessage as BrokerStreamMessage + from faststream.middlewares import AckPolicy + + +TopicName: TypeAlias = bytes +Offset: TypeAlias = bytes + + +class _StreamHandlerMixin(LogicSubscriber): + def __init__( + self, + *, + stream: StreamSub, + default_parser: "AsyncCallable", + default_decoder: "AsyncCallable", + # Subscriber args + ack_policy: "AckPolicy", + no_reply: bool, + broker_dependencies: Iterable["Dependant"], + broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], + ) -> None: + super().__init__( + default_parser=default_parser, + default_decoder=default_decoder, + # Propagated options + ack_policy=ack_policy, + no_reply=no_reply, + broker_middlewares=broker_middlewares, + broker_dependencies=broker_dependencies, + ) + + self.stream_sub = stream + self.last_id = stream.last_id + + def get_log_context( + self, + message: Optional["BrokerStreamMessage[Any]"], + ) -> dict[str, str]: + return self.build_log_context( + message=message, + channel=self.stream_sub.name, + ) + + @override + async def start(self) -> None: + if self.tasks: + return + + assert self._client, "You should setup subscriber at first." # nosec B101 + + client = self._client + + self.extra_watcher_options.update( + redis=client, + group=self.stream_sub.group, + ) + + stream = self.stream_sub + + read: Callable[ + [str], + Awaitable[ + tuple[ + tuple[ + TopicName, + tuple[ + tuple[ + Offset, + dict[bytes, bytes], + ], + ..., + ], + ], + ..., + ], + ], + ] + + if stream.group and stream.consumer: + try: + await client.xgroup_create( + name=stream.name, + id=self.last_id, + groupname=stream.group, + mkstream=True, + ) + except ResponseError as e: + if "already exists" not in str(e): + raise + + def read( + _: str, + ) -> Awaitable[ + tuple[ + tuple[ + TopicName, + tuple[ + tuple[ + Offset, + dict[bytes, bytes], + ], + ..., + ], + ], + ..., + ], + ]: + return client.xreadgroup( + groupname=stream.group, + consumername=stream.consumer, + streams={stream.name: ">"}, + count=stream.max_records, + block=stream.polling_interval, + noack=stream.no_ack, + ) + + else: + + def read( + last_id: str, + ) -> Awaitable[ + tuple[ + tuple[ + TopicName, + tuple[ + tuple[ + Offset, + dict[bytes, bytes], + ], + ..., + ], + ], + ..., + ], + ]: + return client.xread( + {stream.name: last_id}, + block=stream.polling_interval, + count=stream.max_records, + ) + + await super().start(read) + + @override + async def get_one( + self, + *, + timeout: float = 5.0, + ) -> "Optional[RedisStreamMessage]": + assert self._client, "You should start subscriber at first." # nosec B101 + assert ( # nosec B101 + not self.calls + ), "You can't use `get_one` method if subscriber has registered handlers." + + stream_message = await self._client.xread( + {self.stream_sub.name: self.last_id}, + block=math.ceil(timeout * 1000), + count=1, + ) + + if not stream_message: + return None + + ((stream_name, ((message_id, raw_message),)),) = stream_message + + self.last_id = message_id.decode() + + redis_incoming_msg = DefaultStreamMessage( + type="stream", + channel=stream_name.decode(), + message_ids=[message_id], + data=raw_message, + ) + + context = self._state.get().di_state.context + + msg: RedisStreamMessage = await process_msg( # type: ignore[assignment] + msg=redis_incoming_msg, + middlewares=( + m(redis_incoming_msg, context=context) for m in self._broker_middlewares + ), + parser=self._parser, + decoder=self._decoder, + ) + return msg + + def add_prefix(self, prefix: str) -> None: + new_stream = deepcopy(self.stream_sub) + new_stream.name = f"{prefix}{new_stream.name}" + self.stream_sub = new_stream + + +class StreamSubscriber(_StreamHandlerMixin): + def __init__( + self, + *, + stream: StreamSub, + # Subscriber args + ack_policy: "AckPolicy", + no_reply: bool, + broker_dependencies: Iterable["Dependant"], + broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], + ) -> None: + parser = RedisStreamParser() + super().__init__( + stream=stream, + default_parser=parser.parse_message, + default_decoder=parser.decode_message, + # Propagated options + ack_policy=ack_policy, + no_reply=no_reply, + broker_middlewares=broker_middlewares, + broker_dependencies=broker_dependencies, + ) + + async def _get_msgs( + self, + read: Callable[ + [str], + Awaitable[ + tuple[ + tuple[ + TopicName, + tuple[ + tuple[ + Offset, + dict[bytes, bytes], + ], + ..., + ], + ], + ..., + ], + ], + ], + ) -> None: + for stream_name, msgs in await read(self.last_id): + if msgs: + self.last_id = msgs[-1][0].decode() + + for message_id, raw_msg in msgs: + msg = DefaultStreamMessage( + type="stream", + channel=stream_name.decode(), + message_ids=[message_id], + data=raw_msg, + ) + + await self.consume(msg) + + +class StreamBatchSubscriber(_StreamHandlerMixin): + def __init__( + self, + *, + stream: StreamSub, + # Subscriber args + ack_policy: "AckPolicy", + no_reply: bool, + broker_dependencies: Iterable["Dependant"], + broker_middlewares: Sequence["BrokerMiddleware[UnifyRedisDict]"], + ) -> None: + parser = RedisBatchStreamParser() + super().__init__( + stream=stream, + default_parser=parser.parse_message, + default_decoder=parser.decode_message, + # Propagated options + ack_policy=ack_policy, + no_reply=no_reply, + broker_middlewares=broker_middlewares, + broker_dependencies=broker_dependencies, + ) + + async def _get_msgs( + self, + read: Callable[ + [str], + Awaitable[ + tuple[tuple[bytes, tuple[tuple[bytes, dict[bytes, bytes]], ...]], ...], + ], + ], + ) -> None: + for stream_name, msgs in await read(self.last_id): + if msgs: + self.last_id = msgs[-1][0].decode() + + data: list[dict[bytes, bytes]] = [] + ids: list[bytes] = [] + for message_id, i in msgs: + data.append(i) + ids.append(message_id) + + msg = BatchStreamMessage( + type="bstream", + channel=stream_name.decode(), + data=data, + message_ids=ids, + ) + + await self.consume(msg) diff --git a/faststream/redis/testing.py b/faststream/redis/testing.py index 558a0fd5ae..26537284da 100644 --- a/faststream/redis/testing.py +++ b/faststream/redis/testing.py @@ -31,16 +31,14 @@ from faststream.redis.publisher.producer import RedisFastProducer from faststream.redis.response import DestinationType, RedisPublishCommand from faststream.redis.schemas import INCORRECT_SETUP_MSG -from faststream.redis.subscriber.usecase import ( - ChannelSubscriber, - LogicSubscriber, - _ListHandlerMixin, - _StreamHandlerMixin, -) +from faststream.redis.subscriber.usecases.channel_subscriber import ChannelSubscriber +from faststream.redis.subscriber.usecases.list_subscriber import _ListHandlerMixin +from faststream.redis.subscriber.usecases.stream_subscriber import _StreamHandlerMixin if TYPE_CHECKING: from faststream._internal.basic_types import AnyDict, SendableMessage from faststream.redis.publisher.specified import SpecificationPublisher + from faststream.redis.subscriber.usecases.basic import LogicSubscriber __all__ = ("TestRedisBroker",)