Skip to content

Commit

Permalink
feat: nats filter JS subscription support (#1519)
Browse files Browse the repository at this point in the history
* feat: support NATS multiple subjects JS subscription

* feat: NATS test client supports filter subsciption

* refactor: add cache for NATS log subject calculation
  • Loading branch information
Lancetnik authored Jun 12, 2024
1 parent f054192 commit f7a5c19
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 16 deletions.
6 changes: 3 additions & 3 deletions faststream/nats/broker/registrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def subscriber( # type: ignore[override]
subject: Annotated[
str,
Doc("NATS subject to subscribe."),
],
] = "",
queue: Annotated[
str,
Doc(
Expand Down Expand Up @@ -209,7 +209,7 @@ def subscriber( # type: ignore[override]
You can use it as a handler decorator `@broker.subscriber(...)`.
"""
if stream := self._stream_builder.create(stream):
if (stream := self._stream_builder.create(stream)) and subject:
stream.add_subject(subject)

subscriber = cast(
Expand Down Expand Up @@ -323,7 +323,7 @@ def publisher( # type: ignore[override]
Or you can create a publisher object to call it lately - `broker.publisher(...).publish(...)`.
"""
if stream := self._stream_builder.create(stream):
if (stream := self._stream_builder.create(stream)) and subject:
stream.add_subject(subject)

publisher = cast(
Expand Down
16 changes: 15 additions & 1 deletion faststream/nats/subscriber/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
DEFAULT_SUB_PENDING_BYTES_LIMIT,
DEFAULT_SUB_PENDING_MSGS_LIMIT,
)
from nats.js.api import ConsumerConfig
from nats.js.client import (
DEFAULT_JS_SUB_PENDING_BYTES_LIMIT,
DEFAULT_JS_SUB_PENDING_MSGS_LIMIT,
Expand Down Expand Up @@ -80,6 +81,11 @@ def create_subscriber(
if pull_sub is not None and stream is None:
raise SetupError("Pull subscriber can be used only with a stream")

if not subject and not config:
raise SetupError("You must provide either `subject` or `config` option.")

config = config or ConsumerConfig(filter_subjects=[])

if stream:
# TODO: pull & queue warning
# TODO: push & durable warning
Expand All @@ -91,7 +97,6 @@ def create_subscriber(
or DEFAULT_JS_SUB_PENDING_BYTES_LIMIT,
"durable": durable,
"stream": stream.name,
"config": config,
}

if pull_sub is not None:
Expand Down Expand Up @@ -120,6 +125,7 @@ def create_subscriber(
if obj_watch is not None:
return AsyncAPIObjStoreWatchSubscriber(
subject=subject,
config=config,
obj_watch=obj_watch,
broker_dependencies=broker_dependencies,
broker_middlewares=broker_middlewares,
Expand All @@ -131,6 +137,7 @@ def create_subscriber(
if kv_watch is not None:
return AsyncAPIKeyValueWatchSubscriber(
subject=subject,
config=config,
kv_watch=kv_watch,
broker_dependencies=broker_dependencies,
broker_middlewares=broker_middlewares,
Expand All @@ -144,6 +151,7 @@ def create_subscriber(
return AsyncAPIConcurrentCoreSubscriber(
max_workers=max_workers,
subject=subject,
config=config,
queue=queue,
# basic args
extra_options=extra_options,
Expand All @@ -162,6 +170,7 @@ def create_subscriber(
else:
return AsyncAPICoreSubscriber(
subject=subject,
config=config,
queue=queue,
# basic args
extra_options=extra_options,
Expand All @@ -185,6 +194,7 @@ def create_subscriber(
pull_sub=pull_sub,
stream=stream,
subject=subject,
config=config,
# basic args
extra_options=extra_options,
# Subscriber args
Expand All @@ -204,6 +214,7 @@ def create_subscriber(
max_workers=max_workers,
stream=stream,
subject=subject,
config=config,
queue=queue,
# basic args
extra_options=extra_options,
Expand All @@ -226,6 +237,7 @@ def create_subscriber(
pull_sub=pull_sub,
stream=stream,
subject=subject,
config=config,
# basic args
extra_options=extra_options,
# Subscriber args
Expand All @@ -245,6 +257,7 @@ def create_subscriber(
pull_sub=pull_sub,
stream=stream,
subject=subject,
config=config,
# basic args
extra_options=extra_options,
# Subscriber args
Expand All @@ -264,6 +277,7 @@ def create_subscriber(
stream=stream,
subject=subject,
queue=queue,
config=config,
# basic args
extra_options=extra_options,
# Subscriber args
Expand Down
42 changes: 38 additions & 4 deletions faststream/nats/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import anyio
from fast_depends.dependencies import Depends
from nats.errors import ConnectionClosedError, TimeoutError
from nats.js.api import ObjectInfo
from nats.js.api import ConsumerConfig, ObjectInfo
from nats.js.kv import KeyValue
from typing_extensions import Annotated, Doc, override

Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
self,
*,
subject: str,
config: "ConsumerConfig",
extra_options: Optional[AnyDict],
# Subscriber args
default_parser: "AsyncCallable",
Expand All @@ -88,6 +89,7 @@ def __init__(
include_in_schema: bool,
) -> None:
self.subject = subject
self.config = config

self.extra_options = extra_options or {}

Expand Down Expand Up @@ -205,10 +207,20 @@ def build_log_context(

def add_prefix(self, prefix: str) -> None:
"""Include Subscriber in router."""
self.subject = "".join((prefix, self.subject))
if self.subject:
self.subject = "".join((prefix, self.subject))
else:
self.config.filter_subjects = [
"".join((prefix, subject))
for subject in (self.config.filter_subjects or ())
]

@cached_property
def _resolved_subject_string(self) -> str:
return self.subject or ", ".join(self.config.filter_subjects or ())

def __hash__(self) -> int:
return self.get_routing_hash(self.subject)
return self.get_routing_hash(self._resolved_subject_string)

@staticmethod
def get_routing_hash(
Expand All @@ -229,6 +241,7 @@ def __init__(
self,
*,
subject: str,
config: "ConsumerConfig",
# default args
extra_options: Optional[AnyDict],
# Subscriber args
Expand All @@ -246,6 +259,7 @@ def __init__(
) -> None:
super().__init__(
subject=subject,
config=config,
extra_options=extra_options,
# subscriber args
default_parser=default_parser,
Expand Down Expand Up @@ -368,6 +382,7 @@ def __init__(
*,
# default args
subject: str,
config: "ConsumerConfig",
queue: str,
extra_options: Optional[AnyDict],
# Subscriber args
Expand All @@ -387,6 +402,7 @@ def __init__(

super().__init__(
subject=subject,
config=config,
extra_options=extra_options,
# subscriber args
default_parser=parser_.parse_message,
Expand Down Expand Up @@ -439,6 +455,7 @@ def __init__(
max_workers: int,
# default args
subject: str,
config: "ConsumerConfig",
queue: str,
extra_options: Optional[AnyDict],
# Subscriber args
Expand All @@ -456,6 +473,7 @@ def __init__(
max_workers=max_workers,
# basic args
subject=subject,
config=config,
queue=queue,
extra_options=extra_options,
# Propagated args
Expand Down Expand Up @@ -494,6 +512,7 @@ def __init__(
stream: "JStream",
# default args
subject: str,
config: "ConsumerConfig",
queue: str,
extra_options: Optional[AnyDict],
# Subscriber args
Expand All @@ -514,6 +533,7 @@ def __init__(

super().__init__(
subject=subject,
config=config,
extra_options=extra_options,
# subscriber args
default_parser=parser_.parse_message,
Expand All @@ -540,7 +560,7 @@ def get_log_context(
"""Log context factory using in `self.consume` scope."""
return self.build_log_context(
message=message,
subject=self.subject,
subject=self._resolved_subject_string,
queue=self.queue,
stream=self.stream.name,
)
Expand All @@ -560,6 +580,7 @@ async def _create_subscription( # type: ignore[override]
subject=self.clear_subject,
queue=self.queue,
cb=self.consume,
config=self.config,
**self.extra_options,
)

Expand All @@ -574,6 +595,7 @@ def __init__(
stream: "JStream",
# default args
subject: str,
config: "ConsumerConfig",
queue: str,
extra_options: Optional[AnyDict],
# Subscriber args
Expand All @@ -592,6 +614,7 @@ def __init__(
# basic args
stream=stream,
subject=subject,
config=config,
queue=queue,
extra_options=extra_options,
# Propagated args
Expand Down Expand Up @@ -619,6 +642,7 @@ async def _create_subscription( # type: ignore[override]
subject=self.clear_subject,
queue=self.queue,
cb=self._put_msg,
config=self.config,
**self.extra_options,
)

Expand All @@ -633,6 +657,7 @@ def __init__(
stream: "JStream",
# default args
subject: str,
config: "ConsumerConfig",
extra_options: Optional[AnyDict],
# Subscriber args
no_ack: bool,
Expand All @@ -651,6 +676,7 @@ def __init__(
# basic args
stream=stream,
subject=subject,
config=config,
extra_options=extra_options,
queue="",
# Propagated args
Expand Down Expand Up @@ -708,6 +734,7 @@ def __init__(
pull_sub: "PullSub",
stream: "JStream",
subject: str,
config: "ConsumerConfig",
extra_options: Optional[AnyDict],
# Subscriber args
no_ack: bool,
Expand All @@ -726,6 +753,7 @@ def __init__(
pull_sub=pull_sub,
stream=stream,
subject=subject,
config=config,
extra_options=extra_options,
# Propagated args
no_ack=no_ack,
Expand Down Expand Up @@ -765,6 +793,7 @@ def __init__(
*,
# default args
subject: str,
config: "ConsumerConfig",
stream: "JStream",
pull_sub: "PullSub",
extra_options: Optional[AnyDict],
Expand All @@ -786,6 +815,7 @@ def __init__(

super().__init__(
subject=subject,
config=config,
extra_options=extra_options,
# subscriber args
default_parser=parser.parse_batch,
Expand Down Expand Up @@ -837,6 +867,7 @@ def __init__(
self,
*,
subject: str,
config: "ConsumerConfig",
kv_watch: "KvWatch",
broker_dependencies: Iterable[Depends],
broker_middlewares: Iterable["BrokerMiddleware[KeyValue.Entry]"],
Expand All @@ -850,6 +881,7 @@ def __init__(

super().__init__(
subject=subject,
config=config,
extra_options=None,
no_ack=True,
no_reply=True,
Expand Down Expand Up @@ -941,6 +973,7 @@ def __init__(
self,
*,
subject: str,
config: "ConsumerConfig",
obj_watch: "ObjWatch",
broker_dependencies: Iterable[Depends],
broker_middlewares: Iterable["BrokerMiddleware[List[Msg]]"],
Expand All @@ -955,6 +988,7 @@ def __init__(

super().__init__(
subject=subject,
config=config,
extra_options=None,
no_ack=True,
no_reply=True,
Expand Down
5 changes: 4 additions & 1 deletion faststream/nats/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ async def publish( # type: ignore[override]
):
continue

if is_subject_match_wildcard(subject, handler.clear_subject):
if is_subject_match_wildcard(subject, handler.clear_subject) or any(
is_subject_match_wildcard(subject, filter_subject)
for filter_subject in (handler.config.filter_subjects or ())
):
msg: Union[List[PatchedMessage], PatchedMessage]
if (pull := getattr(handler, "pull_sub", None)) and pull.batch:
msg = [incoming]
Expand Down
Loading

0 comments on commit f7a5c19

Please sign in to comment.