Skip to content

Commit

Permalink
Implement async confluent-kafka producer
Browse files Browse the repository at this point in the history
THe async `AIOKafkaProducer` is implemented as a wrapper around
`KafkaProducer` with an async `send` method and a poll-thread to
continuously send messages in the background, making the result of
`send` an awaitable `asyncio.Future`.

The example of confluent-kafka has been followed:
https://github.com/confluentinc/confluent-kafka-python/blob/master/examples/asyncio_example.py
  • Loading branch information
Mátyás Kuti committed Jan 18, 2024
1 parent 59eaff2 commit 699136c
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 36 deletions.
11 changes: 10 additions & 1 deletion karapace/kafka/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,12 @@ def token_with_expiry(self, config: str | None) -> tuple[str, int | None]:


class KafkaClientParams(TypedDict, total=False):
acks: int | None
client_id: str | None
connections_max_idle_ms: int | None
max_block_ms: int | None
compression_type: str | None
linger_ms: int | None
message_max_bytes: int | None
metadata_max_age_ms: int | None
retries: int | None
sasl_mechanism: str | None
Expand All @@ -83,6 +86,7 @@ class KafkaClientParams(TypedDict, total=False):
socket_timeout_ms: int | None
ssl_cafile: str | None
ssl_certfile: str | None
ssl_crlfile: str | None
ssl_keyfile: str | None
sasl_oauth_token_provider: TokenWithExpiryProvider
# Consumer-only
Expand Down Expand Up @@ -121,8 +125,12 @@ def _get_config_from_params(self, bootstrap_servers: Iterable[str] | str, **para

config: dict[str, int | str | Callable | None] = {
"bootstrap.servers": bootstrap_servers,
"acks": params.get("acks"),
"client.id": params.get("client_id"),
"connections.max.idle.ms": params.get("connections_max_idle_ms"),
"compression.type": params.get("compression_type"),
"linger.ms": params.get("linger_ms"),
"message.max.bytes": params.get("message_max_bytes"),
"metadata.max.age.ms": params.get("metadata_max_age_ms"),
"retries": params.get("retries"),
"sasl.mechanism": params.get("sasl_mechanism"),
Expand All @@ -132,6 +140,7 @@ def _get_config_from_params(self, bootstrap_servers: Iterable[str] | str, **para
"socket.timeout.ms": params.get("socket_timeout_ms"),
"ssl.ca.location": params.get("ssl_cafile"),
"ssl.certificate.location": params.get("ssl_certfile"),
"ssl.crl.location": params.get("ssl_crlfile"),
"ssl.key.location": params.get("ssl_keyfile"),
"error_cb": self._error_callback,
# Consumer-only
Expand Down
43 changes: 42 additions & 1 deletion karapace/kafka/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@

from __future__ import annotations

from collections.abc import Iterable
from concurrent.futures import Future
from confluent_kafka import Message, Producer
from confluent_kafka.admin import PartitionMetadata
from confluent_kafka.error import KafkaError, KafkaException
from functools import partial
from karapace.kafka.common import _KafkaConfigMixin, raise_from_kafkaexception, translate_from_kafkaerror
from karapace.kafka.common import _KafkaConfigMixin, KafkaClientParams, raise_from_kafkaexception, translate_from_kafkaerror
from threading import Event, Thread
from typing import cast, TypedDict
from typing_extensions import Unpack

import asyncio
import logging

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -59,3 +62,41 @@ def partitions_for(self, topic: str) -> dict[int, PartitionMetadata]:
return self.list_topics(topic).topics[topic].partitions
except KafkaException as exc:
raise_from_kafkaexception(exc)


class AsyncKafkaProducer:
"""An async wrapper around `KafkaProducer` built on top of confluent-kafka.
Starting `AsyncKafkaProducer` instantiates a `KafkaProducer` and starts a poll-thread.
The poll-thread continuously polls the underlying producer so buffered messages
are sent and asyncio futures returned by the `send` method can be awaited.
"""

def __init__(
self,
bootstrap_servers: Iterable[str] | str,
loop: asyncio.AbstractEventLoop | None = None,
**params: Unpack[KafkaClientParams],
) -> None:
self.loop = loop or asyncio.get_running_loop()
self.producer = KafkaProducer(bootstrap_servers, **params)

self.stopped = Event()
self.poll_thread = Thread(target=self.poll_loop)
self.poll_thread.start()

async def stop(self) -> None:
self.stopped.set()
self.poll_thread.join()

def poll_loop(self) -> None:
"""Target of the poll-thread."""
while not self.stopped.is_set():
self.producer.poll(0.1)

async def send(self, topic: str, **params: Unpack[ProducerSendParams]) -> asyncio.Future[Message]:
return asyncio.wrap_future(
self.producer.send(topic, **params),
loop=self.loop,
)
58 changes: 27 additions & 31 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from aiokafka import AIOKafkaProducer
from aiokafka.errors import KafkaConnectionError
from binascii import Error as B64DecodeError
from collections import namedtuple
from confluent_kafka.error import KafkaException
Expand All @@ -13,9 +11,10 @@
TopicAuthorizationFailedError,
UnknownTopicOrPartitionError,
)
from karapace.config import Config, create_client_ssl_context
from karapace.config import Config
from karapace.errors import InvalidSchema
from karapace.kafka.admin import KafkaAdminClient
from karapace.kafka.producer import AsyncKafkaProducer
from karapace.kafka_rest_apis.authentication import (
get_auth_config_from_header,
get_expiration_time_from_header,
Expand All @@ -36,7 +35,7 @@
SchemaRetrievalError,
)
from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType
from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient
from karapace.utils import convert_to_int, json_encode
from typing import Callable, Dict, List, Optional, Tuple, Union

import asyncio
Expand Down Expand Up @@ -441,7 +440,7 @@ def __init__(
self._auth_expiry = auth_expiry

self._async_producer_lock = asyncio.Lock()
self._async_producer: Optional[AIOKafkaProducer] = None
self._async_producer: Optional[AsyncKafkaProducer] = None
self.naming_strategy = NameStrategy(self.config["name_strategy"])

def __str__(self) -> str:
Expand All @@ -461,12 +460,12 @@ def auth_expiry(self) -> datetime.datetime:
def num_consumers(self) -> int:
return len(self.consumer_manager.consumers)

async def _maybe_create_async_producer(self) -> AIOKafkaProducer:
async def _maybe_create_async_producer(self) -> AsyncKafkaProducer:
if self._async_producer is not None:
return self._async_producer

if self.config["producer_acks"] == "all":
acks = "all"
acks = -1
else:
acks = int(self.config["producer_acks"])

Expand All @@ -477,27 +476,23 @@ async def _maybe_create_async_producer(self) -> AIOKafkaProducer:

log.info("Creating async producer")

# Don't retry if creating the SSL context fails, likely a configuration issue with
# ciphers or certificate chains
ssl_context = create_client_ssl_context(self.config)

# Don't retry if instantiating the producer fails, likely a configuration error.
producer = AIOKafkaProducer(
acks=acks,
bootstrap_servers=self.config["bootstrap_uri"],
compression_type=self.config["producer_compression_type"],
connections_max_idle_ms=self.config["connections_max_idle_ms"],
linger_ms=self.config["producer_linger_ms"],
max_request_size=self.config["producer_max_request_size"],
metadata_max_age_ms=self.config["metadata_max_age_ms"],
security_protocol=self.config["security_protocol"],
ssl_context=ssl_context,
**get_kafka_client_auth_parameters_from_config(self.config),
)

try:
await producer.start()
except KafkaConnectionError:
producer = AsyncKafkaProducer(
acks=acks,
bootstrap_servers=self.config["bootstrap_uri"],
compression_type=self.config["producer_compression_type"],
connections_max_idle_ms=self.config["connections_max_idle_ms"],
linger_ms=self.config["producer_linger_ms"],
message_max_bytes=self.config["producer_max_request_size"],
metadata_max_age_ms=self.config["metadata_max_age_ms"],
security_protocol=self.config["security_protocol"],
ssl_cafile=self.config["ssl_cafile"],
ssl_certfile=self.config["ssl_certfile"],
ssl_keyfile=self.config["ssl_keyfile"],
ssl_crlfile=self.config["ssl_crlfile"],
**get_kafka_client_auth_parameters_from_config(self.config),
)
except (NoBrokersAvailable, AuthenticationFailedError):
if retry:
log.exception("Unable to connect to the bootstrap servers, retrying")
else:
Expand Down Expand Up @@ -645,10 +640,8 @@ def init_admin_client(self):
ssl_cafile=self.config["ssl_cafile"],
ssl_certfile=self.config["ssl_certfile"],
ssl_keyfile=self.config["ssl_keyfile"],
api_version=(1, 0, 0),
metadata_max_age_ms=self.config["metadata_max_age_ms"],
connections_max_idle_ms=self.config["connections_max_idle_ms"],
kafka_client=KarapaceKafkaClient,
**get_kafka_client_auth_parameters_from_config(self.config, async_client=False),
)
break
Expand Down Expand Up @@ -1069,8 +1062,11 @@ async def produce_messages(self, *, topic: str, prepared_records: List) -> List:
if not isinstance(result, Exception):
produce_results.append(
{
"offset": result.offset if result else -1,
"partition": result.topic_partition.partition if result else 0,
# In case the offset is not available, `confluent_kafka.Message.offset()` is
# `None`. To preserve backwards compatibility, we replace this with -1.
# -1 was the default `aiokafka` behaviour.
"offset": result.offset() if result and result.offset() is not None else -1,
"partition": result.partition() if result else 0,
}
)

Expand Down
3 changes: 3 additions & 0 deletions karapace/kafka_rest_apis/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ class SimpleOauthTokenProviderAsync(AbstractTokenProviderAsync):
async def token(self) -> str:
return self._token

def token_with_expiry(self, _config: str | None = None) -> tuple[str, int | None]:
return (self._token, get_expiration_timestamp_from_jwt(self._token))


class SASLOauthParams(TypedDict):
sasl_mechanism: str
Expand Down
64 changes: 63 additions & 1 deletion tests/integration/kafka/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

from confluent_kafka.admin import NewTopic
from kafka.errors import MessageSizeTooLargeError, UnknownTopicOrPartitionError
from karapace.kafka.producer import KafkaProducer
from karapace.kafka.producer import AsyncKafkaProducer, KafkaProducer
from karapace.kafka.types import Timestamp
from tests.integration.utils.kafka_server import KafkaServers
from typing import Iterator

import asyncio
import pytest
import time

Expand Down Expand Up @@ -71,3 +74,62 @@ def test_partitions_for(self, producer: KafkaProducer, new_topic: NewTopic) -> N
assert partitions[0].id == 0
assert partitions[0].replicas == [1]
assert partitions[0].isrs == [1]


@pytest.fixture(scope="function", name="asyncproducer")
async def fixture_asyncproducer(
kafka_servers: KafkaServers,
loop: asyncio.AbstractEventLoop,
) -> Iterator[AsyncKafkaProducer]:
try:
asyncproducer = AsyncKafkaProducer(bootstrap_servers=kafka_servers.bootstrap_servers, loop=loop)
yield asyncproducer
finally:
await asyncproducer.stop()


class TestAsyncSend:
async def test_async_send(self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic) -> None:
key = b"key"
value = b"value"
partition = 0
timestamp = int(time.time() * 1000)
headers = [("something", b"123"), (None, "foobar")]

aiofut = await asyncproducer.send(
new_topic.topic,
key=key,
value=value,
partition=partition,
timestamp=timestamp,
headers=headers,
)
message = await aiofut

assert message.offset() == 0
assert message.partition() == partition
assert message.topic() == new_topic.topic
assert message.key() == key
assert message.value() == value
assert message.timestamp()[0] == Timestamp.CREATE_TIME
assert message.timestamp()[1] == timestamp

async def test_async_send_raises_for_unknown_topic(self, asyncproducer: AsyncKafkaProducer) -> None:
aiofut = await asyncproducer.send("nonexistent")

with pytest.raises(UnknownTopicOrPartitionError):
_ = await aiofut

async def test_async_send_raises_for_unknown_partition(
self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic
) -> None:
aiofut = await asyncproducer.send(new_topic.topic, partition=99)

with pytest.raises(UnknownTopicOrPartitionError):
_ = await aiofut

async def test_async_send_raises_for_too_large_message(
self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic
) -> None:
with pytest.raises(MessageSizeTooLargeError):
await asyncproducer.send(new_topic.topic, value=b"x" * 1000001)
4 changes: 2 additions & 2 deletions tests/integration/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ async def test_internal(rest_async: KafkaRest | None, admin_client: KafkaAdminCl
assert len(results) == 1
for result in results:
assert "error" in result, "Invalid result missing 'error' key"
assert result["error"] == "Unrecognized partition"
assert result["error"] == "This request is for a topic or partition that does not exist on this broker."
assert "error_code" in result, "Invalid result missing 'error_code' key"
assert result["error_code"] == 1
assert result["error_code"] == 2

assert rest_async_proxy.all_empty({"records": [{"key": {"foo": "bar"}}]}, "key") is False
assert rest_async_proxy.all_empty({"records": [{"value": {"foo": "bar"}}]}, "value") is False
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ async def test_simple_oauth_token_provider_async_returns_configured_token() -> N
assert await token_provider.token() == "TOKEN"


def test_simple_oauth_token_provider_async_returns_configured_token_and_expiry() -> None:
expiry_timestamp = 1697013997
token = jwt.encode({"exp": expiry_timestamp}, "secret")
token_provider = SimpleOauthTokenProviderAsync(token)

assert token_provider.token_with_expiry() == (token, expiry_timestamp)


def test_get_client_auth_parameters_from_config_sasl_plain() -> None:
config = set_config_defaults(
{"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"}
Expand Down

0 comments on commit 699136c

Please sign in to comment.