Skip to content

Commit

Permalink
feat: add batch_headers for Confluent
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik committed May 16, 2024
1 parent 5b5fcd0 commit 437ce6a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 19 deletions.
41 changes: 22 additions & 19 deletions faststream/confluent/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union

from faststream.broker.message import decode_message, gen_cor_id
from faststream.confluent.message import FAKE_CONSUMER, KafkaMessage
Expand All @@ -20,18 +20,14 @@ async def parse_message(
message: "Message",
) -> "StreamMessage[Message]":
"""Parses a Kafka message."""
headers = {}
if message.headers() is not None:
for i, j in message.headers(): # type: ignore[union-attr]
if isinstance(j, str):
headers[i] = j
else:
headers[i] = j.decode()
headers = _parse_msg_headers(message.headers())

body = message.value()
offset = message.offset()
_, timestamp = message.timestamp()

handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_")

return KafkaMessage(
body=body,
headers=headers,
Expand All @@ -49,28 +45,29 @@ async def parse_message_batch(
message: Tuple["Message", ...],
) -> "StreamMessage[Tuple[Message, ...]]":
"""Parses a batch of messages from a Kafka consumer."""
body: List[Any] = []
batch_headers: List[Dict[str, str]] = []

first = message[0]
last = message[-1]

headers = {}
if first.headers() is not None:
for i, j in first.headers(): # type: ignore[union-attr]
if isinstance(j, str):
headers[i] = j
else:
headers[i] = j.decode()
body = [m.value() for m in message]
first_offset = first.offset()
last_offset = last.offset()
for m in message:
body.append(m.value)
batch_headers.append(_parse_msg_headers(m.headers()))

headers = next(iter(batch_headers), {})

_, first_timestamp = first.timestamp()

handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_")

return KafkaMessage(
body=body,
headers=headers,
batch_headers=batch_headers,
reply_to=headers.get("reply_to", ""),
content_type=headers.get("content-type"),
message_id=f"{first_offset}-{last_offset}-{first_timestamp}",
message_id=f"{first.offset()}-{last.offset()}-{first_timestamp}",
correlation_id=headers.get("correlation_id", gen_cor_id()),
raw_message=message,
consumer=getattr(handler, "consumer", None) or FAKE_CONSUMER,
Expand All @@ -91,3 +88,9 @@ async def decode_message_batch(
) -> "DecodedMessage":
"""Decode a batch of messages."""
return [decode_message(await cls.parse_message(m)) for m in msg.raw_message]


def _parse_msg_headers(
headers: Sequence[Tuple[str, Union[bytes, str]]],
) -> Dict[str, str]:
return {i: j if isinstance(j, str) else j.decode() for i, j in headers}
32 changes: 32 additions & 0 deletions tests/brokers/confluent/test_consume.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,38 @@ async def handler(msg):

assert [{1, "hi"}] == [set(r.result()) for r in result]

@pytest.mark.asyncio()
async def test_consume_batch_headers(
self, mock, event: asyncio.Event, queue: str, full_broker: KafkaBroker
):
@full_broker.subscriber(queue, batch=True, **self.subscriber_kwargs)
def subscriber(m, msg: KafkaMessage):
check = all(
(
msg.headers,
[msg.headers] == msg.batch_headers,
msg.headers.get("custom") == "1",
)
)
mock(check)
event.set()

async with full_broker:
await full_broker.start()

await asyncio.wait(
(
asyncio.create_task(
full_broker.publish("", queue, headers={"custom": "1"})
),
asyncio.create_task(event.wait()),
),
timeout=self.timeout,
)

assert event.is_set()
mock.assert_called_once_with(True)

@pytest.mark.asyncio()
@pytest.mark.slow()
async def test_consume_ack(
Expand Down

0 comments on commit 437ce6a

Please sign in to comment.