Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/correct dynamic subscriber registration #1433

Merged
merged 9 commits into from
May 9, 2024
2 changes: 1 addition & 1 deletion .codespell-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
dependant
dependant
2 changes: 1 addition & 1 deletion faststream/asyncapi/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def serve_app(
) -> None:
"""Serve the HTTPServer with AsyncAPI schema."""
logger.info(f"HTTPServer running on http://{host}:{port} (Press CTRL+C to quit)")
logger.warn("Please, do not use it in production.")
logger.warning("Please, do not use it in production.")

server.HTTPServer(
(host, port),
Expand Down
20 changes: 15 additions & 5 deletions faststream/confluent/broker/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from faststream.broker.message import gen_cor_id
from faststream.confluent.broker.logging import KafkaLoggingBroker
from faststream.confluent.broker.registrator import KafkaRegistrator
from faststream.confluent.client import AsyncConfluentProducer, _missing
from faststream.confluent.client import (
AsyncConfluentConsumer,
AsyncConfluentProducer,
_missing,
)
from faststream.confluent.publisher.producer import AsyncConfluentFastProducer
from faststream.confluent.schemas.params import ConsumerConnectionParams
from faststream.confluent.security import parse_security
Expand Down Expand Up @@ -425,7 +429,7 @@ async def connect(
Doc("Kafka addresses to connect."),
] = Parameter.empty,
**kwargs: Any,
) -> ConsumerConnectionParams:
) -> Callable[..., AsyncConfluentConsumer]:
if bootstrap_servers is not Parameter.empty:
kwargs["bootstrap_servers"] = bootstrap_servers

Expand All @@ -437,7 +441,7 @@ async def _connect( # type: ignore[override]
*,
client_id: str,
**kwargs: Any,
) -> ConsumerConnectionParams:
) -> Callable[..., AsyncConfluentConsumer]:
security_params = parse_security(self.security)
kwargs.update(security_params)

Expand All @@ -450,7 +454,10 @@ async def _connect( # type: ignore[override]
producer=producer,
)

return filter_by_dict(ConsumerConnectionParams, kwargs)
return partial(
AsyncConfluentConsumer,
**filter_by_dict(ConsumerConnectionParams, kwargs),
)

async def start(self) -> None:
await super().start()
Expand All @@ -464,7 +471,10 @@ async def start(self) -> None:

@property
def _subscriber_setup_extra(self) -> "AnyDict":
return {"client_id": self.client_id, "connection_data": self._connection or {}}
return {
"client_id": self.client_id,
"builder": self._connection,
}

@override
async def publish( # type: ignore[override]
Expand Down
6 changes: 3 additions & 3 deletions faststream/confluent/broker/logging.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
from inspect import Parameter
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Tuple, Union

from faststream.broker.core.usecase import BrokerUsecase
from faststream.confluent.schemas.params import ConsumerConnectionParams
from faststream.confluent.client import AsyncConfluentConsumer
from faststream.log.logging import get_broker_logger

if TYPE_CHECKING:
Expand All @@ -15,7 +15,7 @@
class KafkaLoggingBroker(
BrokerUsecase[
Union["confluent_kafka.Message", Tuple["confluent_kafka.Message", ...]],
ConsumerConnectionParams,
Callable[..., AsyncConfluentConsumer],
]
):
"""A class that extends the LoggingMixin class and adds additional functionality for logging Kafka related information."""
Expand Down
47 changes: 21 additions & 26 deletions faststream/confluent/broker/registrator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -18,7 +17,6 @@

from faststream.broker.core.abc import ABCBroker
from faststream.broker.utils import default_filter
from faststream.confluent.client import AsyncConfluentConsumer
from faststream.confluent.publisher.asyncapi import AsyncAPIPublisher
from faststream.confluent.subscriber.asyncapi import AsyncAPISubscriber
from faststream.exceptions import SetupError
Expand Down Expand Up @@ -1235,37 +1233,34 @@ def subscriber(
if not auto_commit and not group_id:
raise SetupError("You should install `group_id` with manual commit mode")

builder = partial(
AsyncConfluentConsumer,
key_deserializer=key_deserializer,
value_deserializer=value_deserializer,
fetch_max_wait_ms=fetch_max_wait_ms,
fetch_max_bytes=fetch_max_bytes,
fetch_min_bytes=fetch_min_bytes,
max_partition_fetch_bytes=max_partition_fetch_bytes,
auto_offset_reset=auto_offset_reset,
enable_auto_commit=auto_commit,
auto_commit_interval_ms=auto_commit_interval_ms,
check_crcs=check_crcs,
partition_assignment_strategy=partition_assignment_strategy,
max_poll_interval_ms=max_poll_interval_ms,
rebalance_timeout_ms=rebalance_timeout_ms,
session_timeout_ms=session_timeout_ms,
heartbeat_interval_ms=heartbeat_interval_ms,
consumer_timeout_ms=consumer_timeout_ms,
max_poll_records=max_poll_records,
exclude_internal_topics=exclude_internal_topics,
isolation_level=isolation_level,
)

subscriber = super().subscriber(
AsyncAPISubscriber.create(
*topics,
batch=batch,
batch_timeout_ms=batch_timeout_ms,
max_records=max_records,
group_id=group_id,
builder=builder,
connection_data={
"key_deserializer": key_deserializer,
"value_deserializer": value_deserializer,
"fetch_max_wait_ms": fetch_max_wait_ms,
"fetch_max_bytes": fetch_max_bytes,
"fetch_min_bytes": fetch_min_bytes,
"max_partition_fetch_bytes": max_partition_fetch_bytes,
"auto_offset_reset": auto_offset_reset,
"enable_auto_commit": auto_commit,
"auto_commit_interval_ms": auto_commit_interval_ms,
"check_crcs": check_crcs,
"partition_assignment_strategy": partition_assignment_strategy,
"max_poll_interval_ms": max_poll_interval_ms,
"rebalance_timeout_ms": rebalance_timeout_ms,
"session_timeout_ms": session_timeout_ms,
"heartbeat_interval_ms": heartbeat_interval_ms,
"consumer_timeout_ms": consumer_timeout_ms,
"max_poll_records": max_poll_records,
"exclude_internal_topics": exclude_internal_topics,
"isolation_level": isolation_level,
},
is_manual=not auto_commit,
# subscriber args
no_ack=no_ack,
Expand Down
15 changes: 7 additions & 8 deletions faststream/confluent/subscriber/asyncapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Iterable,
Literal,
Expand Down Expand Up @@ -33,7 +32,7 @@
from fast_depends.dependencies import Depends

from faststream.broker.types import BrokerMiddleware
from faststream.confluent.client import AsyncConfluentConsumer
from faststream.types import AnyDict


class AsyncAPISubscriber(LogicSubscriber[MsgType]):
Expand Down Expand Up @@ -77,7 +76,7 @@ def create(
max_records: Optional[int],
# Kafka information
group_id: Optional[str],
builder: Callable[..., "AsyncConfluentConsumer"],
connection_data: "AnyDict",
is_manual: bool,
# Subscriber args
no_ack: bool,
Expand All @@ -99,7 +98,7 @@ def create(
max_records: Optional[int],
# Kafka information
group_id: Optional[str],
builder: Callable[..., "AsyncConfluentConsumer"],
connection_data: "AnyDict",
is_manual: bool,
# Subscriber args
no_ack: bool,
Expand All @@ -121,7 +120,7 @@ def create(
max_records: Optional[int],
# Kafka information
group_id: Optional[str],
builder: Callable[..., "AsyncConfluentConsumer"],
connection_data: "AnyDict",
is_manual: bool,
# Subscriber args
no_ack: bool,
Expand All @@ -148,7 +147,7 @@ def create(
max_records: Optional[int],
# Kafka information
group_id: Optional[str],
builder: Callable[..., "AsyncConfluentConsumer"],
connection_data: "AnyDict",
is_manual: bool,
# Subscriber args
no_ack: bool,
Expand All @@ -171,7 +170,7 @@ def create(
batch_timeout_ms=batch_timeout_ms,
max_records=max_records,
group_id=group_id,
builder=builder,
connection_data=connection_data,
is_manual=is_manual,
no_ack=no_ack,
retry=retry,
Expand All @@ -185,7 +184,7 @@ def create(
return AsyncAPIDefaultSubscriber(
*topics,
group_id=group_id,
builder=builder,
connection_data=connection_data,
is_manual=is_manual,
no_ack=no_ack,
retry=retry,
Expand Down
25 changes: 14 additions & 11 deletions faststream/confluent/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from faststream.broker.subscriber.usecase import SubscriberUsecase
from faststream.broker.types import MsgType
from faststream.confluent.parser import AsyncConfluentParser
from faststream.confluent.schemas.params import ConsumerConnectionParams

if TYPE_CHECKING:
from fast_depends.dependencies import Depends
Expand All @@ -41,7 +40,9 @@ class LogicSubscriber(ABC, SubscriberUsecase[MsgType]):
topics: Sequence[str]
group_id: Optional[str]

builder: Optional[Callable[..., "AsyncConfluentConsumer"]]
consumer: Optional["AsyncConfluentConsumer"]

task: Optional["asyncio.Task[None]"]
client_id: Optional[str]

Expand All @@ -50,7 +51,7 @@ def __init__(
*topics: str,
# Kafka information
group_id: Optional[str],
builder: Callable[..., "AsyncConfluentConsumer"],
connection_data: "AnyDict",
is_manual: bool,
# Subscriber args
default_parser: "AsyncCallable",
Expand Down Expand Up @@ -81,20 +82,20 @@ def __init__(
self.group_id = group_id
self.topics = topics
self.is_manual = is_manual
self.builder = builder
self.builder = None
self.consumer = None
self.task = None

# Setup it later
self.client_id = ""
self.__connection_data = ConsumerConnectionParams()
self.__connection_data = connection_data

@override
def setup( # type: ignore[override]
self,
*,
client_id: Optional[str],
connection_data: "ConsumerConnectionParams",
builder: Callable[..., "AsyncConfluentConsumer"],
# basic args
logger: Optional["LoggerProto"],
producer: Optional["ProducerProto"],
Expand All @@ -110,7 +111,7 @@ def setup( # type: ignore[override]
_call_decorators: Iterable["Decorator"],
) -> None:
self.client_id = client_id
self.__connection_data = connection_data
self.builder = builder

super().setup(
logger=logger,
Expand All @@ -128,6 +129,8 @@ def setup( # type: ignore[override]
@override
async def start(self) -> None:
"""Start the consumer."""
assert self.builder, "You should setup subscriber at first." # nosec B101

self.consumer = consumer = self.builder(
*self.topics,
group_id=self.group_id,
Expand Down Expand Up @@ -172,7 +175,7 @@ async def get_msg(self) -> Optional[MsgType]:
raise NotImplementedError()

async def _consume(self) -> None:
assert self.consumer, "You need to start handler first" # nosec B101
assert self.consumer, "You should start subscriber at first." # nosec B101

connected = True
while self.running:
Expand Down Expand Up @@ -219,7 +222,7 @@ def __init__(
*topics: str,
# Kafka information
group_id: Optional[str],
builder: Callable[..., "AsyncConfluentConsumer"],
connection_data: "AnyDict",
is_manual: bool,
# Subscriber args
no_ack: bool,
Expand All @@ -234,7 +237,7 @@ def __init__(
super().__init__(
*topics,
group_id=group_id,
builder=builder,
connection_data=connection_data,
is_manual=is_manual,
# subscriber args
default_parser=AsyncConfluentParser.parse_message,
Expand Down Expand Up @@ -278,7 +281,7 @@ def __init__(
max_records: Optional[int],
# Kafka information
group_id: Optional[str],
builder: Callable[..., "AsyncConfluentConsumer"],
connection_data: "AnyDict",
is_manual: bool,
# Subscriber args
no_ack: bool,
Expand All @@ -296,7 +299,7 @@ def __init__(
super().__init__(
*topics,
group_id=group_id,
builder=builder,
connection_data=connection_data,
is_manual=is_manual,
# subscriber args
default_parser=AsyncConfluentParser.parse_message_batch,
Expand Down
17 changes: 15 additions & 2 deletions faststream/confluent/testing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import AsyncMock, MagicMock

from typing_extensions import override

Expand All @@ -22,8 +23,13 @@ class TestKafkaBroker(TestBroker[KafkaBroker]):
"""A class to test Kafka brokers."""

@staticmethod
async def _fake_connect(broker: KafkaBroker, *args: Any, **kwargs: Any) -> None:
async def _fake_connect( # type: ignore[override]
broker: KafkaBroker,
*args: Any,
**kwargs: Any,
) -> Callable[..., AsyncMock]:
broker._producer = FakeProducer(broker)
return _fake_connection

@staticmethod
def create_publisher_fake_subscriber(
Expand Down Expand Up @@ -231,3 +237,10 @@ def build_message(
timestamp_type=0 + 1,
timestamp_ms=timestamp_ms or int(datetime.now().timestamp()),
)


def _fake_connection(*args: Any, **kwargs: Any) -> AsyncMock:
mock = AsyncMock()
mock.getone.return_value = MagicMock()
mock.getmany.return_value = [MagicMock()]
return mock
Loading