From f7a5c194717f5d71e9aafc43d6b321f4746c16e9 Mon Sep 17 00:00:00 2001 From: Pastukhov Nikita Date: Wed, 12 Jun 2024 09:08:47 +0300 Subject: [PATCH] feat: nats filter JS subscription support (#1519) * feat: support NATS multiple subjects JS subscription * feat: NATS test client supports filter subsciption * refactor: add cache for NATS log subject calculation --- faststream/nats/broker/registrator.py | 6 ++-- faststream/nats/subscriber/factory.py | 16 +++++++++- faststream/nats/subscriber/usecase.py | 42 +++++++++++++++++++++++--- faststream/nats/testing.py | 5 ++- tests/brokers/nats/test_consume.py | 33 ++++++++++++++++++-- tests/brokers/nats/test_test_client.py | 25 ++++++++++++--- 6 files changed, 111 insertions(+), 16 deletions(-) diff --git a/faststream/nats/broker/registrator.py b/faststream/nats/broker/registrator.py index a77b439b98..bcd0bab0a2 100644 --- a/faststream/nats/broker/registrator.py +++ b/faststream/nats/broker/registrator.py @@ -42,7 +42,7 @@ def subscriber( # type: ignore[override] subject: Annotated[ str, Doc("NATS subject to subscribe."), - ], + ] = "", queue: Annotated[ str, Doc( @@ -209,7 +209,7 @@ def subscriber( # type: ignore[override] You can use it as a handler decorator `@broker.subscriber(...)`. """ - if stream := self._stream_builder.create(stream): + if (stream := self._stream_builder.create(stream)) and subject: stream.add_subject(subject) subscriber = cast( @@ -323,7 +323,7 @@ def publisher( # type: ignore[override] Or you can create a publisher object to call it lately - `broker.publisher(...).publish(...)`. """ - if stream := self._stream_builder.create(stream): + if (stream := self._stream_builder.create(stream)) and subject: stream.add_subject(subject) publisher = cast( diff --git a/faststream/nats/subscriber/factory.py b/faststream/nats/subscriber/factory.py index 2ae7c9b820..1161c66550 100644 --- a/faststream/nats/subscriber/factory.py +++ b/faststream/nats/subscriber/factory.py @@ -4,6 +4,7 @@ DEFAULT_SUB_PENDING_BYTES_LIMIT, DEFAULT_SUB_PENDING_MSGS_LIMIT, ) +from nats.js.api import ConsumerConfig from nats.js.client import ( DEFAULT_JS_SUB_PENDING_BYTES_LIMIT, DEFAULT_JS_SUB_PENDING_MSGS_LIMIT, @@ -80,6 +81,11 @@ def create_subscriber( if pull_sub is not None and stream is None: raise SetupError("Pull subscriber can be used only with a stream") + if not subject and not config: + raise SetupError("You must provide either `subject` or `config` option.") + + config = config or ConsumerConfig(filter_subjects=[]) + if stream: # TODO: pull & queue warning # TODO: push & durable warning @@ -91,7 +97,6 @@ def create_subscriber( or DEFAULT_JS_SUB_PENDING_BYTES_LIMIT, "durable": durable, "stream": stream.name, - "config": config, } if pull_sub is not None: @@ -120,6 +125,7 @@ def create_subscriber( if obj_watch is not None: return AsyncAPIObjStoreWatchSubscriber( subject=subject, + config=config, obj_watch=obj_watch, broker_dependencies=broker_dependencies, broker_middlewares=broker_middlewares, @@ -131,6 +137,7 @@ def create_subscriber( if kv_watch is not None: return AsyncAPIKeyValueWatchSubscriber( subject=subject, + config=config, kv_watch=kv_watch, broker_dependencies=broker_dependencies, broker_middlewares=broker_middlewares, @@ -144,6 +151,7 @@ def create_subscriber( return AsyncAPIConcurrentCoreSubscriber( max_workers=max_workers, subject=subject, + config=config, queue=queue, # basic args extra_options=extra_options, @@ -162,6 +170,7 @@ def create_subscriber( else: return AsyncAPICoreSubscriber( subject=subject, + config=config, queue=queue, # basic args extra_options=extra_options, @@ -185,6 +194,7 @@ def create_subscriber( pull_sub=pull_sub, stream=stream, subject=subject, + config=config, # basic args extra_options=extra_options, # Subscriber args @@ -204,6 +214,7 @@ def create_subscriber( max_workers=max_workers, stream=stream, subject=subject, + config=config, queue=queue, # basic args extra_options=extra_options, @@ -226,6 +237,7 @@ def create_subscriber( pull_sub=pull_sub, stream=stream, subject=subject, + config=config, # basic args extra_options=extra_options, # Subscriber args @@ -245,6 +257,7 @@ def create_subscriber( pull_sub=pull_sub, stream=stream, subject=subject, + config=config, # basic args extra_options=extra_options, # Subscriber args @@ -264,6 +277,7 @@ def create_subscriber( stream=stream, subject=subject, queue=queue, + config=config, # basic args extra_options=extra_options, # Subscriber args diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py index 322ef41aa3..d64cc2cf2d 100644 --- a/faststream/nats/subscriber/usecase.py +++ b/faststream/nats/subscriber/usecase.py @@ -20,7 +20,7 @@ import anyio from fast_depends.dependencies import Depends from nats.errors import ConnectionClosedError, TimeoutError -from nats.js.api import ObjectInfo +from nats.js.api import ConsumerConfig, ObjectInfo from nats.js.kv import KeyValue from typing_extensions import Annotated, Doc, override @@ -73,6 +73,7 @@ def __init__( self, *, subject: str, + config: "ConsumerConfig", extra_options: Optional[AnyDict], # Subscriber args default_parser: "AsyncCallable", @@ -88,6 +89,7 @@ def __init__( include_in_schema: bool, ) -> None: self.subject = subject + self.config = config self.extra_options = extra_options or {} @@ -205,10 +207,20 @@ def build_log_context( def add_prefix(self, prefix: str) -> None: """Include Subscriber in router.""" - self.subject = "".join((prefix, self.subject)) + if self.subject: + self.subject = "".join((prefix, self.subject)) + else: + self.config.filter_subjects = [ + "".join((prefix, subject)) + for subject in (self.config.filter_subjects or ()) + ] + + @cached_property + def _resolved_subject_string(self) -> str: + return self.subject or ", ".join(self.config.filter_subjects or ()) def __hash__(self) -> int: - return self.get_routing_hash(self.subject) + return self.get_routing_hash(self._resolved_subject_string) @staticmethod def get_routing_hash( @@ -229,6 +241,7 @@ def __init__( self, *, subject: str, + config: "ConsumerConfig", # default args extra_options: Optional[AnyDict], # Subscriber args @@ -246,6 +259,7 @@ def __init__( ) -> None: super().__init__( subject=subject, + config=config, extra_options=extra_options, # subscriber args default_parser=default_parser, @@ -368,6 +382,7 @@ def __init__( *, # default args subject: str, + config: "ConsumerConfig", queue: str, extra_options: Optional[AnyDict], # Subscriber args @@ -387,6 +402,7 @@ def __init__( super().__init__( subject=subject, + config=config, extra_options=extra_options, # subscriber args default_parser=parser_.parse_message, @@ -439,6 +455,7 @@ def __init__( max_workers: int, # default args subject: str, + config: "ConsumerConfig", queue: str, extra_options: Optional[AnyDict], # Subscriber args @@ -456,6 +473,7 @@ def __init__( max_workers=max_workers, # basic args subject=subject, + config=config, queue=queue, extra_options=extra_options, # Propagated args @@ -494,6 +512,7 @@ def __init__( stream: "JStream", # default args subject: str, + config: "ConsumerConfig", queue: str, extra_options: Optional[AnyDict], # Subscriber args @@ -514,6 +533,7 @@ def __init__( super().__init__( subject=subject, + config=config, extra_options=extra_options, # subscriber args default_parser=parser_.parse_message, @@ -540,7 +560,7 @@ def get_log_context( """Log context factory using in `self.consume` scope.""" return self.build_log_context( message=message, - subject=self.subject, + subject=self._resolved_subject_string, queue=self.queue, stream=self.stream.name, ) @@ -560,6 +580,7 @@ async def _create_subscription( # type: ignore[override] subject=self.clear_subject, queue=self.queue, cb=self.consume, + config=self.config, **self.extra_options, ) @@ -574,6 +595,7 @@ def __init__( stream: "JStream", # default args subject: str, + config: "ConsumerConfig", queue: str, extra_options: Optional[AnyDict], # Subscriber args @@ -592,6 +614,7 @@ def __init__( # basic args stream=stream, subject=subject, + config=config, queue=queue, extra_options=extra_options, # Propagated args @@ -619,6 +642,7 @@ async def _create_subscription( # type: ignore[override] subject=self.clear_subject, queue=self.queue, cb=self._put_msg, + config=self.config, **self.extra_options, ) @@ -633,6 +657,7 @@ def __init__( stream: "JStream", # default args subject: str, + config: "ConsumerConfig", extra_options: Optional[AnyDict], # Subscriber args no_ack: bool, @@ -651,6 +676,7 @@ def __init__( # basic args stream=stream, subject=subject, + config=config, extra_options=extra_options, queue="", # Propagated args @@ -708,6 +734,7 @@ def __init__( pull_sub: "PullSub", stream: "JStream", subject: str, + config: "ConsumerConfig", extra_options: Optional[AnyDict], # Subscriber args no_ack: bool, @@ -726,6 +753,7 @@ def __init__( pull_sub=pull_sub, stream=stream, subject=subject, + config=config, extra_options=extra_options, # Propagated args no_ack=no_ack, @@ -765,6 +793,7 @@ def __init__( *, # default args subject: str, + config: "ConsumerConfig", stream: "JStream", pull_sub: "PullSub", extra_options: Optional[AnyDict], @@ -786,6 +815,7 @@ def __init__( super().__init__( subject=subject, + config=config, extra_options=extra_options, # subscriber args default_parser=parser.parse_batch, @@ -837,6 +867,7 @@ def __init__( self, *, subject: str, + config: "ConsumerConfig", kv_watch: "KvWatch", broker_dependencies: Iterable[Depends], broker_middlewares: Iterable["BrokerMiddleware[KeyValue.Entry]"], @@ -850,6 +881,7 @@ def __init__( super().__init__( subject=subject, + config=config, extra_options=None, no_ack=True, no_reply=True, @@ -941,6 +973,7 @@ def __init__( self, *, subject: str, + config: "ConsumerConfig", obj_watch: "ObjWatch", broker_dependencies: Iterable[Depends], broker_middlewares: Iterable["BrokerMiddleware[List[Msg]]"], @@ -955,6 +988,7 @@ def __init__( super().__init__( subject=subject, + config=config, extra_options=None, no_ack=True, no_reply=True, diff --git a/faststream/nats/testing.py b/faststream/nats/testing.py index 34230cb788..4d13333c5f 100644 --- a/faststream/nats/testing.py +++ b/faststream/nats/testing.py @@ -97,7 +97,10 @@ async def publish( # type: ignore[override] ): continue - if is_subject_match_wildcard(subject, handler.clear_subject): + if is_subject_match_wildcard(subject, handler.clear_subject) or any( + is_subject_match_wildcard(subject, filter_subject) + for filter_subject in (handler.config.filter_subjects or ()) + ): msg: Union[List[PatchedMessage], PatchedMessage] if (pull := getattr(handler, "pull_sub", None)) and pull.batch: msg = [incoming] diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index 60ac90a7f3..96e40f447b 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -1,11 +1,11 @@ import asyncio -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from nats.aio.msg import Msg from faststream.exceptions import AckMessage -from faststream.nats import JStream, NatsBroker, PullSub +from faststream.nats import ConsumerConfig, JStream, NatsBroker, PullSub from faststream.nats.annotations import NatsMessage from tests.brokers.base.consume import BrokerRealConsumeTestcase from tests.tools import spy_decorator @@ -40,6 +40,35 @@ def subscriber(m): assert event.is_set() + async def test_consume_with_filter( + self, + queue, + mock: Mock, + event: asyncio.Event, + ): + consume_broker = self.get_broker() + + @consume_broker.subscriber( + config=ConsumerConfig(filter_subjects=[f"{queue}.a"]), + stream=JStream(queue, subjects=[f"{queue}.*"]), + ) + def subscriber(m): + mock(m) + event.set() + + async with self.patch_broker(consume_broker) as br: + await br.start() + await asyncio.wait( + ( + asyncio.create_task(br.publish(2, f"{queue}.a")), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + + assert event.is_set() + mock.assert_called_once_with(2) + async def test_consume_pull( self, queue: str, diff --git a/tests/brokers/nats/test_test_client.py b/tests/brokers/nats/test_test_client.py index ebbd1c7887..9718b558b6 100644 --- a/tests/brokers/nats/test_test_client.py +++ b/tests/brokers/nats/test_test_client.py @@ -4,7 +4,7 @@ from faststream import BaseMiddleware from faststream.exceptions import SetupError -from faststream.nats import JStream, NatsBroker, PullSub, TestNatsBroker +from faststream.nats import ConsumerConfig, JStream, NatsBroker, PullSub, TestNatsBroker from tests.brokers.base.testclient import BrokerTestclientTestcase @@ -208,8 +208,6 @@ async def test_consume_batch( self, queue: str, stream: JStream, - event: asyncio.Event, - mock, ): broker = self.get_broker() @@ -219,9 +217,26 @@ async def test_consume_batch( pull_sub=PullSub(1, batch=True), ) def subscriber(m): - mock(m) - event.set() + pass async with TestNatsBroker(broker) as br: await br.publish("hello", queue) subscriber.mock.assert_called_once_with(["hello"]) + + async def test_consume_with_filter( + self, + queue, + ): + broker = self.get_broker() + + @broker.subscriber( + config=ConsumerConfig(filter_subjects=[f"{queue}.a"]), + stream=JStream(queue, subjects=[f"{queue}.*"]), + ) + def subscriber(m): + pass + + async with TestNatsBroker(broker) as br: + await br.publish(1, f"{queue}.b") + await br.publish(2, f"{queue}.a") + subscriber.mock.assert_called_once_with(2)