Skip to content

Commit

Permalink
fix: Use separate thread for confluent kafka consumer.
Browse files Browse the repository at this point in the history
  • Loading branch information
DABND19 committed Dec 22, 2024
1 parent 8572446 commit e61ee3b
Showing 1 changed file with 44 additions and 21 deletions.
65 changes: 44 additions & 21 deletions faststream/confluent/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from contextlib import suppress
from time import time
from typing import (
Expand Down Expand Up @@ -314,19 +315,24 @@ def __init__(
self.config = final_config
self.consumer = Consumer(final_config, logger=self.logger) # type: ignore[call-arg]

# We shouldn't read messages and close consumer concurrently
# https://github.com/airtai/faststream/issues/1904#issuecomment-2506990895
self._lock = anyio.Lock()
# A pool with single thread is used in order to execute the commands of the consumer sequentially:
self._thread_pool = ThreadPoolExecutor(max_workers=1)

@property
def topics_to_create(self) -> List[str]:
return list({*self.topics, *(p.topic for p in self.partitions)})

async def start(self) -> None:
"""Starts the Kafka consumer and subscribes to the specified topics."""
loop = asyncio.get_running_loop()

if self.allow_auto_create_topics:
await call_or_await(
create_topics, self.topics_to_create, self.config, self.logger
await loop.run_in_executor(
self._thread_pool,
create_topics,
self.topics_to_create,
self.config,
self.logger,
)

elif self.logger:
Expand All @@ -336,26 +342,39 @@ async def start(self) -> None:
)

if self.topics:
await call_or_await(self.consumer.subscribe, self.topics)
await loop.run_in_executor(
self._thread_pool, self.consumer.subscribe, self.topics
)

elif self.partitions:
await call_or_await(
self.consumer.assign, [p.to_confluent() for p in self.partitions]
await loop.run_in_executor(
self._thread_pool,
self.consumer.assign,
[p.to_confluent() for p in self.partitions],
)

else:
raise SetupError("You must provide either `topics` or `partitions` option.")

async def commit(self, asynchronous: bool = True) -> None:
"""Commits the offsets of all messages returned by the last poll operation."""
loop = asyncio.get_running_loop()
if asynchronous:
# Asynchronous commit is non-blocking:
self.consumer.commit(asynchronous=True)
else:
await call_or_await(self.consumer.commit, asynchronous=False)
await loop.run_in_executor(
self._thread_pool,
self.consumer.commit,
None,
None,
False,
)

async def stop(self) -> None:
"""Stops the Kafka consumer and releases all resources."""
loop = asyncio.get_running_loop()

# NOTE: If we don't explicitly call commit and then close the consumer, the confluent consumer gets stuck.
# We are doing this to avoid the issue.
enable_auto_commit = self.config["enable.auto.commit"]
Expand All @@ -375,13 +394,14 @@ async def stop(self) -> None:
)

# Wrap calls to async to make method cancelable by timeout
async with self._lock:
await call_or_await(self.consumer.close)
await loop.run_in_executor(self._thread_pool, self.consumer.close)

self._thread_pool.shutdown(wait=False)

async def getone(self, timeout: float = 0.1) -> Optional[Message]:
"""Consumes a single message from Kafka."""
async with self._lock:
msg = await call_or_await(self.consumer.poll, timeout)
loop = asyncio.get_running_loop()
msg = await loop.run_in_executor(self._thread_pool, self.consumer.poll, timeout)
return check_msg_error(msg)

async def getmany(
Expand All @@ -390,21 +410,24 @@ async def getmany(
max_records: Optional[int] = 10,
) -> Tuple[Message, ...]:
"""Consumes a batch of messages from Kafka and groups them by topic and partition."""
async with self._lock:
raw_messages: List[Optional[Message]] = await call_or_await(
self.consumer.consume, # type: ignore[arg-type]
num_messages=max_records or 10,
timeout=timeout,
)

loop = asyncio.get_running_loop()
raw_messages: List[Optional[Message]] = await loop.run_in_executor(
self._thread_pool,
self.consumer.consume, # type: ignore[arg-type]
max_records or 10,
timeout,
)
return tuple(x for x in map(check_msg_error, raw_messages) if x is not None)

async def seek(self, topic: str, partition: int, offset: int) -> None:
"""Seeks to the specified offset in the specified topic and partition."""
loop = asyncio.get_running_loop()
topic_partition = TopicPartition(
topic=topic, partition=partition, offset=offset
)
await call_or_await(self.consumer.seek, topic_partition.to_confluent())
await loop.run_in_executor(
self._thread_pool, self.consumer.seek, topic_partition.to_confluent()
)


def check_msg_error(msg: Optional[Message]) -> Optional[Message]:
Expand Down

0 comments on commit e61ee3b

Please sign in to comment.