diff --git a/faststream/_internal/utils/functions.py b/faststream/_internal/utils/functions.py index e8cb60d696..efea5541d6 100644 --- a/faststream/_internal/utils/functions.py +++ b/faststream/_internal/utils/functions.py @@ -1,9 +1,12 @@ +import asyncio from collections.abc import AsyncIterator, Awaitable, Iterator +from concurrent.futures import Executor from contextlib import asynccontextmanager, contextmanager -from functools import wraps +from functools import partial, wraps from typing import ( Any, Callable, + Optional, TypeVar, Union, cast, @@ -80,3 +83,8 @@ def drop_response_type(model: CallModel) -> CallModel: async def return_input(x: Any) -> Any: return x + + +async def run_in_executor(executor: Optional[Executor], func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(executor, partial(func, *args, **kwargs)) diff --git a/faststream/confluent/client.py b/faststream/confluent/client.py index 385a1b4389..471b71c368 100644 --- a/faststream/confluent/client.py +++ b/faststream/confluent/client.py @@ -1,6 +1,7 @@ import asyncio import logging from collections.abc import Iterable, Sequence +from concurrent.futures import ThreadPoolExecutor from contextlib import suppress from time import time from typing import ( @@ -17,7 +18,7 @@ from faststream._internal.constants import EMPTY from faststream._internal.log import logger as faststream_logger -from faststream._internal.utils.functions import call_or_await +from faststream._internal.utils.functions import call_or_await, run_in_executor from faststream.confluent import config as config_module from faststream.confluent.schemas import TopicPartition from faststream.exceptions import SetupError @@ -314,9 +315,8 @@ def __init__( self.config = final_config self.consumer = Consumer(final_config, logger=self.logger_state.logger.logger) # type: ignore[call-arg] - # We shouldn't read messages and close consumer concurrently - # https://github.com/airtai/faststream/issues/1904#issuecomment-2506990895 - self._lock = anyio.Lock() + # A pool with single thread is used in order to execute the commands of the consumer sequentially: + self._thread_pool = ThreadPoolExecutor(max_workers=1) @property def topics_to_create(self) -> list[str]: @@ -325,11 +325,12 @@ def topics_to_create(self) -> list[str]: async def start(self) -> None: """Starts the Kafka consumer and subscribes to the specified topics.""" if self.allow_auto_create_topics: - await call_or_await( + await run_in_executor( + self._thread_pool, create_topics, - self.topics_to_create, - self.config, - self.logger_state.logger.logger, + topics=self.topics_to_create, + config=self.config, + logger_=self.logger_state.logger.logger, ) else: @@ -339,10 +340,13 @@ async def start(self) -> None: ) if self.topics: - await call_or_await(self.consumer.subscribe, self.topics) + await run_in_executor( + self._thread_pool, self.consumer.subscribe, topics=self.topics + ) elif self.partitions: - await call_or_await( + await run_in_executor( + self._thread_pool, self.consumer.assign, [p.to_confluent() for p in self.partitions], ) @@ -353,7 +357,7 @@ async def start(self) -> None: async def commit(self, asynchronous: bool = True) -> None: """Commits the offsets of all messages returned by the last poll operation.""" - await call_or_await(self.consumer.commit, asynchronous=asynchronous) + await run_in_executor(self._thread_pool, self.consumer.commit, asynchronous=asynchronous) async def stop(self) -> None: """Stops the Kafka consumer and releases all resources.""" @@ -376,13 +380,13 @@ async def stop(self) -> None: ) # Wrap calls to async to make method cancelable by timeout - async with self._lock: - await call_or_await(self.consumer.close) + await run_in_executor(self._thread_pool, self.consumer.close) + + self._thread_pool.shutdown(wait=False) async def getone(self, timeout: float = 0.1) -> Optional[Message]: """Consumes a single message from Kafka.""" - async with self._lock: - msg = await call_or_await(self.consumer.poll, timeout) + msg = await run_in_executor(self._thread_pool, self.consumer.poll, timeout) return check_msg_error(msg) async def getmany( @@ -391,13 +395,12 @@ async def getmany( max_records: Optional[int] = 10, ) -> tuple[Message, ...]: """Consumes a batch of messages from Kafka and groups them by topic and partition.""" - async with self._lock: - raw_messages: list[Optional[Message]] = await call_or_await( - self.consumer.consume, # type: ignore[arg-type] - num_messages=max_records or 10, - timeout=timeout, - ) - + raw_messages: list[Optional[Message]] = await run_in_executor( + self._thread_pool, + self.consumer.consume, # type: ignore[arg-type] + num_messages=max_records or 10, + timeout=timeout, + ) return tuple(x for x in map(check_msg_error, raw_messages) if x is not None) async def seek(self, topic: str, partition: int, offset: int) -> None: @@ -407,7 +410,7 @@ async def seek(self, topic: str, partition: int, offset: int) -> None: partition=partition, offset=offset, ) - await call_or_await(self.consumer.seek, topic_partition.to_confluent()) + await run_in_executor(self._thread_pool, self.consumer.seek, topic_partition.to_confluent()) def check_msg_error(msg: Optional[Message]) -> Optional[Message]: diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index 7f94ec1b7b..a2a0392978 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -131,12 +131,12 @@ async def start(self) -> None: self.add_task(self._consume()) async def close(self) -> None: + await super().close() + if self.consumer is not None: await self.consumer.stop() self.consumer = None - await super().close() - @override async def get_one( self, @@ -335,18 +335,14 @@ def __init__( async def get_msg(self) -> Optional[tuple["Message", ...]]: assert self.consumer, "You should setup subscriber at first." # nosec B101 - - messages = await self.consumer.getmany( - timeout=self.polling_interval, - max_records=self.max_records, + return ( + await self.consumer.getmany( + timeout=self.polling_interval, + max_records=self.max_records, + ) + or None ) - if not messages: # TODO: why we are sleeping here? - await anyio.sleep(self.polling_interval) - return None - - return messages - def get_log_context( self, message: Optional["StreamMessage[tuple[Message, ...]]"],