From e61ee3bd6b3fdb80ae32449c71e8dc441306e7dc Mon Sep 17 00:00:00 2001 From: DABND19 Date: Sun, 22 Dec 2024 23:50:29 +0300 Subject: [PATCH] fix: Use separate thread for confluent kafka consumer. --- faststream/confluent/client.py | 65 +++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/faststream/confluent/client.py b/faststream/confluent/client.py index 240f6f3257..9d18d69098 100644 --- a/faststream/confluent/client.py +++ b/faststream/confluent/client.py @@ -1,5 +1,6 @@ import asyncio import logging +from concurrent.futures import ThreadPoolExecutor from contextlib import suppress from time import time from typing import ( @@ -314,9 +315,8 @@ def __init__( self.config = final_config self.consumer = Consumer(final_config, logger=self.logger) # type: ignore[call-arg] - # We shouldn't read messages and close consumer concurrently - # https://github.com/airtai/faststream/issues/1904#issuecomment-2506990895 - self._lock = anyio.Lock() + # A pool with single thread is used in order to execute the commands of the consumer sequentially: + self._thread_pool = ThreadPoolExecutor(max_workers=1) @property def topics_to_create(self) -> List[str]: @@ -324,9 +324,15 @@ def topics_to_create(self) -> List[str]: async def start(self) -> None: """Starts the Kafka consumer and subscribes to the specified topics.""" + loop = asyncio.get_running_loop() + if self.allow_auto_create_topics: - await call_or_await( - create_topics, self.topics_to_create, self.config, self.logger + await loop.run_in_executor( + self._thread_pool, + create_topics, + self.topics_to_create, + self.config, + self.logger, ) elif self.logger: @@ -336,11 +342,15 @@ async def start(self) -> None: ) if self.topics: - await call_or_await(self.consumer.subscribe, self.topics) + await loop.run_in_executor( + self._thread_pool, self.consumer.subscribe, self.topics + ) elif self.partitions: - await call_or_await( - self.consumer.assign, [p.to_confluent() for p in self.partitions] + await loop.run_in_executor( + self._thread_pool, + self.consumer.assign, + [p.to_confluent() for p in self.partitions], ) else: @@ -348,14 +358,23 @@ async def start(self) -> None: async def commit(self, asynchronous: bool = True) -> None: """Commits the offsets of all messages returned by the last poll operation.""" + loop = asyncio.get_running_loop() if asynchronous: # Asynchronous commit is non-blocking: self.consumer.commit(asynchronous=True) else: - await call_or_await(self.consumer.commit, asynchronous=False) + await loop.run_in_executor( + self._thread_pool, + self.consumer.commit, + None, + None, + False, + ) async def stop(self) -> None: """Stops the Kafka consumer and releases all resources.""" + loop = asyncio.get_running_loop() + # NOTE: If we don't explicitly call commit and then close the consumer, the confluent consumer gets stuck. # We are doing this to avoid the issue. enable_auto_commit = self.config["enable.auto.commit"] @@ -375,13 +394,14 @@ async def stop(self) -> None: ) # Wrap calls to async to make method cancelable by timeout - async with self._lock: - await call_or_await(self.consumer.close) + await loop.run_in_executor(self._thread_pool, self.consumer.close) + + self._thread_pool.shutdown(wait=False) async def getone(self, timeout: float = 0.1) -> Optional[Message]: """Consumes a single message from Kafka.""" - async with self._lock: - msg = await call_or_await(self.consumer.poll, timeout) + loop = asyncio.get_running_loop() + msg = await loop.run_in_executor(self._thread_pool, self.consumer.poll, timeout) return check_msg_error(msg) async def getmany( @@ -390,21 +410,24 @@ async def getmany( max_records: Optional[int] = 10, ) -> Tuple[Message, ...]: """Consumes a batch of messages from Kafka and groups them by topic and partition.""" - async with self._lock: - raw_messages: List[Optional[Message]] = await call_or_await( - self.consumer.consume, # type: ignore[arg-type] - num_messages=max_records or 10, - timeout=timeout, - ) - + loop = asyncio.get_running_loop() + raw_messages: List[Optional[Message]] = await loop.run_in_executor( + self._thread_pool, + self.consumer.consume, # type: ignore[arg-type] + max_records or 10, + timeout, + ) return tuple(x for x in map(check_msg_error, raw_messages) if x is not None) async def seek(self, topic: str, partition: int, offset: int) -> None: """Seeks to the specified offset in the specified topic and partition.""" + loop = asyncio.get_running_loop() topic_partition = TopicPartition( topic=topic, partition=partition, offset=offset ) - await call_or_await(self.consumer.seek, topic_partition.to_confluent()) + await loop.run_in_executor( + self._thread_pool, self.consumer.seek, topic_partition.to_confluent() + ) def check_msg_error(msg: Optional[Message]) -> Optional[Message]: