diff --git a/.codespell-whitelist.txt b/.codespell-whitelist.txt index dcfed576bf..6b1a432b87 100644 --- a/.codespell-whitelist.txt +++ b/.codespell-whitelist.txt @@ -1 +1 @@ -dependant +dependant \ No newline at end of file diff --git a/faststream/asyncapi/site.py b/faststream/asyncapi/site.py index 73184f9bb4..fcc0aefea6 100644 --- a/faststream/asyncapi/site.py +++ b/faststream/asyncapi/site.py @@ -102,7 +102,7 @@ def serve_app( ) -> None: """Serve the HTTPServer with AsyncAPI schema.""" logger.info(f"HTTPServer running on http://{host}:{port} (Press CTRL+C to quit)") - logger.warn("Please, do not use it in production.") + logger.warning("Please, do not use it in production.") server.HTTPServer( (host, port), diff --git a/faststream/confluent/broker/broker.py b/faststream/confluent/broker/broker.py index 8f9de15b09..9f31fbbb5e 100644 --- a/faststream/confluent/broker/broker.py +++ b/faststream/confluent/broker/broker.py @@ -23,7 +23,11 @@ from faststream.broker.message import gen_cor_id from faststream.confluent.broker.logging import KafkaLoggingBroker from faststream.confluent.broker.registrator import KafkaRegistrator -from faststream.confluent.client import AsyncConfluentProducer, _missing +from faststream.confluent.client import ( + AsyncConfluentConsumer, + AsyncConfluentProducer, + _missing, +) from faststream.confluent.publisher.producer import AsyncConfluentFastProducer from faststream.confluent.schemas.params import ConsumerConnectionParams from faststream.confluent.security import parse_security @@ -425,7 +429,7 @@ async def connect( Doc("Kafka addresses to connect."), ] = Parameter.empty, **kwargs: Any, - ) -> ConsumerConnectionParams: + ) -> Callable[..., AsyncConfluentConsumer]: if bootstrap_servers is not Parameter.empty: kwargs["bootstrap_servers"] = bootstrap_servers @@ -437,7 +441,7 @@ async def _connect( # type: ignore[override] *, client_id: str, **kwargs: Any, - ) -> ConsumerConnectionParams: + ) -> Callable[..., AsyncConfluentConsumer]: security_params = parse_security(self.security) kwargs.update(security_params) @@ -450,7 +454,10 @@ async def _connect( # type: ignore[override] producer=producer, ) - return filter_by_dict(ConsumerConnectionParams, kwargs) + return partial( + AsyncConfluentConsumer, + **filter_by_dict(ConsumerConnectionParams, kwargs), + ) async def start(self) -> None: await super().start() @@ -464,7 +471,10 @@ async def start(self) -> None: @property def _subscriber_setup_extra(self) -> "AnyDict": - return {"client_id": self.client_id, "connection_data": self._connection or {}} + return { + "client_id": self.client_id, + "builder": self._connection, + } @override async def publish( # type: ignore[override] diff --git a/faststream/confluent/broker/logging.py b/faststream/confluent/broker/logging.py index 9eebc89461..4fead65305 100644 --- a/faststream/confluent/broker/logging.py +++ b/faststream/confluent/broker/logging.py @@ -1,9 +1,9 @@ import logging from inspect import Parameter -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Tuple, Union from faststream.broker.core.usecase import BrokerUsecase -from faststream.confluent.schemas.params import ConsumerConnectionParams +from faststream.confluent.client import AsyncConfluentConsumer from faststream.log.logging import get_broker_logger if TYPE_CHECKING: @@ -15,7 +15,7 @@ class KafkaLoggingBroker( BrokerUsecase[ Union["confluent_kafka.Message", Tuple["confluent_kafka.Message", ...]], - ConsumerConnectionParams, + Callable[..., AsyncConfluentConsumer], ] ): """A class that extends the LoggingMixin class and adds additional functionality for logging Kafka related information.""" diff --git a/faststream/confluent/broker/registrator.py b/faststream/confluent/broker/registrator.py index 4fde6249c3..6306d10bd9 100644 --- a/faststream/confluent/broker/registrator.py +++ b/faststream/confluent/broker/registrator.py @@ -1,4 +1,3 @@ -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -18,7 +17,6 @@ from faststream.broker.core.abc import ABCBroker from faststream.broker.utils import default_filter -from faststream.confluent.client import AsyncConfluentConsumer from faststream.confluent.publisher.asyncapi import AsyncAPIPublisher from faststream.confluent.subscriber.asyncapi import AsyncAPISubscriber from faststream.exceptions import SetupError @@ -1235,29 +1233,6 @@ def subscriber( if not auto_commit and not group_id: raise SetupError("You should install `group_id` with manual commit mode") - builder = partial( - AsyncConfluentConsumer, - key_deserializer=key_deserializer, - value_deserializer=value_deserializer, - fetch_max_wait_ms=fetch_max_wait_ms, - fetch_max_bytes=fetch_max_bytes, - fetch_min_bytes=fetch_min_bytes, - max_partition_fetch_bytes=max_partition_fetch_bytes, - auto_offset_reset=auto_offset_reset, - enable_auto_commit=auto_commit, - auto_commit_interval_ms=auto_commit_interval_ms, - check_crcs=check_crcs, - partition_assignment_strategy=partition_assignment_strategy, - max_poll_interval_ms=max_poll_interval_ms, - rebalance_timeout_ms=rebalance_timeout_ms, - session_timeout_ms=session_timeout_ms, - heartbeat_interval_ms=heartbeat_interval_ms, - consumer_timeout_ms=consumer_timeout_ms, - max_poll_records=max_poll_records, - exclude_internal_topics=exclude_internal_topics, - isolation_level=isolation_level, - ) - subscriber = super().subscriber( AsyncAPISubscriber.create( *topics, @@ -1265,7 +1240,27 @@ def subscriber( batch_timeout_ms=batch_timeout_ms, max_records=max_records, group_id=group_id, - builder=builder, + connection_data={ + "key_deserializer": key_deserializer, + "value_deserializer": value_deserializer, + "fetch_max_wait_ms": fetch_max_wait_ms, + "fetch_max_bytes": fetch_max_bytes, + "fetch_min_bytes": fetch_min_bytes, + "max_partition_fetch_bytes": max_partition_fetch_bytes, + "auto_offset_reset": auto_offset_reset, + "enable_auto_commit": auto_commit, + "auto_commit_interval_ms": auto_commit_interval_ms, + "check_crcs": check_crcs, + "partition_assignment_strategy": partition_assignment_strategy, + "max_poll_interval_ms": max_poll_interval_ms, + "rebalance_timeout_ms": rebalance_timeout_ms, + "session_timeout_ms": session_timeout_ms, + "heartbeat_interval_ms": heartbeat_interval_ms, + "consumer_timeout_ms": consumer_timeout_ms, + "max_poll_records": max_poll_records, + "exclude_internal_topics": exclude_internal_topics, + "isolation_level": isolation_level, + }, is_manual=not auto_commit, # subscriber args no_ack=no_ack, diff --git a/faststream/confluent/subscriber/asyncapi.py b/faststream/confluent/subscriber/asyncapi.py index 8da47a800e..d31bfa05f2 100644 --- a/faststream/confluent/subscriber/asyncapi.py +++ b/faststream/confluent/subscriber/asyncapi.py @@ -1,6 +1,5 @@ from typing import ( TYPE_CHECKING, - Callable, Dict, Iterable, Literal, @@ -33,7 +32,7 @@ from fast_depends.dependencies import Depends from faststream.broker.types import BrokerMiddleware - from faststream.confluent.client import AsyncConfluentConsumer + from faststream.types import AnyDict class AsyncAPISubscriber(LogicSubscriber[MsgType]): @@ -77,7 +76,7 @@ def create( max_records: Optional[int], # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -99,7 +98,7 @@ def create( max_records: Optional[int], # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -121,7 +120,7 @@ def create( max_records: Optional[int], # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -148,7 +147,7 @@ def create( max_records: Optional[int], # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -171,7 +170,7 @@ def create( batch_timeout_ms=batch_timeout_ms, max_records=max_records, group_id=group_id, - builder=builder, + connection_data=connection_data, is_manual=is_manual, no_ack=no_ack, retry=retry, @@ -185,7 +184,7 @@ def create( return AsyncAPIDefaultSubscriber( *topics, group_id=group_id, - builder=builder, + connection_data=connection_data, is_manual=is_manual, no_ack=no_ack, retry=retry, diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index d778086bae..e5e23ed710 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -19,7 +19,6 @@ from faststream.broker.subscriber.usecase import SubscriberUsecase from faststream.broker.types import MsgType from faststream.confluent.parser import AsyncConfluentParser -from faststream.confluent.schemas.params import ConsumerConnectionParams if TYPE_CHECKING: from fast_depends.dependencies import Depends @@ -41,7 +40,9 @@ class LogicSubscriber(ABC, SubscriberUsecase[MsgType]): topics: Sequence[str] group_id: Optional[str] + builder: Optional[Callable[..., "AsyncConfluentConsumer"]] consumer: Optional["AsyncConfluentConsumer"] + task: Optional["asyncio.Task[None]"] client_id: Optional[str] @@ -50,7 +51,7 @@ def __init__( *topics: str, # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args default_parser: "AsyncCallable", @@ -81,20 +82,20 @@ def __init__( self.group_id = group_id self.topics = topics self.is_manual = is_manual - self.builder = builder + self.builder = None self.consumer = None self.task = None # Setup it later self.client_id = "" - self.__connection_data = ConsumerConnectionParams() + self.__connection_data = connection_data @override def setup( # type: ignore[override] self, *, client_id: Optional[str], - connection_data: "ConsumerConnectionParams", + builder: Callable[..., "AsyncConfluentConsumer"], # basic args logger: Optional["LoggerProto"], producer: Optional["ProducerProto"], @@ -110,7 +111,7 @@ def setup( # type: ignore[override] _call_decorators: Iterable["Decorator"], ) -> None: self.client_id = client_id - self.__connection_data = connection_data + self.builder = builder super().setup( logger=logger, @@ -128,6 +129,8 @@ def setup( # type: ignore[override] @override async def start(self) -> None: """Start the consumer.""" + assert self.builder, "You should setup subscriber at first." # nosec B101 + self.consumer = consumer = self.builder( *self.topics, group_id=self.group_id, @@ -172,7 +175,7 @@ async def get_msg(self) -> Optional[MsgType]: raise NotImplementedError() async def _consume(self) -> None: - assert self.consumer, "You need to start handler first" # nosec B101 + assert self.consumer, "You should start subscriber at first." # nosec B101 connected = True while self.running: @@ -219,7 +222,7 @@ def __init__( *topics: str, # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -234,7 +237,7 @@ def __init__( super().__init__( *topics, group_id=group_id, - builder=builder, + connection_data=connection_data, is_manual=is_manual, # subscriber args default_parser=AsyncConfluentParser.parse_message, @@ -278,7 +281,7 @@ def __init__( max_records: Optional[int], # Kafka information group_id: Optional[str], - builder: Callable[..., "AsyncConfluentConsumer"], + connection_data: "AnyDict", is_manual: bool, # Subscriber args no_ack: bool, @@ -296,7 +299,7 @@ def __init__( super().__init__( *topics, group_id=group_id, - builder=builder, + connection_data=connection_data, is_manual=is_manual, # subscriber args default_parser=AsyncConfluentParser.parse_message_batch, diff --git a/faststream/confluent/testing.py b/faststream/confluent/testing.py index 4559cbde8b..9420ff3aa5 100644 --- a/faststream/confluent/testing.py +++ b/faststream/confluent/testing.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +from unittest.mock import AsyncMock, MagicMock from typing_extensions import override @@ -22,8 +23,13 @@ class TestKafkaBroker(TestBroker[KafkaBroker]): """A class to test Kafka brokers.""" @staticmethod - async def _fake_connect(broker: KafkaBroker, *args: Any, **kwargs: Any) -> None: + async def _fake_connect( # type: ignore[override] + broker: KafkaBroker, + *args: Any, + **kwargs: Any, + ) -> Callable[..., AsyncMock]: broker._producer = FakeProducer(broker) + return _fake_connection @staticmethod def create_publisher_fake_subscriber( @@ -231,3 +237,10 @@ def build_message( timestamp_type=0 + 1, timestamp_ms=timestamp_ms or int(datetime.now().timestamp()), ) + + +def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock: + mock = AsyncMock() + mock.getone.return_value = MagicMock() + mock.getmany.return_value = [MagicMock()] + return mock diff --git a/faststream/kafka/broker/broker.py b/faststream/kafka/broker/broker.py index 59d6e733d6..2a29796860 100644 --- a/faststream/kafka/broker/broker.py +++ b/faststream/kafka/broker/broker.py @@ -557,7 +557,7 @@ async def connect( # type: ignore[override] Doc("Kafka addresses to connect."), ] = Parameter.empty, **kwargs: "Unpack[KafkaInitKwargs]", - ) -> ConsumerConnectionParams: + ) -> Callable[..., aiokafka.AIOKafkaConsumer]: """Connect to Kafka servers manually. Consumes the same with `KafkaBroker.__init__` arguments and overrides them. @@ -579,7 +579,7 @@ async def _connect( # type: ignore[override] *, client_id: str, **kwargs: Any, - ) -> ConsumerConnectionParams: + ) -> Callable[..., aiokafka.AIOKafkaConsumer]: security_params = parse_security(self.security) kwargs.update(security_params) @@ -593,7 +593,10 @@ async def _connect( # type: ignore[override] producer=producer, ) - return filter_by_dict(ConsumerConnectionParams, kwargs) + return partial( + aiokafka.AIOKafkaConsumer, + **filter_by_dict(ConsumerConnectionParams, kwargs), + ) async def start(self) -> None: """Connect broker to Kafka and startup all subscribers.""" @@ -610,7 +613,7 @@ async def start(self) -> None: def _subscriber_setup_extra(self) -> "AnyDict": return { "client_id": self.client_id, - "connection_args": self._connection or {}, + "builder": self._connection, } @override diff --git a/faststream/kafka/broker/logging.py b/faststream/kafka/broker/logging.py index df828024da..16b1103b83 100644 --- a/faststream/kafka/broker/logging.py +++ b/faststream/kafka/broker/logging.py @@ -1,9 +1,8 @@ import logging from inspect import Parameter -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Tuple, Union from faststream.broker.core.usecase import BrokerUsecase -from faststream.kafka.schemas.params import ConsumerConnectionParams from faststream.log.logging import get_broker_logger if TYPE_CHECKING: @@ -15,7 +14,7 @@ class KafkaLoggingBroker( BrokerUsecase[ Union["aiokafka.ConsumerRecord", Tuple["aiokafka.ConsumerRecord", ...]], - ConsumerConnectionParams, + Callable[..., "aiokafka.AIOKafkaConsumer"], ] ): """A class that extends the LoggingMixin class and adds additional functionality for logging Kafka related information.""" diff --git a/faststream/kafka/broker/registrator.py b/faststream/kafka/broker/registrator.py index 899e5828d5..afc69a459c 100644 --- a/faststream/kafka/broker/registrator.py +++ b/faststream/kafka/broker/registrator.py @@ -1,4 +1,3 @@ -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -14,13 +13,12 @@ overload, ) -from aiokafka import AIOKafkaConsumer, ConsumerRecord +from aiokafka import ConsumerRecord from aiokafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from typing_extensions import Annotated, Doc, deprecated, override from faststream.broker.core.abc import ABCBroker from faststream.broker.utils import default_filter -from faststream.exceptions import SetupError from faststream.kafka.publisher.asyncapi import AsyncAPIPublisher from faststream.kafka.subscriber.asyncapi import AsyncAPISubscriber @@ -1395,32 +1393,6 @@ def subscriber( "AsyncAPIDefaultSubscriber", "AsyncAPIBatchSubscriber", ]: - if not auto_commit and not group_id: - raise SetupError("You should install `group_id` with manual commit mode") - - builder = partial( - AIOKafkaConsumer, - key_deserializer=key_deserializer, - value_deserializer=value_deserializer, - fetch_max_wait_ms=fetch_max_wait_ms, - fetch_max_bytes=fetch_max_bytes, - fetch_min_bytes=fetch_min_bytes, - max_partition_fetch_bytes=max_partition_fetch_bytes, - auto_offset_reset=auto_offset_reset, - enable_auto_commit=auto_commit, - auto_commit_interval_ms=auto_commit_interval_ms, - check_crcs=check_crcs, - partition_assignment_strategy=partition_assignment_strategy, - max_poll_interval_ms=max_poll_interval_ms, - rebalance_timeout_ms=rebalance_timeout_ms, - session_timeout_ms=session_timeout_ms, - heartbeat_interval_ms=heartbeat_interval_ms, - consumer_timeout_ms=consumer_timeout_ms, - max_poll_records=max_poll_records, - exclude_internal_topics=exclude_internal_topics, - isolation_level=isolation_level, - ) - subscriber = super().subscriber( AsyncAPISubscriber.create( *topics, @@ -1430,8 +1402,28 @@ def subscriber( group_id=group_id, listener=listener, pattern=pattern, + connection_args={ + "key_deserializer": key_deserializer, + "value_deserializer": value_deserializer, + "fetch_max_wait_ms": fetch_max_wait_ms, + "fetch_max_bytes": fetch_max_bytes, + "fetch_min_bytes": fetch_min_bytes, + "max_partition_fetch_bytes": max_partition_fetch_bytes, + "auto_offset_reset": auto_offset_reset, + "enable_auto_commit": auto_commit, + "auto_commit_interval_ms": auto_commit_interval_ms, + "check_crcs": check_crcs, + "partition_assignment_strategy": partition_assignment_strategy, + "max_poll_interval_ms": max_poll_interval_ms, + "rebalance_timeout_ms": rebalance_timeout_ms, + "session_timeout_ms": session_timeout_ms, + "heartbeat_interval_ms": heartbeat_interval_ms, + "consumer_timeout_ms": consumer_timeout_ms, + "max_poll_records": max_poll_records, + "exclude_internal_topics": exclude_internal_topics, + "isolation_level": isolation_level, + }, partitions=partitions, - builder=builder, is_manual=not auto_commit, # subscriber args no_ack=no_ack, diff --git a/faststream/kafka/subscriber/asyncapi.py b/faststream/kafka/subscriber/asyncapi.py index f2897d3fdf..ec31001633 100644 --- a/faststream/kafka/subscriber/asyncapi.py +++ b/faststream/kafka/subscriber/asyncapi.py @@ -1,6 +1,5 @@ from typing import ( TYPE_CHECKING, - Callable, Dict, Iterable, Literal, @@ -30,11 +29,12 @@ ) if TYPE_CHECKING: - from aiokafka import AIOKafkaConsumer, ConsumerRecord, TopicPartition + from aiokafka import ConsumerRecord, TopicPartition from aiokafka.abc import ConsumerRebalanceListener from fast_depends.dependencies import Depends from faststream.broker.types import BrokerMiddleware + from faststream.types import AnyDict class AsyncAPISubscriber(LogicSubscriber[MsgType]): @@ -80,8 +80,8 @@ def create( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], + connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - builder: Callable[..., "AIOKafkaConsumer"], is_manual: bool, # Subscriber args no_ack: bool, @@ -105,8 +105,8 @@ def create( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], + connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - builder: Callable[..., "AIOKafkaConsumer"], is_manual: bool, # Subscriber args no_ack: bool, @@ -130,8 +130,8 @@ def create( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], + connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - builder: Callable[..., "AIOKafkaConsumer"], is_manual: bool, # Subscriber args no_ack: bool, @@ -160,8 +160,8 @@ def create( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], + connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - builder: Callable[..., "AIOKafkaConsumer"], is_manual: bool, # Subscriber args no_ack: bool, @@ -178,6 +178,9 @@ def create( "AsyncAPIDefaultSubscriber", "AsyncAPIBatchSubscriber", ]: + if is_manual and not group_id: + raise SetupError("You should install `group_id` with manual commit mode") + if not topics and not partitions and not pattern: raise SetupError( "You should provide either `topics` or `partitions` or `pattern`." @@ -197,8 +200,8 @@ def create( group_id=group_id, listener=listener, pattern=pattern, + connection_args=connection_args, partitions=partitions, - builder=builder, is_manual=is_manual, no_ack=no_ack, retry=retry, @@ -214,8 +217,8 @@ def create( group_id=group_id, listener=listener, pattern=pattern, + connection_args=connection_args, partitions=partitions, - builder=builder, is_manual=is_manual, no_ack=no_ack, retry=retry, diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index 963f00c524..818922c48e 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -35,7 +35,6 @@ from faststream.broker.message import StreamMessage from faststream.broker.publisher.proto import ProducerProto - from faststream.kafka.schemas.params import ConsumerConnectionParams from faststream.types import AnyDict, Decorator, LoggerProto @@ -45,7 +44,9 @@ class LogicSubscriber(ABC, SubscriberUsecase[MsgType]): topics: Sequence[str] group_id: Optional[str] + builder: Optional[Callable[..., "AIOKafkaConsumer"]] consumer: Optional["AIOKafkaConsumer"] + task: Optional["asyncio.Task[None]"] client_id: Optional[str] batch: bool @@ -55,7 +56,7 @@ def __init__( *topics: str, # Kafka information group_id: Optional[str], - builder: Callable[..., "AIOKafkaConsumer"], + connection_args: "AnyDict", listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], partitions: Iterable["TopicPartition"], @@ -86,10 +87,12 @@ def __init__( include_in_schema=include_in_schema, ) - self.group_id = group_id self.topics = topics + self.partitions = partitions + self.group_id = group_id + self.is_manual = is_manual - self.builder = builder + self.builder = None self.consumer = None self.task = None @@ -97,15 +100,14 @@ def __init__( self.client_id = "" self.__pattern = pattern self.__listener = listener - self.partitions = partitions - self.__connection_args: "ConsumerConnectionParams" = {} + self.__connection_args = connection_args @override def setup( # type: ignore[override] self, *, client_id: Optional[str], - connection_args: "ConsumerConnectionParams", + builder: Callable[..., "AIOKafkaConsumer"], # basic args logger: Optional["LoggerProto"], producer: Optional["ProducerProto"], @@ -121,7 +123,7 @@ def setup( # type: ignore[override] _call_decorators: Iterable["Decorator"], ) -> None: self.client_id = client_id - self.__connection_args = connection_args + self.builder = builder super().setup( logger=logger, @@ -138,6 +140,8 @@ def setup( # type: ignore[override] async def start(self) -> None: """Start the consumer.""" + assert self.builder, "You should setup subscriber at first." # nosec B101 + self.consumer = consumer = self.builder( group_id=self.group_id, client_id=self.client_id, @@ -192,7 +196,7 @@ async def get_msg(self) -> MsgType: raise NotImplementedError() async def _consume(self) -> None: - assert self.consumer, "You should setup subscriber at first." # nosec B101 + assert self.consumer, "You should start subscriber at first." # nosec B101 connected = True while self.running: @@ -286,8 +290,8 @@ def __init__( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], + connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - builder: Callable[..., "AIOKafkaConsumer"], is_manual: bool, # Subscriber args no_ack: bool, @@ -304,8 +308,8 @@ def __init__( group_id=group_id, listener=listener, pattern=pattern, + connection_args=connection_args, partitions=partitions, - builder=builder, is_manual=is_manual, # subscriber args default_parser=AioKafkaParser.parse_message, @@ -336,8 +340,8 @@ def __init__( group_id: Optional[str], listener: Optional["ConsumerRebalanceListener"], pattern: Optional[str], + connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - builder: Callable[..., "AIOKafkaConsumer"], is_manual: bool, # Subscriber args no_ack: bool, @@ -359,8 +363,8 @@ def __init__( group_id=group_id, listener=listener, pattern=pattern, + connection_args=connection_args, partitions=partitions, - builder=builder, is_manual=is_manual, # subscriber args default_parser=AioKafkaParser.parse_message_batch, diff --git a/faststream/kafka/testing.py b/faststream/kafka/testing.py index e28056edf6..fd8b520332 100755 --- a/faststream/kafka/testing.py +++ b/faststream/kafka/testing.py @@ -1,5 +1,6 @@ from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from unittest.mock import AsyncMock, MagicMock from aiokafka import ConsumerRecord from typing_extensions import override @@ -24,8 +25,13 @@ class TestKafkaBroker(TestBroker[KafkaBroker]): """A class to test Kafka brokers.""" @staticmethod - async def _fake_connect(broker: KafkaBroker, *args: Any, **kwargs: Any) -> None: + async def _fake_connect( # type: ignore[override] + broker: KafkaBroker, + *args: Any, + **kwargs: Any, + ) -> Callable[..., AsyncMock]: broker._producer = FakeProducer(broker) + return _fake_connection @staticmethod def create_publisher_fake_subscriber( @@ -201,3 +207,10 @@ def build_message( offset=0, headers=[(i, j.encode()) for i, j in headers.items()], ) + + +def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock: + mock = AsyncMock() + mock.subscribe = MagicMock + mock.assign = MagicMock + return mock diff --git a/faststream/nats/testing.py b/faststream/nats/testing.py index f106c93f9d..6681ba5b14 100644 --- a/faststream/nats/testing.py +++ b/faststream/nats/testing.py @@ -1,4 +1,5 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from unittest.mock import AsyncMock from nats.aio.msg import Msg from typing_extensions import override @@ -40,8 +41,14 @@ def f(msg: Any) -> None: return sub.calls[0].handler @staticmethod - async def _fake_connect(broker: NatsBroker, *args: Any, **kwargs: Any) -> None: + async def _fake_connect( # type: ignore[override] + broker: NatsBroker, + *args: Any, + **kwargs: Any, + ) -> AsyncMock: + broker.stream = AsyncMock() # type: ignore[assignment] broker._js_producer = broker._producer = FakeProducer(broker) # type: ignore[assignment] + return AsyncMock() @staticmethod def remove_publisher_fake_subscriber( diff --git a/faststream/redis/testing.py b/faststream/redis/testing.py index 74541322f1..7d4a60da4e 100644 --- a/faststream/redis/testing.py +++ b/faststream/redis/testing.py @@ -1,5 +1,6 @@ import re from typing import TYPE_CHECKING, Any, Optional, Sequence, Union +from unittest.mock import AsyncMock, MagicMock from typing_extensions import override @@ -49,12 +50,15 @@ def f(msg: Any) -> None: return sub.calls[0].handler @staticmethod - async def _fake_connect( + async def _fake_connect( # type: ignore[override] broker: RedisBroker, *args: Any, **kwargs: Any, - ) -> None: + ) -> AsyncMock: broker._producer = FakeProducer(broker) # type: ignore[assignment] + connection = MagicMock() + connection.pubsub.side_effect = AsyncMock + return connection @staticmethod def remove_publisher_fake_subscriber( diff --git a/tests/brokers/base/consume.py b/tests/brokers/base/consume.py index 654d3b19f8..fc3ad0956d 100644 --- a/tests/brokers/base/consume.py +++ b/tests/brokers/base/consume.py @@ -221,6 +221,35 @@ async def handler(m: Foo, dep: int = Depends(dependency), broker=Context()): assert event.is_set() mock.assert_called_once_with({"x": 1}, "100", consume_broker) + async def test_dynamic_sub( + self, + queue: str, + consume_broker: BrokerUsecase, + event: asyncio.Event, + ): + def subscriber(m): + event.set() + + async with consume_broker: + await consume_broker.start() + + sub = consume_broker.subscriber(queue, **self.subscriber_kwargs) + sub(subscriber) + consume_broker.setup_subscriber(sub) + await sub.start() + + await asyncio.wait( + ( + asyncio.create_task(consume_broker.publish("hello", queue)), + asyncio.create_task(event.wait()), + ), + timeout=self.timeout, + ) + + await sub.close() + + assert event.is_set() + @pytest.mark.asyncio() class BrokerRealConsumeTestcase(BrokerConsumeTestcase): diff --git a/tests/cli/supervisors/test_base_reloader.py b/tests/cli/supervisors/test_base_reloader.py index c143d39c9f..2a1c2fd6ed 100644 --- a/tests/cli/supervisors/test_base_reloader.py +++ b/tests/cli/supervisors/test_base_reloader.py @@ -14,7 +14,7 @@ def should_restart(self) -> bool: return True -def empty(): +def empty(*args, **kwargs): pass