From 437ce6a55dc67c92d208d4502ac82f4ce4195772 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Thu, 16 May 2024 19:31:44 +0300 Subject: [PATCH] feat: add batch_headers for Confluent --- faststream/confluent/parser.py | 41 +++++++++++++------------ tests/brokers/confluent/test_consume.py | 32 +++++++++++++++++++ 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/faststream/confluent/parser.py b/faststream/confluent/parser.py index a4858247ac..8541ceb4f0 100644 --- a/faststream/confluent/parser.py +++ b/faststream/confluent/parser.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union from faststream.broker.message import decode_message, gen_cor_id from faststream.confluent.message import FAKE_CONSUMER, KafkaMessage @@ -20,18 +20,14 @@ async def parse_message( message: "Message", ) -> "StreamMessage[Message]": """Parses a Kafka message.""" - headers = {} - if message.headers() is not None: - for i, j in message.headers(): # type: ignore[union-attr] - if isinstance(j, str): - headers[i] = j - else: - headers[i] = j.decode() + headers = _parse_msg_headers(message.headers()) + body = message.value() offset = message.offset() _, timestamp = message.timestamp() handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_") + return KafkaMessage( body=body, headers=headers, @@ -49,28 +45,29 @@ async def parse_message_batch( message: Tuple["Message", ...], ) -> "StreamMessage[Tuple[Message, ...]]": """Parses a batch of messages from a Kafka consumer.""" + body: List[Any] = [] + batch_headers: List[Dict[str, str]] = [] + first = message[0] last = message[-1] - headers = {} - if first.headers() is not None: - for i, j in first.headers(): # type: ignore[union-attr] - if isinstance(j, str): - headers[i] = j - else: - headers[i] = j.decode() - body = [m.value() for m in message] - first_offset = first.offset() - last_offset = last.offset() + for m in message: + body.append(m.value) + batch_headers.append(_parse_msg_headers(m.headers())) + + headers = next(iter(batch_headers), {}) + _, first_timestamp = first.timestamp() handler: Optional["LogicSubscriber[Any]"] = context.get_local("handler_") + return KafkaMessage( body=body, headers=headers, + batch_headers=batch_headers, reply_to=headers.get("reply_to", ""), content_type=headers.get("content-type"), - message_id=f"{first_offset}-{last_offset}-{first_timestamp}", + message_id=f"{first.offset()}-{last.offset()}-{first_timestamp}", correlation_id=headers.get("correlation_id", gen_cor_id()), raw_message=message, consumer=getattr(handler, "consumer", None) or FAKE_CONSUMER, @@ -91,3 +88,9 @@ async def decode_message_batch( ) -> "DecodedMessage": """Decode a batch of messages.""" return [decode_message(await cls.parse_message(m)) for m in msg.raw_message] + + +def _parse_msg_headers( + headers: Sequence[Tuple[str, Union[bytes, str]]], +) -> Dict[str, str]: + return {i: j if isinstance(j, str) else j.decode() for i, j in headers} diff --git a/tests/brokers/confluent/test_consume.py b/tests/brokers/confluent/test_consume.py index fb612d66d0..2c471c6e73 100644 --- a/tests/brokers/confluent/test_consume.py +++ b/tests/brokers/confluent/test_consume.py @@ -39,6 +39,38 @@ async def handler(msg): assert [{1, "hi"}] == [set(r.result()) for r in result] + @pytest.mark.asyncio() + async def test_consume_batch_headers( + self, mock, event: asyncio.Event, queue: str, full_broker: KafkaBroker + ): + @full_broker.subscriber(queue, batch=True, **self.subscriber_kwargs) + def subscriber(m, msg: KafkaMessage): + check = all( + ( + msg.headers, + [msg.headers] == msg.batch_headers, + msg.headers.get("custom") == "1", + ) + ) + mock(check) + event.set() + + async with full_broker: + await full_broker.start() + + await asyncio.wait( + ( + asyncio.create_task( + full_broker.publish("", queue, headers={"custom": "1"}) + ), + asyncio.create_task(event.wait()), + ), + timeout=self.timeout, + ) + + assert event.is_set() + mock.assert_called_once_with(True) + @pytest.mark.asyncio() @pytest.mark.slow() async def test_consume_ack(