From 7e72bdc4af84339240bdb00057158cbee79ce223 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Tue, 7 May 2024 22:52:49 +0300 Subject: [PATCH 1/6] fix: correct NATS dynamic subscriber registration' --- faststream/nats/testing.py | 9 ++++++++- tests/brokers/base/consume.py | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/faststream/nats/testing.py b/faststream/nats/testing.py index f106c93f9d..6681ba5b14 100644 --- a/faststream/nats/testing.py +++ b/faststream/nats/testing.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from unittest.mock import AsyncMock from nats.aio.msg import Msg from typing_extensions import override @@ -40,8 +41,14 @@ def f(msg: Any) -> None: return sub.calls[0].handler @staticmethod - async def _fake_connect(broker: NatsBroker, *args: Any, **kwargs: Any) -> None: + async def _fake_connect( # type: ignore[override] + broker: NatsBroker, + *args: Any, + **kwargs: Any, + ) -> AsyncMock: + broker.stream = AsyncMock() # type: ignore[assignment] broker._js_producer = broker._producer = FakeProducer(broker) # type: ignore[assignment] + return AsyncMock() @staticmethod def remove_publisher_fake_subscriber( diff --git a/tests/brokers/base/consume.py b/tests/brokers/base/consume.py index 654d3b19f8..68f4edc3e3 100644 --- a/tests/brokers/base/consume.py +++ b/tests/brokers/base/consume.py @@ -221,6 +221,33 @@ async def handler(m: Foo, dep: int = Depends(dependency), broker=Context()): assert event.is_set() mock.assert_called_once_with({"x": 1}, "100", consume_broker) + async def test_dynamic_sub( + self, + queue: str, + consume_broker: BrokerUsecase, + event: asyncio.Event, + ): + def subscriber(m): + event.set() + + async with consume_broker: + await consume_broker.start() + + sub = consume_broker.subscriber(queue, **self.subscriber_kwargs) + sub(subscriber) + consume_broker.setup_subscriber(sub) + await sub.start() + + await asyncio.wait( + ( + asyncio.create_task(consume_broker.publish("hello", queue)), + asyncio.create_task(event.wait()), + ), + timeout=self.timeout, + ) + + assert event.is_set() + @pytest.mark.asyncio() class BrokerRealConsumeTestcase(BrokerConsumeTestcase): From fbc74bbe8530fed6cf07a71bc436f5a52e9add45 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 8 May 2024 20:03:47 +0300 Subject: [PATCH 2/6] tests: fix dynamic test for TestClient --- faststream/confluent/broker/broker.py | 20 ++++++-- faststream/confluent/broker/logging.py | 6 +-- faststream/confluent/broker/registrator.py | 47 +++++++++---------- faststream/confluent/subscriber/asyncapi.py | 15 +++--- faststream/confluent/subscriber/usecase.py | 25 +++++----- faststream/confluent/testing.py | 14 +++++- faststream/kafka/broker/broker.py | 11 +++-- faststream/kafka/broker/logging.py | 5 +- faststream/kafka/broker/registrator.py | 52 +++++++++------------ faststream/kafka/subscriber/asyncapi.py | 20 ++++---- faststream/kafka/subscriber/usecase.py | 26 ++++++----- faststream/kafka/testing.py | 14 +++++- faststream/redis/testing.py | 8 +++- 13 files changed, 148 insertions(+), 115 deletions(-) diff --git a/faststream/confluent/broker/broker.py b/faststream/confluent/broker/broker.py index 8f9de15b09..9f31fbbb5e 100644 --- a/faststream/confluent/broker/broker.py +++ b/faststream/confluent/broker/broker.py @@ -23,7 +23,11 @@ from faststream.broker.message import gen_cor_id from faststream.confluent.broker.logging import KafkaLoggingBroker from faststream.confluent.broker.registrator import KafkaRegistrator -from faststream.confluent.client import AsyncConfluentProducer, _missing +from faststream.confluent.client import ( + AsyncConfluentConsumer, + AsyncConfluentProducer, + _missing, +) from faststream.confluent.publisher.producer import AsyncConfluentFastProducer from faststream.confluent.schemas.params import ConsumerConnectionParams from faststream.confluent.security import parse_security @@ -425,7 +429,7 @@ async def connect( Doc("Kafka addresses to connect."), ] = Parameter.empty, **kwargs: Any, - ) -> ConsumerConnectionParams: + ) -> Callable[..., AsyncConfluentConsumer]: if bootstrap_servers is not Parameter.empty: kwargs["bootstrap_servers"] = bootstrap_servers @@ -437,7 +441,7 @@ async def _connect( # type: ignore[override] *, client_id: str, **kwargs: Any, - ) -> ConsumerConnectionParams: + ) -> Callable[..., AsyncConfluentConsumer]: security_params = parse_security(self.security) kwargs.update(security_params) @@ -450,7 +454,10 @@ async def _connect( # type: ignore[override] producer=producer, ) - return filter_by_dict(ConsumerConnectionParams, kwargs) + return partial( + AsyncConfluentConsumer, + **filter_by_dict(ConsumerConnectionParams, kwargs), + ) async def start(self) -> None: await super().start() @@ -464,7 +471,10 @@ async def start(self) -> None: @property def _subscriber_setup_extra(self) -> "AnyDict": - return {"client_id": self.client_id, "connection_data": self._connection or {}} + return { + "client_id": self.client_id, + "builder": self._connection, + } @override async def publish( # type: ignore[override] diff --git a/faststream/confluent/broker/logging.py b/faststream/confluent/broker/logging.py index 9eebc89461..4fead65305 100644 --- a/faststream/confluent/broker/logging.py +++ b/faststream/confluent/broker/logging.py @@ -1,9 +1,9 @@ import logging from inspect import Parameter -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Tuple, Union from faststream.broker.core.usecase import BrokerUsecase -from faststream.confluent.schemas.params import ConsumerConnectionParams +from faststream.confluent.client import AsyncConfluentConsumer from faststream.log.logging import get_broker_logger if TYPE_CHECKING: @@ -15,7 +15,7 @@ class KafkaLoggingBroker( BrokerUsecase[ Union["confluent_kafka.Message", Tuple["confluent_kafka.Message", ...]], - ConsumerConnectionParams, + Callable[..., AsyncConfluentConsumer], ] ): """A class that extends the LoggingMixin class and adds additional functionality for logging Kafka related information.""" diff --git a/faststream/confluent/broker/registrator.py b/faststream/confluent/broker/registrator.py index 4fde6249c3..6306d10bd9 100644 --- a/faststream/confluent/broker/registrator.py +++ b/faststream/confluent/broker/registrator.py @@ -1,4 +1,3 @@ -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -18,7 +17,6 @@ from faststream.broker.core.abc import ABCBroker from faststream.broker.utils import default_filter -from faststream.confluent.client import AsyncConfluentConsumer from faststream.confluent.publisher.asyncapi import AsyncAPIPublisher from faststream.confluent.subscriber.asyncapi import AsyncAPISubscriber from faststream.exceptions import SetupError @@ -1235,29 +1233,6 @@ def subscriber( if not auto_commit and not group_id: raise SetupError("You should install `group_id` with manual commit mode") - builder = partial( - AsyncConfluentConsumer, - key_deserializer=key_deserializer, - value_deserializer=value_deserializer, - fetch_max_wait_ms=fetch_max_wait_ms, - fetch_max_bytes=fetch_max_bytes, - fetch_min_bytes=fetch_min_bytes, - max_partition_fetch_bytes=max_partition_fetch_bytes, - auto_offset_reset=auto_offset_reset, - enable_auto_commit=auto_commit, - auto_commit_interval_ms=auto_commit_interval_ms, - check_crcs=check_crcs, - partition_assignment_strategy=partition_assignment_strategy, - max_poll_interval_ms=max_poll_interval_ms, - rebalance_timeout_ms=rebalance_timeout_ms, - session_timeout_ms=session_timeout_ms, - heartbeat_interval_ms=heartbeat_interval_ms, - consumer_timeout_ms=consumer_timeout_ms, - max_poll_records=max_poll_records, - exclude_internal_topics=exclude_internal_topics, - isolation_level=isolation_level, - ) - subscriber = super().subscriber( AsyncAPISubscriber.create( *topics, @@ -1265,7 +1240,27 @@ def subscriber( batch_timeout_ms=batch_timeout_ms, max_records=max_records, group_id=group_id, - builder=builder, + connection_data={ + "key_deserializer": key_deserializer, + "value_deserializer": value_deserializer, + "fetch_max_wait_ms": fetch_max_wait_ms, + "fetch_max_bytes": fetch_max_bytes, + "fetch_min_bytes": fetch_min_bytes, + "max_partition_fetch_bytes": max_partition_fetch_bytes, + "auto_offset_reset": auto_offset_reset, + "enable_auto_commit": auto_commit, + "auto_commit_interval_ms": auto_commit_interval_ms, + "check_crcs": check_crcs, + "partition_assignment_strategy": partition_assignment_strategy, + "max_poll_interval_ms": max_poll_interval_ms, + "rebalance_timeout_ms": rebalance_timeout_ms, + "session_timeout_ms": session_timeout_ms, + "heartbeat_interval_ms": heartbeat_interval_ms, + "consumer_timeout_ms": consumer_timeout_ms, + "max_poll_records": max_poll_records, + "exclude_internal_topics": exclude_internal_topics, + "isolation_level": isolation_level, + }, is_manual=not auto_commit, # subscriber args no_ack=no_ack, diff --git a/faststream/confluent/subscriber/asyncapi.py b/faststream/confluent/subscriber/asyncapi.py index 8da47a800e..d31bfa05f2 100644 --- a/faststream/confluent/subscriber/asyncapi.py +++ b/faststream/confluent/subscriber/asyncapi.py @@ -1,6 +1,5 @@ from typing import ( TYPE_CHECKING, - Callable, Dict, Iterable, Literal, @@ -33,7 +32,7 @@ from fast_depends.dependencies import Depends from faststream.broker.types import BrokerMiddleware - from faststream.confluent.client import AsyncConfluentConsumer + from faststream.types import AnyDict class AsyncAPISubscriber(LogicSubscriber[MsgType]): @@ -77,7 +76,7 @@ def create( max_records: Optional[int], # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -99,7 +98,7 @@ def create( max_records: Optional[int], # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -121,7 +120,7 @@ def create( max_records: Optional[int], # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -148,7 +147,7 @@ def create( max_records: Optional[int], # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -171,7 +170,7 @@ def create( batch_timeout_ms=batch_timeout_ms, max_records=max_records, group_id=group_id, - builder=builder, + connection_data=connection_data, is_manual=is_manual, no_ack=no_ack, retry=retry, @@ -185,7 +184,7 @@ def create( return AsyncAPIDefaultSubscriber( *topics, group_id=group_id, - builder=builder, + connection_data=connection_data, is_manual=is_manual, no_ack=no_ack, retry=retry, diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index d778086bae..e5e23ed710 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -19,7 +19,6 @@ from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.broker.types import MsgType from faststream.confluent.parser import AsyncConfluentParser -from faststream.confluent.schemas.params import ConsumerConnectionParams if TYPE_CHECKING: from fast_depends.dependencies import Depends @@ -41,7 +40,9 @@ class LogicSubscriber(ABC, SubscriberUsecase[MsgType]): topics: Sequence[str] group_id: Optional[str] + builder: Optional[Callable[..., "AsyncConfluentConsumer"]] consumer: Optional["AsyncConfluentConsumer"] + task: Optional["asyncio.Task[None]"] client_id: Optional[str] @@ -50,7 +51,7 @@ def __init__( *topics: str, # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args default_parser: "AsyncCallable", @@ -81,20 +82,20 @@ def __init__( self.group_id = group_id self.topics = topics self.is_manual = is_manual - self.builder = builder + self.builder = None self.consumer = None self.task = None # Setup it later self.client_id = "" - self.__connection_data = ConsumerConnectionParams() + self.__connection_data = connection_data @override def setup( # type: ignore[override] self, *, client_id: Optional[str], - connection_data: "ConsumerConnectionParams", + builder: Callable[..., "AsyncConfluentConsumer"], # basic args logger: Optional["LoggerProto"], producer: Optional["ProducerProto"], @@ -110,7 +111,7 @@ def setup( # type: ignore[override] _call_decorators: Iterable["Decorator"], ) -> None: self.client_id = client_id - self.__connection_data = connection_data + self.builder = builder super().setup( logger=logger, @@ -128,6 +129,8 @@ def setup( # type: ignore[override] @override async def start(self) -> None: """Start the consumer.""" + assert self.builder, "You should setup subscriber at first." # nosec B101 + self.consumer = consumer = self.builder( *self.topics, group_id=self.group_id, @@ -172,7 +175,7 @@ async def get_msg(self) -> Optional[MsgType]: raise NotImplementedError() async def _consume(self) -> None: - assert self.consumer, "You need to start handler first" # nosec B101 + assert self.consumer, "You should start subscriber at first." # nosec B101 connected = True while self.running: @@ -219,7 +222,7 @@ def __init__( *topics: str, # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -234,7 +237,7 @@ def __init__( super().__init__( *topics, group_id=group_id, - builder=builder, + connection_data=connection_data, is_manual=is_manual, # subscriber args default_parser=AsyncConfluentParser.parse_message, @@ -278,7 +281,7 @@ def __init__( max_records: Optional[int], # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -296,7 +299,7 @@ def __init__( super().__init__( *topics, group_id=group_id, - builder=builder, + connection_data=connection_data, is_manual=is_manual, # subscriber args default_parser=AsyncConfluentParser.parse_message_batch, diff --git a/faststream/confluent/testing.py b/faststream/confluent/testing.py index 4559cbde8b..3052828424 100644 --- a/faststream/confluent/testing.py +++ b/faststream/confluent/testing.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +from unittest.mock import AsyncMock from typing_extensions import override @@ -22,8 +23,13 @@ class TestKafkaBroker(TestBroker[KafkaBroker]): """A class to test Kafka brokers.""" @staticmethod - async def _fake_connect(broker: KafkaBroker, *args: Any, **kwargs: Any) -> None: + async def _fake_connect( # type: ignore[override] + broker: KafkaBroker, + *args: Any, + **kwargs: Any, + ) -> Callable[..., AsyncMock]: broker._producer = FakeProducer(broker) + return _fake_connection @staticmethod def create_publisher_fake_subscriber( @@ -231,3 +237,7 @@ def build_message( timestamp_type=0 + 1, timestamp_ms=timestamp_ms or int(datetime.now().timestamp()), ) + + +def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock: + return AsyncMock() diff --git a/faststream/kafka/broker/broker.py b/faststream/kafka/broker/broker.py index 59d6e733d6..2a29796860 100644 --- a/faststream/kafka/broker/broker.py +++ b/faststream/kafka/broker/broker.py @@ -557,7 +557,7 @@ async def connect( # type: ignore[override] Doc("Kafka addresses to connect."), ] = Parameter.empty, **kwargs: "Unpack[KafkaInitKwargs]", - ) -> ConsumerConnectionParams: + ) -> Callable[..., aiokafka.AIOKafkaConsumer]: """Connect to Kafka servers manually. Consumes the same with `KafkaBroker.__init__` arguments and overrides them. @@ -579,7 +579,7 @@ async def _connect( # type: ignore[override] *, client_id: str, **kwargs: Any, - ) -> ConsumerConnectionParams: + ) -> Callable[..., aiokafka.AIOKafkaConsumer]: security_params = parse_security(self.security) kwargs.update(security_params) @@ -593,7 +593,10 @@ async def _connect( # type: ignore[override] producer=producer, ) - return filter_by_dict(ConsumerConnectionParams, kwargs) + return partial( + aiokafka.AIOKafkaConsumer, + **filter_by_dict(ConsumerConnectionParams, kwargs), + ) async def start(self) -> None: """Connect broker to Kafka and startup all subscribers.""" @@ -610,7 +613,7 @@ async def start(self) -> None: def _subscriber_setup_extra(self) -> "AnyDict": return { "client_id": self.client_id, - "connection_args": self._connection or {}, + "builder": self._connection, } @override diff --git a/faststream/kafka/broker/logging.py b/faststream/kafka/broker/logging.py index df828024da..16b1103b83 100644 --- a/faststream/kafka/broker/logging.py +++ b/faststream/kafka/broker/logging.py @@ -1,9 +1,8 @@ import logging from inspect import Parameter -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Tuple, Union from faststream.broker.core.usecase import BrokerUsecase -from faststream.kafka.schemas.params import ConsumerConnectionParams from faststream.log.logging import get_broker_logger if TYPE_CHECKING: @@ -15,7 +14,7 @@ class KafkaLoggingBroker( BrokerUsecase[ Union["aiokafka.ConsumerRecord", Tuple["aiokafka.ConsumerRecord", ...]], - ConsumerConnectionParams, + Callable[..., "aiokafka.AIOKafkaConsumer"], ] ): """A class that extends the LoggingMixin class and adds additional functionality for logging Kafka related information.""" diff --git a/faststream/kafka/broker/registrator.py b/faststream/kafka/broker/registrator.py index bed606870a..dabfc9a1b0 100644 --- a/faststream/kafka/broker/registrator.py +++ b/faststream/kafka/broker/registrator.py @@ -1,4 +1,3 @@ -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -14,13 +13,12 @@ overload, ) -from aiokafka import AIOKafkaConsumer, ConsumerRecord +from aiokafka import ConsumerRecord from aiokafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from typing_extensions import Annotated, Doc, deprecated, override from faststream.broker.core.abc import ABCBroker from faststream.broker.utils import default_filter -from faststream.exceptions import SetupError from faststream.kafka.publisher.asyncapi import AsyncAPIPublisher from faststream.kafka.subscriber.asyncapi import AsyncAPISubscriber @@ -1367,32 +1365,6 @@ def subscriber( "AsyncAPIDefaultSubscriber", "AsyncAPIBatchSubscriber", ]: - if not auto_commit and not group_id: - raise SetupError("You should install `group_id` with manual commit mode") - - builder = partial( - AIOKafkaConsumer, - key_deserializer=key_deserializer, - value_deserializer=value_deserializer, - fetch_max_wait_ms=fetch_max_wait_ms, - fetch_max_bytes=fetch_max_bytes, - fetch_min_bytes=fetch_min_bytes, - max_partition_fetch_bytes=max_partition_fetch_bytes, - auto_offset_reset=auto_offset_reset, - enable_auto_commit=auto_commit, - auto_commit_interval_ms=auto_commit_interval_ms, - check_crcs=check_crcs, - partition_assignment_strategy=partition_assignment_strategy, - max_poll_interval_ms=max_poll_interval_ms, - rebalance_timeout_ms=rebalance_timeout_ms, - session_timeout_ms=session_timeout_ms, - heartbeat_interval_ms=heartbeat_interval_ms, - consumer_timeout_ms=consumer_timeout_ms, - max_poll_records=max_poll_records, - exclude_internal_topics=exclude_internal_topics, - isolation_level=isolation_level, - ) - subscriber = super().subscriber( AsyncAPISubscriber.create( *topics, @@ -1402,7 +1374,27 @@ def subscriber( group_id=group_id, listener=listener, pattern=pattern, - builder=builder, + connection_args={ + "key_deserializer": key_deserializer, + "value_deserializer": value_deserializer, + "fetch_max_wait_ms": fetch_max_wait_ms, + "fetch_max_bytes": fetch_max_bytes, + "fetch_min_bytes": fetch_min_bytes, + "max_partition_fetch_bytes": max_partition_fetch_bytes, + "auto_offset_reset": auto_offset_reset, + "enable_auto_commit": auto_commit, + "auto_commit_interval_ms": auto_commit_interval_ms, + "check_crcs": check_crcs, + "partition_assignment_strategy": partition_assignment_strategy, + "max_poll_interval_ms": max_poll_interval_ms, + "rebalance_timeout_ms": rebalance_timeout_ms, + "session_timeout_ms": session_timeout_ms, + "heartbeat_interval_ms": heartbeat_interval_ms, + "consumer_timeout_ms": consumer_timeout_ms, + "max_poll_records": max_poll_records, + "exclude_internal_topics": exclude_internal_topics, + "isolation_level": isolation_level, + }, is_manual=not auto_commit, # subscriber args no_ack=no_ack, diff --git a/faststream/kafka/subscriber/asyncapi.py b/faststream/kafka/subscriber/asyncapi.py index 4453690cc1..8e1a9e5e21 100644 --- a/faststream/kafka/subscriber/asyncapi.py +++ b/faststream/kafka/subscriber/asyncapi.py @@ -1,6 +1,5 @@ from typing import ( TYPE_CHECKING, - Callable, Dict, Iterable, Literal, @@ -22,6 +21,7 @@ from faststream.asyncapi.schema.bindings import kafka from faststream.asyncapi.utils import resolve_payloads from faststream.broker.types import MsgType +from faststream.exceptions import SetupError from faststream.kafka.subscriber.usecase import ( BatchSubscriber, DefaultSubscriber, @@ -29,11 +29,12 @@ ) if TYPE_CHECKING: - from aiokafka import AIOKafkaConsumer, ConsumerRecord + from aiokafka import ConsumerRecord from aiokafka.abc import ConsumerRebalanceListener from fast_depends.dependencies import Depends from faststream.broker.types import BrokerMiddleware + from faststream.types import AnyDict class AsyncAPISubscriber(LogicSubscriber[MsgType]): @@ -79,7 +80,7 @@ def create( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], - builder: Callable[..., "AIOKafkaConsumer"], + connection_args: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -103,7 +104,7 @@ def create( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], - builder: Callable[..., "AIOKafkaConsumer"], + connection_args: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -127,7 +128,7 @@ def create( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], - builder: Callable[..., "AIOKafkaConsumer"], + connection_args: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -156,7 +157,7 @@ def create( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], - builder: Callable[..., "AIOKafkaConsumer"], + connection_args: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -173,6 +174,9 @@ def create( "AsyncAPIDefaultSubscriber", "AsyncAPIBatchSubscriber", ]: + if is_manual and not group_id: + raise SetupError("You should install `group_id` with manual commit mode") + if batch: return AsyncAPIBatchSubscriber( *topics, @@ -181,7 +185,7 @@ def create( group_id=group_id, listener=listener, pattern=pattern, - builder=builder, + connection_args=connection_args, is_manual=is_manual, no_ack=no_ack, retry=retry, @@ -197,7 +201,7 @@ def create( group_id=group_id, listener=listener, pattern=pattern, - builder=builder, + connection_args=connection_args, is_manual=is_manual, no_ack=no_ack, retry=retry, diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index 0a99702b98..6084ec6787 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -33,7 +33,6 @@ from faststream.broker.message import StreamMessage from faststream.broker.publisher.proto import ProducerProto - from faststream.kafka.schemas.params import ConsumerConnectionParams from faststream.types import AnyDict, Decorator, LoggerProto @@ -43,7 +42,9 @@ class LogicSubscriber(ABC, SubscriberUsecase[MsgType]): topics: Sequence[str] group_id: Optional[str] + builder: Optional[Callable[..., "AIOKafkaConsumer"]] consumer: Optional["AIOKafkaConsumer"] + task: Optional["asyncio.Task[None]"] client_id: Optional[str] batch: bool @@ -53,7 +54,7 @@ def __init__( *topics: str, # Kafka information group_id: Optional[str], - builder: Callable[..., "AIOKafkaConsumer"], + connection_args: "AnyDict", listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], is_manual: bool, @@ -86,7 +87,7 @@ def __init__( self.group_id = group_id self.topics = topics self.is_manual = is_manual - self.builder = builder + self.builder = None self.consumer = None self.task = None @@ -94,14 +95,14 @@ def __init__( self.client_id = "" self.__pattern = pattern self.__listener = listener - self.__connection_args: "ConsumerConnectionParams" = {} + self.__connection_args = connection_args @override def setup( # type: ignore[override] self, *, client_id: Optional[str], - connection_args: "ConsumerConnectionParams", + builder: Callable[..., "AIOKafkaConsumer"], # basic args logger: Optional["LoggerProto"], producer: Optional["ProducerProto"], @@ -117,7 +118,7 @@ def setup( # type: ignore[override] _call_decorators: Iterable["Decorator"], ) -> None: self.client_id = client_id - self.__connection_args = connection_args + self.builder = builder super().setup( logger=logger, @@ -134,11 +135,14 @@ def setup( # type: ignore[override] async def start(self) -> None: """Start the consumer.""" + assert self.builder, "You should setup subscriber at first." # nosec B101 + self.consumer = consumer = self.builder( group_id=self.group_id, client_id=self.client_id, **self.__connection_args, ) + consumer.subscribe( topics=self.topics, pattern=self.__pattern, @@ -183,7 +187,7 @@ async def get_msg(self) -> MsgType: raise NotImplementedError() async def _consume(self) -> None: - assert self.consumer, "You should setup subscriber at first." # nosec B101 + assert self.consumer, "You should start subscriber at first." # nosec B101 connected = True while self.running: @@ -260,7 +264,7 @@ def __init__( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], - builder: Callable[..., "AIOKafkaConsumer"], + connection_args: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -277,7 +281,7 @@ def __init__( group_id=group_id, listener=listener, pattern=pattern, - builder=builder, + connection_args=connection_args, is_manual=is_manual, # subscriber args default_parser=AioKafkaParser.parse_message, @@ -308,7 +312,7 @@ def __init__( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], - builder: Callable[..., "AIOKafkaConsumer"], + connection_args: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -330,7 +334,7 @@ def __init__( group_id=group_id, listener=listener, pattern=pattern, - builder=builder, + connection_args=connection_args, is_manual=is_manual, # subscriber args default_parser=AioKafkaParser.parse_message_batch, diff --git a/faststream/kafka/testing.py b/faststream/kafka/testing.py index fb9e71417f..4397501a95 100644 --- a/faststream/kafka/testing.py +++ b/faststream/kafka/testing.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from unittest.mock import AsyncMock from aiokafka import ConsumerRecord from typing_extensions import override @@ -23,8 +24,13 @@ class TestKafkaBroker(TestBroker[KafkaBroker]): """A class to test Kafka brokers.""" @staticmethod - async def _fake_connect(broker: KafkaBroker, *args: Any, **kwargs: Any) -> None: + async def _fake_connect( # type: ignore[override] + broker: KafkaBroker, + *args: Any, + **kwargs: Any, + ) -> Callable[..., AsyncMock]: broker._producer = FakeProducer(broker) + return _fake_connection @staticmethod def create_publisher_fake_subscriber( @@ -184,3 +190,7 @@ def build_message( offset=0, headers=[(i, j.encode()) for i, j in headers.items()], ) + + +def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock: + return AsyncMock() diff --git a/faststream/redis/testing.py b/faststream/redis/testing.py index 74541322f1..7d4a60da4e 100644 --- a/faststream/redis/testing.py +++ b/faststream/redis/testing.py @@ -1,5 +1,6 @@ import re from typing import TYPE_CHECKING, Any, Optional, Sequence, Union +from unittest.mock import AsyncMock, MagicMock from typing_extensions import override @@ -49,12 +50,15 @@ def f(msg: Any) -> None: return sub.calls[0].handler @staticmethod - async def _fake_connect( + async def _fake_connect( # type: ignore[override] broker: RedisBroker, *args: Any, **kwargs: Any, - ) -> None: + ) -> AsyncMock: broker._producer = FakeProducer(broker) # type: ignore[assignment] + connection = MagicMock() + connection.pubsub.side_effect = AsyncMock + return connection @staticmethod def remove_publisher_fake_subscriber( From 7178dd3af33cc2aa6bbc081ae07f57e7b63638a1 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 8 May 2024 20:23:09 +0300 Subject: [PATCH 3/6] lint: fix deprecation warn --- faststream/asyncapi/site.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/faststream/asyncapi/site.py b/faststream/asyncapi/site.py index 73184f9bb4..fcc0aefea6 100644 --- a/faststream/asyncapi/site.py +++ b/faststream/asyncapi/site.py @@ -102,7 +102,7 @@ def serve_app( ) -> None: """Serve the HTTPServer with AsyncAPI schema.""" logger.info(f"HTTPServer running on http://{host}:{port} (Press CTRL+C to quit)") - logger.warn("Please, do not use it in production.") + logger.warning("Please, do not use it in production.") server.HTTPServer( (host, port), From 1227488ac755d3b83750052bd365a465b8a6349f Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 8 May 2024 23:18:35 +0300 Subject: [PATCH 4/6] tests: fix reloader test --- tests/cli/supervisors/test_base_reloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cli/supervisors/test_base_reloader.py b/tests/cli/supervisors/test_base_reloader.py index c143d39c9f..2a1c2fd6ed 100644 --- a/tests/cli/supervisors/test_base_reloader.py +++ b/tests/cli/supervisors/test_base_reloader.py @@ -14,7 +14,7 @@ def should_restart(self) -> bool: return True -def empty(): +def empty(*args, **kwargs): pass From 2ad2d1f73a69fd7402f4393e23f4dcece3ba67fe Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Thu, 9 May 2024 13:22:42 +0300 Subject: [PATCH 5/6] tests: fix kafka warnings --- faststream/confluent/testing.py | 7 +++++-- faststream/kafka/testing.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/faststream/confluent/testing.py b/faststream/confluent/testing.py index 3052828424..9420ff3aa5 100644 --- a/faststream/confluent/testing.py +++ b/faststream/confluent/testing.py @@ -1,6 +1,6 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock from typing_extensions import override @@ -240,4 +240,7 @@ def build_message( def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock: - return AsyncMock() + mock = AsyncMock() + mock.getone.return_value = MagicMock() + mock.getmany.return_value = [MagicMock()] + return mock diff --git a/faststream/kafka/testing.py b/faststream/kafka/testing.py index 89eb5f6c70..fd8b520332 100755 --- a/faststream/kafka/testing.py +++ b/faststream/kafka/testing.py @@ -1,6 +1,6 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Callable, Dict, Optional -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock from aiokafka import ConsumerRecord from typing_extensions import override @@ -210,4 +210,7 @@ def build_message( def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock: - return AsyncMock() + mock = AsyncMock() + mock.subscribe = MagicMock + mock.assign = MagicMock + return mock From afcfd78ecc38a262ef7c76eca0a50e51c4185c2d Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Thu, 9 May 2024 13:56:26 +0300 Subject: [PATCH 6/6] tests: fix hanging test --- .codespell-whitelist.txt | 2 +- tests/brokers/base/consume.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.codespell-whitelist.txt b/.codespell-whitelist.txt index dcfed576bf..6b1a432b87 100644 --- a/.codespell-whitelist.txt +++ b/.codespell-whitelist.txt @@ -1 +1 @@ -dependant +dependant \ No newline at end of file diff --git a/tests/brokers/base/consume.py b/tests/brokers/base/consume.py index 68f4edc3e3..fc3ad0956d 100644 --- a/tests/brokers/base/consume.py +++ b/tests/brokers/base/consume.py @@ -246,6 +246,8 @@ def subscriber(m): timeout=self.timeout, ) + await sub.close() + assert event.is_set()