Skip to content

Commit

Permalink
Feat: stage 1 add typing, and mock class for concurrent subscriber
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniil Dumchenko committed Nov 11, 2024
1 parent fc61e91 commit bb8a1c9
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 26 deletions.
33 changes: 25 additions & 8 deletions faststream/kafka/broker/registrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from faststream.kafka.subscriber.asyncapi import (
AsyncAPIBatchSubscriber,
AsyncAPIDefaultSubscriber,
AsyncAPIConcurrentDefaultSubscriber
)


Expand All @@ -57,7 +58,7 @@ class KafkaRegistrator(

_subscribers: Dict[
int,
Union["AsyncAPIBatchSubscriber", "AsyncAPIDefaultSubscriber"],
Union["AsyncAPIBatchSubscriber", "AsyncAPIDefaultSubscriber", "AsyncAPIConcurrentDefaultSubscriber"],
]
_publishers: Dict[
int,
Expand Down Expand Up @@ -1548,6 +1549,10 @@ def subscriber(
Iterable["SubscriberMiddleware[KafkaMessage]"],
Doc("Subscriber middlewares to wrap incoming message processing."),
] = (),
max_workers: Annotated[
int,
Doc("Number of workers to process messages concurrently."),
] = 1,
filter: Annotated[
"Filter[KafkaMessage]",
Doc(
Expand Down Expand Up @@ -1592,11 +1597,13 @@ def subscriber(
) -> Union[
"AsyncAPIDefaultSubscriber",
"AsyncAPIBatchSubscriber",
"AsyncAPIConcurrentDefaultSubscriber",
]:
subscriber = super().subscriber(
create_subscriber(
*topics,
batch=batch,
max_workers=max_workers,
batch_timeout_ms=batch_timeout_ms,
max_records=max_records,
group_id=group_id,
Expand Down Expand Up @@ -1648,13 +1655,23 @@ def subscriber(
)

else:
return cast("AsyncAPIDefaultSubscriber", subscriber).add_call(
filter_=filter,
parser_=parser or self._parser,
decoder_=decoder or self._decoder,
dependencies_=dependencies,
middlewares_=middlewares,
)
if max_workers > 1:
return cast("AsyncAPIConcurrentDefaultSubscriber", subscriber).add_call(
filter_=filter,
parser_=parser or self._parser,
decoder_=decoder or self._decoder,
dependencies_=dependencies,
middlewares_=middlewares,
max_workers=max_workers
)
else:
return cast("AsyncAPIDefaultSubscriber", subscriber).add_call(
filter_=filter,
parser_=parser or self._parser,
decoder_=decoder or self._decoder,
dependencies_=dependencies,
middlewares_=middlewares,
)

@overload # type: ignore[override]
def publisher(
Expand Down
12 changes: 11 additions & 1 deletion faststream/kafka/fastapi/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from faststream.kafka.subscriber.asyncapi import (
AsyncAPIBatchSubscriber,
AsyncAPIDefaultSubscriber,
AsyncAPIConcurrentDefaultSubscriber
)
from faststream.security import BaseSecurity
from faststream.types import AnyDict, LoggerProto
Expand Down Expand Up @@ -2618,13 +2619,19 @@ def subscriber(
"""
),
] = False,
max_workers: Annotated[
int,
Doc("Number of workers to process messages concurrently."),
] = 1,
) -> Union[
"AsyncAPIBatchSubscriber",
"AsyncAPIDefaultSubscriber",
"AsyncAPIConcurrentDefaultSubscriber"
]:
subscriber = super().subscriber(
*topics,
group_id=group_id,
max_workers=max_workers,
key_deserializer=key_deserializer,
value_deserializer=value_deserializer,
fetch_max_wait_ms=fetch_max_wait_ms,
Expand Down Expand Up @@ -2675,7 +2682,10 @@ def subscriber(
if batch:
return cast("AsyncAPIBatchSubscriber", subscriber)
else:
return cast("AsyncAPIDefaultSubscriber", subscriber)
if max_workers > 1:
return cast("AsyncAPIConcurrentDefaultSubscriber", subscriber)
else:
return cast("AsyncAPIDefaultSubscriber", subscriber)

@overload # type: ignore[override]
def publisher(
Expand Down
5 changes: 5 additions & 0 deletions faststream/kafka/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,11 +525,16 @@ def __init__(
bool,
Doc("Whetever to include operation in AsyncAPI schema or not."),
] = True,
max_workers: Annotated[
int,
Doc("Number of workers to process messages concurrently."),
] = 1,
) -> None:
super().__init__(
call,
*topics,
publishers=publishers,
max_workers=max_workers,
group_id=group_id,
key_deserializer=key_deserializer,
value_deserializer=value_deserializer,
Expand Down
8 changes: 8 additions & 0 deletions faststream/kafka/subscriber/asyncapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
BatchSubscriber,
DefaultSubscriber,
LogicSubscriber,
ConcurrentDefaultSubscriber
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -69,3 +70,10 @@ class AsyncAPIBatchSubscriber(
AsyncAPISubscriber[Tuple["ConsumerRecord", ...]],
):
pass


class AsyncAPIConcurrentDefaultSubscriber(
ConcurrentDefaultSubscriber,
AsyncAPISubscriber["ConsumerRecord"],
):
pass
60 changes: 43 additions & 17 deletions faststream/kafka/subscriber/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from faststream.kafka.subscriber.asyncapi import (
AsyncAPIBatchSubscriber,
AsyncAPIDefaultSubscriber,
AsyncAPIConcurrentDefaultSubscriber
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -119,6 +120,7 @@ def create_subscriber(
partitions: Iterable["TopicPartition"],
is_manual: bool,
# Subscriber args
max_workers: int,
no_ack: bool,
no_reply: bool,
retry: bool,
Expand All @@ -133,10 +135,14 @@ def create_subscriber(
) -> Union[
"AsyncAPIDefaultSubscriber",
"AsyncAPIBatchSubscriber",
"AsyncAPIConcurrentDefaultSubscriber"
]:
if is_manual and not group_id:
raise SetupError("You must use `group_id` with manual commit mode.")

if is_manual and max_workers > 1:
raise SetupError("Max workers not work with manual commit mode.")

if not topics and not partitions and not pattern:
raise SetupError(
"You should provide either `topics` or `partitions` or `pattern`."
Expand Down Expand Up @@ -170,20 +176,40 @@ def create_subscriber(
)

else:
return AsyncAPIDefaultSubscriber(
*topics,
group_id=group_id,
listener=listener,
pattern=pattern,
connection_args=connection_args,
partitions=partitions,
is_manual=is_manual,
no_ack=no_ack,
no_reply=no_reply,
retry=retry,
broker_dependencies=broker_dependencies,
broker_middlewares=broker_middlewares,
title_=title_,
description_=description_,
include_in_schema=include_in_schema,
)
if max_workers > 1:
return AsyncAPIConcurrentDefaultSubscriber(
*topics,
max_workers=max_workers,
group_id=group_id,
listener=listener,
pattern=pattern,
connection_args=connection_args,
partitions=partitions,
is_manual=is_manual,
no_ack=no_ack,
no_reply=no_reply,
retry=retry,
broker_dependencies=broker_dependencies,
broker_middlewares=broker_middlewares,
title_=title_,
description_=description_,
include_in_schema=include_in_schema,
)
else:
return AsyncAPIDefaultSubscriber(
*topics,
group_id=group_id,
listener=listener,
pattern=pattern,
connection_args=connection_args,
partitions=partitions,
is_manual=is_manual,
no_ack=no_ack,
no_reply=no_reply,
retry=retry,
broker_dependencies=broker_dependencies,
broker_middlewares=broker_middlewares,
title_=title_,
description_=description_,
include_in_schema=include_in_schema,
)
30 changes: 30 additions & 0 deletions faststream/kafka/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
Iterable,
List,
Expand All @@ -20,6 +21,7 @@

from faststream.broker.publisher.fake import FakePublisher
from faststream.broker.subscriber.usecase import SubscriberUsecase
from faststream.broker.subscriber.mixins import ConcurrentMixin
from faststream.broker.types import (
AsyncCallable,
BrokerMiddleware,
Expand Down Expand Up @@ -471,3 +473,31 @@ def get_log_context(
topic=topic,
group_id=self.group_id,
)

class ConcurrentDefaultSubscriber(
ConcurrentMixin,
DefaultSubscriber["ConsumerRecord"]
):
def __init__(
self,
*topics: str,
max_workers: int,
# Kafka information
group_id: Optional[str],
listener: Optional["ConsumerRebalanceListener"],
pattern: Optional[str],
connection_args: "AnyDict",
partitions: Iterable["TopicPartition"],
is_manual: bool,
# Subscriber args
no_ack: bool,
no_reply: bool,
retry: bool,
broker_dependencies: Iterable["Depends"],
broker_middlewares: Iterable["BrokerMiddleware[ConsumerRecord]"],
# AsyncAPI args
title_: Optional[str],
description_: Optional[str],
include_in_schema: bool,
) -> None:
pass

0 comments on commit bb8a1c9

Please sign in to comment.