From 3236756a45d66fa7709772efa5d82783734fc9b1 Mon Sep 17 00:00:00 2001 From: treaditup Date: Mon, 17 Jun 2024 01:53:45 +0300 Subject: [PATCH 1/9] fix: add links for batches --- faststream/opentelemetry/consts.py | 2 + faststream/opentelemetry/middleware.py | 70 ++++++++++++++++--- .../opentelemetry/confluent/test_confluent.py | 4 +- tests/opentelemetry/kafka/test_kafka.py | 4 +- tests/opentelemetry/nats/test_nats.py | 4 ++ tests/opentelemetry/redis/test_redis.py | 4 +- 6 files changed, 74 insertions(+), 14 deletions(-) diff --git a/faststream/opentelemetry/consts.py b/faststream/opentelemetry/consts.py index 2436d568ee..33a22644ed 100644 --- a/faststream/opentelemetry/consts.py +++ b/faststream/opentelemetry/consts.py @@ -5,5 +5,7 @@ class MessageAction: RECEIVE = "receive" +OTEL_SCHEMA = "https://opentelemetry.io/schemas/1.11.0" ERROR_TYPE = "error.type" MESSAGING_DESTINATION_PUBLISH_NAME = "messaging.destination_publish.name" +WITH_BATCH = "with_batch" diff --git a/faststream/opentelemetry/middleware.py b/faststream/opentelemetry/middleware.py index 7bb0519c68..9792928df1 100644 --- a/faststream/opentelemetry/middleware.py +++ b/faststream/opentelemetry/middleware.py @@ -1,14 +1,21 @@ import time +from collections import Counter from copy import copy -from typing import TYPE_CHECKING, Any, Callable, Optional, Type +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Type -from opentelemetry import context, metrics, propagate, trace +from opentelemetry import baggage, context, metrics, trace +from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.trace import Link, Span +from opentelemetry.context import Context +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from faststream import BaseMiddleware from faststream.opentelemetry.consts import ( ERROR_TYPE, MESSAGING_DESTINATION_PUBLISH_NAME, + OTEL_SCHEMA, + WITH_BATCH, MessageAction, ) from faststream.opentelemetry.provider import TelemetrySettingsProvider @@ -16,21 +23,55 @@ if TYPE_CHECKING: from types import TracebackType - from opentelemetry.context import Context from opentelemetry.metrics import Meter, MeterProvider - from opentelemetry.trace import Span, Tracer, TracerProvider + from opentelemetry.trace import Tracer, TracerProvider from faststream.broker.message import StreamMessage from faststream.types import AnyDict, AsyncFunc, AsyncFuncAny -_OTEL_SCHEMA = "https://opentelemetry.io/schemas/1.11.0" +_BAGGAGE_PROPAGATOR = W3CBaggagePropagator() +_TRACE_PROPAGATOR = TraceContextTextMapPropagator() def _create_span_name(destination: str, action: str) -> str: return f"{destination} {action}" +def _is_batch_message(msg: "StreamMessage[Any]") -> bool: + with_batch = baggage.get_baggage(WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers)) + return bool(msg.batch_headers or with_batch) + + +def _get_span_link(headers: "AnyDict", count: int) -> Optional[Link]: + trace_context = _TRACE_PROPAGATOR.extract(headers) + if not len(trace_context): + return None + span_context = next(iter(trace_context.values())) + if not isinstance(span_context, Span): + return None + attributes = {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: count} if count > 1 else None + return Link(span_context.get_span_context(), attributes=attributes) + + +def _get_span_links(msg: "StreamMessage[Any]") -> Optional[List[Link]]: + if not msg.batch_headers: + link = _get_span_link(msg.headers, 1) + return [link] if link else None + + links = [] + all_headers = {h["correlation_id"]: h for h in msg.batch_headers} + counted_links = Counter([h["correlation_id"] for h in msg.batch_headers]).most_common() + + for correlation_id, count in counted_links: + link = _get_span_link(all_headers[correlation_id], count) + if link is None: + continue + links.append(link) + + return links if links else None + + class _MetricsContainer: __slots__ = ( "include_messages_counters", @@ -139,12 +180,13 @@ async def publish_scope( # NOTE: if batch with single message? if (msg_count := len((msg, *args))) > 1: trace_attributes[SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT] = msg_count + _BAGGAGE_PROPAGATOR.inject(headers, baggage.set_baggage(WITH_BATCH, True, context=current_context)) if self._current_span and self._current_span.is_recording(): current_context = trace.set_span_in_context( self._current_span, current_context ) - propagate.inject(headers, context=self._origin_context) + _TRACE_PROPAGATOR.inject(headers, context=self._origin_context) else: create_span = self._tracer.start_span( @@ -153,7 +195,7 @@ async def publish_scope( attributes=trace_attributes, ) current_context = trace.set_span_in_context(create_span) - propagate.inject(headers, context=current_context) + _TRACE_PROPAGATOR.inject(headers, context=current_context) create_span.end() start_time = time.perf_counter() @@ -188,9 +230,14 @@ async def consume_scope( if (provider := self.__settings_provider) is None: return await call_next(msg) - current_context = propagate.extract(msg.headers) - destination_name = provider.get_consume_destination_name(msg) + if _is_batch_message(msg): + links = _get_span_links(msg) + current_context = Context() + else: + links = None + current_context = _TRACE_PROPAGATOR.extract(msg.headers) + destination_name = provider.get_consume_destination_name(msg) trace_attributes = provider.get_consume_attrs_from_message(msg) metrics_attributes = { SpanAttributes.MESSAGING_SYSTEM: provider.messaging_system, @@ -202,6 +249,7 @@ async def consume_scope( name=_create_span_name(destination_name, MessageAction.CREATE), kind=trace.SpanKind.CONSUMER, attributes=trace_attributes, + links=links, ) current_context = trace.set_span_in_context(create_span) create_span.end() @@ -292,7 +340,7 @@ def _get_meter( return metrics.get_meter( __name__, meter_provider=meter_provider, - schema_url=_OTEL_SCHEMA, + schema_url=OTEL_SCHEMA, ) return meter @@ -301,5 +349,5 @@ def _get_tracer(tracer_provider: Optional["TracerProvider"] = None) -> "Tracer": return trace.get_tracer( __name__, tracer_provider=tracer_provider, - schema_url=_OTEL_SCHEMA, + schema_url=OTEL_SCHEMA, ) diff --git a/tests/opentelemetry/confluent/test_confluent.py b/tests/opentelemetry/confluent/test_confluent.py index 3877d488ba..3402eff841 100644 --- a/tests/opentelemetry/confluent/test_confluent.py +++ b/tests/opentelemetry/confluent/test_confluent.py @@ -78,6 +78,7 @@ async def test_batch( ) broker = self.broker_class(middlewares=(mid,)) expected_msg_count = 3 + expected_link_count = 1 @broker.subscriber(queue, batch=True, **self.subscriber_kwargs) async def handler(m): @@ -96,7 +97,7 @@ async def handler(m): metrics = self.get_metrics(metric_reader) spans = self.get_spans(trace_exporter) - _, publish, process = spans + _, publish, create_process, process = spans assert ( publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] @@ -106,6 +107,7 @@ async def handler(m): process.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] == expected_msg_count ) + assert len(create_process.links) == expected_link_count self.assert_metrics(metrics, count=expected_msg_count) assert event.is_set() diff --git a/tests/opentelemetry/kafka/test_kafka.py b/tests/opentelemetry/kafka/test_kafka.py index 2142825098..6dfd0eaa9e 100644 --- a/tests/opentelemetry/kafka/test_kafka.py +++ b/tests/opentelemetry/kafka/test_kafka.py @@ -76,6 +76,7 @@ async def test_batch( ) broker = self.broker_class(middlewares=(mid,)) expected_msg_count = 3 + expected_link_count = 1 @broker.subscriber(queue, batch=True, **self.subscriber_kwargs) async def handler(m): @@ -94,7 +95,7 @@ async def handler(m): metrics = self.get_metrics(metric_reader) spans = self.get_spans(trace_exporter) - _, publish, process = spans + _, publish, create_process, process = spans assert ( publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] @@ -104,6 +105,7 @@ async def handler(m): process.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] == expected_msg_count ) + assert len(create_process.links) == expected_link_count self.assert_metrics(metrics, count=expected_msg_count) assert event.is_set() diff --git a/tests/opentelemetry/nats/test_nats.py b/tests/opentelemetry/nats/test_nats.py index db9b4ba48b..3a89699dc0 100644 --- a/tests/opentelemetry/nats/test_nats.py +++ b/tests/opentelemetry/nats/test_nats.py @@ -44,6 +44,7 @@ async def test_batch( ) broker = self.broker_class(middlewares=(mid,)) expected_msg_count = 3 + expected_span_count = 8 @broker.subscriber( queue, @@ -71,7 +72,10 @@ async def handler(m): proc_dur, proc_msg, pub_dur, pub_msg = metrics spans = self.get_spans(trace_exporter) process = spans[-1] + create_batch = spans[-2] + assert len(create_batch.links) == expected_msg_count + assert len(spans) == expected_span_count assert ( process.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] == expected_msg_count diff --git a/tests/opentelemetry/redis/test_redis.py b/tests/opentelemetry/redis/test_redis.py index 71e079cbac..c2d06681d0 100644 --- a/tests/opentelemetry/redis/test_redis.py +++ b/tests/opentelemetry/redis/test_redis.py @@ -42,6 +42,7 @@ async def test_batch( ) broker = self.broker_class(middlewares=(mid,)) expected_msg_count = 3 + expected_link_count = 1 @broker.subscriber(list=ListSub(queue, batch=True), **self.subscriber_kwargs) async def handler(m): @@ -60,7 +61,7 @@ async def handler(m): metrics = self.get_metrics(metric_reader) spans = self.get_spans(trace_exporter) - _, publish, process = spans + _, publish, create_process, process = spans assert ( publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] @@ -70,6 +71,7 @@ async def handler(m): process.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] == expected_msg_count ) + assert len(create_process.links) == expected_link_count self.assert_metrics(metrics, count=expected_msg_count) assert event.is_set() From 550e0ee1d128af7f00b7f7cf1e8a26bc8995335f Mon Sep 17 00:00:00 2001 From: treaditup Date: Mon, 17 Jun 2024 03:06:43 +0300 Subject: [PATCH 2/9] fix: strict get correlation_id from dict --- faststream/opentelemetry/middleware.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/faststream/opentelemetry/middleware.py b/faststream/opentelemetry/middleware.py index 9792928df1..fd1fdd1eb2 100644 --- a/faststream/opentelemetry/middleware.py +++ b/faststream/opentelemetry/middleware.py @@ -59,12 +59,17 @@ def _get_span_links(msg: "StreamMessage[Any]") -> Optional[List[Link]]: link = _get_span_link(msg.headers, 1) return [link] if link else None - links = [] - all_headers = {h["correlation_id"]: h for h in msg.batch_headers} - counted_links = Counter([h["correlation_id"] for h in msg.batch_headers]).most_common() + links, headers_by_correlation, all_correlations = [], {}, [] - for correlation_id, count in counted_links: - link = _get_span_link(all_headers[correlation_id], count) + for headers in msg.batch_headers: + correlation_id = headers.get("correlation_id") + if correlation_id is None: + continue + headers_by_correlation[correlation_id] = headers + all_correlations.append(correlation_id) + + for correlation_id, count in Counter(all_correlations).most_common(): + link = _get_span_link(headers_by_correlation[correlation_id], count) if link is None: continue links.append(link) From bf35dcaa19414a402bc9c5e35adaee6c4bf9cc37 Mon Sep 17 00:00:00 2001 From: treaditup Date: Mon, 17 Jun 2024 15:16:04 +0300 Subject: [PATCH 3/9] test: add batch telemetry tests for redis --- faststream/opentelemetry/middleware.py | 14 ++- tests/opentelemetry/nats/test_nats.py | 3 +- tests/opentelemetry/redis/test_redis.py | 112 ++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 5 deletions(-) diff --git a/faststream/opentelemetry/middleware.py b/faststream/opentelemetry/middleware.py index fd1fdd1eb2..3055341f89 100644 --- a/faststream/opentelemetry/middleware.py +++ b/faststream/opentelemetry/middleware.py @@ -5,9 +5,9 @@ from opentelemetry import baggage, context, metrics, trace from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.context import Context from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import Link, Span -from opentelemetry.context import Context from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from faststream import BaseMiddleware @@ -39,7 +39,9 @@ def _create_span_name(destination: str, action: str) -> str: def _is_batch_message(msg: "StreamMessage[Any]") -> bool: - with_batch = baggage.get_baggage(WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers)) + with_batch = baggage.get_baggage( + WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers) + ) return bool(msg.batch_headers or with_batch) @@ -50,7 +52,9 @@ def _get_span_link(headers: "AnyDict", count: int) -> Optional[Link]: span_context = next(iter(trace_context.values())) if not isinstance(span_context, Span): return None - attributes = {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: count} if count > 1 else None + attributes = ( + {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: count} if count > 1 else None + ) return Link(span_context.get_span_context(), attributes=attributes) @@ -185,7 +189,9 @@ async def publish_scope( # NOTE: if batch with single message? if (msg_count := len((msg, *args))) > 1: trace_attributes[SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT] = msg_count - _BAGGAGE_PROPAGATOR.inject(headers, baggage.set_baggage(WITH_BATCH, True, context=current_context)) + _BAGGAGE_PROPAGATOR.inject( + headers, baggage.set_baggage(WITH_BATCH, True, context=current_context) + ) if self._current_span and self._current_span.is_recording(): current_context = trace.set_span_in_context( diff --git a/tests/opentelemetry/nats/test_nats.py b/tests/opentelemetry/nats/test_nats.py index 3a89699dc0..69007ddfbb 100644 --- a/tests/opentelemetry/nats/test_nats.py +++ b/tests/opentelemetry/nats/test_nats.py @@ -45,6 +45,7 @@ async def test_batch( broker = self.broker_class(middlewares=(mid,)) expected_msg_count = 3 expected_span_count = 8 + expected_proc_batch_count = 1 @broker.subscriber( queue, @@ -82,7 +83,7 @@ async def handler(m): ) assert proc_msg.data.data_points[0].value == expected_msg_count assert pub_msg.data.data_points[0].value == expected_msg_count - assert proc_dur.data.data_points[0].count == 1 + assert proc_dur.data.data_points[0].count == expected_proc_batch_count assert pub_dur.data.data_points[0].count == expected_msg_count assert event.is_set() diff --git a/tests/opentelemetry/redis/test_redis.py b/tests/opentelemetry/redis/test_redis.py index c2d06681d0..31b9216e65 100644 --- a/tests/opentelemetry/redis/test_redis.py +++ b/tests/opentelemetry/redis/test_redis.py @@ -43,6 +43,7 @@ async def test_batch( broker = self.broker_class(middlewares=(mid,)) expected_msg_count = 3 expected_link_count = 1 + expected_link_attrs = {"messaging.batch.message_count": 3} @broker.subscriber(list=ListSub(queue, batch=True), **self.subscriber_kwargs) async def handler(m): @@ -72,11 +73,122 @@ async def handler(m): == expected_msg_count ) assert len(create_process.links) == expected_link_count + assert create_process.links[0].attributes == expected_link_attrs self.assert_metrics(metrics, count=expected_msg_count) assert event.is_set() mock.assert_called_once_with([1, "hi", 3]) + async def test_batch_publish_with_single_consume( + self, + queue: str, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + msgs_queue = asyncio.Queue(maxsize=3) + expected_msg_count = 3 + expected_link_count = 1 + expected_span_count = 8 + expected_pub_batch_count = 1 + + @broker.subscriber(list=ListSub(queue), **self.subscriber_kwargs) + async def handler(msg): + await msgs_queue.put(msg) + + broker = self.patch_broker(broker) + + async with broker: + await broker.start() + await broker.publish_batch(1, "hi", 3, list=queue) + result, _ = await asyncio.wait( + ( + asyncio.create_task(msgs_queue.get()), + asyncio.create_task(msgs_queue.get()), + asyncio.create_task(msgs_queue.get()), + ), + timeout=3, + ) + + metrics = self.get_metrics(metric_reader) + proc_dur, proc_msg, pub_dur, pub_msg = metrics + spans = self.get_spans(trace_exporter) + publish = spans[1] + create_processes = [spans[2], spans[4], spans[6]] + + assert len(spans) == expected_span_count + assert ( + publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] + == expected_msg_count + ) + for cp in create_processes: + assert len(cp.links) == expected_link_count + + assert proc_msg.data.data_points[0].value == expected_msg_count + assert pub_msg.data.data_points[0].value == expected_msg_count + assert proc_dur.data.data_points[0].count == expected_msg_count + assert pub_dur.data.data_points[0].count == expected_pub_batch_count + + assert {1, "hi", 3} == {r.result() for r in result} + + async def test_single_publish_with_batch_consume( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + expected_msg_count = 2 + expected_link_count = 2 + expected_span_count = 6 + expected_process_batch_count = 1 + + @broker.subscriber(list=ListSub(queue, batch=True), **self.subscriber_kwargs) + async def handler(m): + m.sort() + mock(m) + event.set() + + broker = self.patch_broker(broker) + + async with broker: + tasks = ( + asyncio.create_task(broker.publish("hi", list=queue)), + asyncio.create_task(broker.publish("buy", list=queue)), + ) + await asyncio.wait(tasks, timeout=self.timeout) + await broker.start() + await asyncio.wait( + (asyncio.create_task(event.wait()),), timeout=self.timeout + ) + + metrics = self.get_metrics(metric_reader) + proc_dur, proc_msg, pub_dur, pub_msg = metrics + spans = self.get_spans(trace_exporter) + create_process = spans[-2] + + assert len(spans) == expected_span_count + assert len(create_process.links) == expected_link_count + assert proc_msg.data.data_points[0].value == expected_msg_count + assert pub_msg.data.data_points[0].value == expected_msg_count + assert proc_dur.data.data_points[0].count == expected_process_batch_count + assert pub_dur.data.data_points[0].count == expected_msg_count + + assert event.is_set() + mock.assert_called_once_with(["buy", "hi"]) + @pytest.mark.redis() class TestPublishWithTelemetry(TestPublish): From ac97e7559507bd177787deddb60f8b30f1e23d53 Mon Sep 17 00:00:00 2001 From: treaditup Date: Mon, 17 Jun 2024 15:45:15 +0300 Subject: [PATCH 4/9] test: add batch telemetry tests for kafka --- .../opentelemetry/confluent/test_confluent.py | 110 ++++++++++++++++++ tests/opentelemetry/kafka/test_kafka.py | 110 ++++++++++++++++++ 2 files changed, 220 insertions(+) diff --git a/tests/opentelemetry/confluent/test_confluent.py b/tests/opentelemetry/confluent/test_confluent.py index 3402eff841..930bf9aeaf 100644 --- a/tests/opentelemetry/confluent/test_confluent.py +++ b/tests/opentelemetry/confluent/test_confluent.py @@ -79,6 +79,7 @@ async def test_batch( broker = self.broker_class(middlewares=(mid,)) expected_msg_count = 3 expected_link_count = 1 + expected_link_attrs = {"messaging.batch.message_count": 3} @broker.subscriber(queue, batch=True, **self.subscriber_kwargs) async def handler(m): @@ -108,11 +109,120 @@ async def handler(m): == expected_msg_count ) assert len(create_process.links) == expected_link_count + assert create_process.links[0].attributes == expected_link_attrs self.assert_metrics(metrics, count=expected_msg_count) assert event.is_set() mock.assert_called_once_with([1, "hi", 3]) + async def test_batch_publish_with_single_consume( + self, + queue: str, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + msgs_queue = asyncio.Queue(maxsize=3) + expected_msg_count = 3 + expected_link_count = 1 + expected_span_count = 8 + expected_pub_batch_count = 1 + + @broker.subscriber(queue, **self.subscriber_kwargs) + async def handler(msg): + await msgs_queue.put(msg) + + broker = self.patch_broker(broker) + + async with broker: + await broker.start() + await broker.publish_batch(1, "hi", 3, topic=queue) + result, _ = await asyncio.wait( + ( + asyncio.create_task(msgs_queue.get()), + asyncio.create_task(msgs_queue.get()), + asyncio.create_task(msgs_queue.get()), + ), + timeout=self.timeout, + ) + + metrics = self.get_metrics(metric_reader) + proc_dur, proc_msg, pub_dur, pub_msg = metrics + spans = self.get_spans(trace_exporter) + publish = spans[1] + create_processes = [spans[2], spans[4], spans[6]] + + assert len(spans) == expected_span_count + assert ( + publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] + == expected_msg_count + ) + for cp in create_processes: + assert len(cp.links) == expected_link_count + + assert proc_msg.data.data_points[0].value == expected_msg_count + assert pub_msg.data.data_points[0].value == expected_msg_count + assert proc_dur.data.data_points[0].count == expected_msg_count + assert pub_dur.data.data_points[0].count == expected_pub_batch_count + + assert {1, "hi", 3} == {r.result() for r in result} + + async def test_single_publish_with_batch_consume( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + expected_msg_count = 2 + expected_link_count = 2 + expected_span_count = 6 + expected_process_batch_count = 1 + + @broker.subscriber(queue, batch=True, **self.subscriber_kwargs) + async def handler(m): + m.sort() + mock(m) + event.set() + + broker = self.patch_broker(broker) + + async with broker: + await broker.start() + tasks = ( + asyncio.create_task(broker.publish("hi", topic=queue)), + asyncio.create_task(broker.publish("buy", topic=queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + metrics = self.get_metrics(metric_reader) + proc_dur, proc_msg, pub_dur, pub_msg = metrics + spans = self.get_spans(trace_exporter) + create_process = spans[-2] + + assert len(spans) == expected_span_count + assert len(create_process.links) == expected_link_count + assert proc_msg.data.data_points[0].value == expected_msg_count + assert pub_msg.data.data_points[0].value == expected_msg_count + assert proc_dur.data.data_points[0].count == expected_process_batch_count + assert pub_dur.data.data_points[0].count == expected_msg_count + + assert event.is_set() + mock.assert_called_once_with(["buy", "hi"]) + @pytest.mark.confluent() class TestPublishWithTelemetry(TestPublish): diff --git a/tests/opentelemetry/kafka/test_kafka.py b/tests/opentelemetry/kafka/test_kafka.py index 6dfd0eaa9e..c8f67b40b1 100644 --- a/tests/opentelemetry/kafka/test_kafka.py +++ b/tests/opentelemetry/kafka/test_kafka.py @@ -77,6 +77,7 @@ async def test_batch( broker = self.broker_class(middlewares=(mid,)) expected_msg_count = 3 expected_link_count = 1 + expected_link_attrs = {"messaging.batch.message_count": 3} @broker.subscriber(queue, batch=True, **self.subscriber_kwargs) async def handler(m): @@ -106,11 +107,120 @@ async def handler(m): == expected_msg_count ) assert len(create_process.links) == expected_link_count + assert create_process.links[0].attributes == expected_link_attrs self.assert_metrics(metrics, count=expected_msg_count) assert event.is_set() mock.assert_called_once_with([1, "hi", 3]) + async def test_batch_publish_with_single_consume( + self, + queue: str, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + msgs_queue = asyncio.Queue(maxsize=3) + expected_msg_count = 3 + expected_link_count = 1 + expected_span_count = 8 + expected_pub_batch_count = 1 + + @broker.subscriber(queue, **self.subscriber_kwargs) + async def handler(msg): + await msgs_queue.put(msg) + + broker = self.patch_broker(broker) + + async with broker: + await broker.start() + await broker.publish_batch(1, "hi", 3, topic=queue) + result, _ = await asyncio.wait( + ( + asyncio.create_task(msgs_queue.get()), + asyncio.create_task(msgs_queue.get()), + asyncio.create_task(msgs_queue.get()), + ), + timeout=self.timeout, + ) + + metrics = self.get_metrics(metric_reader) + proc_dur, proc_msg, pub_dur, pub_msg = metrics + spans = self.get_spans(trace_exporter) + publish = spans[1] + create_processes = [spans[2], spans[4], spans[6]] + + assert len(spans) == expected_span_count + assert ( + publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] + == expected_msg_count + ) + for cp in create_processes: + assert len(cp.links) == expected_link_count + + assert proc_msg.data.data_points[0].value == expected_msg_count + assert pub_msg.data.data_points[0].value == expected_msg_count + assert proc_dur.data.data_points[0].count == expected_msg_count + assert pub_dur.data.data_points[0].count == expected_pub_batch_count + + assert {1, "hi", 3} == {r.result() for r in result} + + async def test_single_publish_with_batch_consume( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + expected_msg_count = 2 + expected_link_count = 2 + expected_span_count = 6 + expected_process_batch_count = 1 + + @broker.subscriber(queue, batch=True, **self.subscriber_kwargs) + async def handler(m): + m.sort() + mock(m) + event.set() + + broker = self.patch_broker(broker) + + async with broker: + await broker.start() + tasks = ( + asyncio.create_task(broker.publish("hi", topic=queue)), + asyncio.create_task(broker.publish("buy", topic=queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + metrics = self.get_metrics(metric_reader) + proc_dur, proc_msg, pub_dur, pub_msg = metrics + spans = self.get_spans(trace_exporter) + create_process = spans[-2] + + assert len(spans) == expected_span_count + assert len(create_process.links) == expected_link_count + assert proc_msg.data.data_points[0].value == expected_msg_count + assert pub_msg.data.data_points[0].value == expected_msg_count + assert proc_dur.data.data_points[0].count == expected_process_batch_count + assert pub_dur.data.data_points[0].count == expected_msg_count + + assert event.is_set() + mock.assert_called_once_with(["buy", "hi"]) + @pytest.mark.kafka() class TestPublishWithTelemetry(TestPublish): From d6d3a3099a004694312e0eb9eabb0691439fd111 Mon Sep 17 00:00:00 2001 From: treaditup Date: Mon, 17 Jun 2024 19:55:41 +0300 Subject: [PATCH 5/9] test: refactor nats test_batch --- tests/opentelemetry/nats/test_nats.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/opentelemetry/nats/test_nats.py b/tests/opentelemetry/nats/test_nats.py index 69007ddfbb..5a6d335766 100644 --- a/tests/opentelemetry/nats/test_nats.py +++ b/tests/opentelemetry/nats/test_nats.py @@ -43,14 +43,14 @@ async def test_batch( meter_provider=meter_provider, tracer_provider=tracer_provider ) broker = self.broker_class(middlewares=(mid,)) - expected_msg_count = 3 - expected_span_count = 8 + expected_msg_count = 1 + expected_span_count = 4 expected_proc_batch_count = 1 @broker.subscriber( queue, stream=stream, - pull_sub=PullSub(3, batch=True, timeout=30.0), + pull_sub=PullSub(1, batch=True, timeout=30.0), **self.subscriber_kwargs, ) async def handler(m): @@ -62,9 +62,7 @@ async def handler(m): async with broker: await broker.start() tasks = ( - asyncio.create_task(broker.publish(1, queue)), asyncio.create_task(broker.publish("hi", queue)), - asyncio.create_task(broker.publish(3, queue)), asyncio.create_task(event.wait()), ) await asyncio.wait(tasks, timeout=self.timeout) @@ -87,7 +85,7 @@ async def handler(m): assert pub_dur.data.data_points[0].count == expected_msg_count assert event.is_set() - mock.assert_called_once_with([1, "hi", 3]) + mock.assert_called_once_with(["hi"]) @pytest.mark.nats() From 959c5e4559c73d9d24f75fd61f32ea62ee4d9e02 Mon Sep 17 00:00:00 2001 From: treaditup Date: Mon, 17 Jun 2024 23:17:19 +0300 Subject: [PATCH 6/9] refactor: _get_span_links --- faststream/opentelemetry/middleware.py | 37 ++++++++++++++------------ 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/faststream/opentelemetry/middleware.py b/faststream/opentelemetry/middleware.py index 3055341f89..bf9286bfd4 100644 --- a/faststream/opentelemetry/middleware.py +++ b/faststream/opentelemetry/middleware.py @@ -1,7 +1,7 @@ import time -from collections import Counter +from collections import defaultdict from copy import copy -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Type +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type from opentelemetry import baggage, context, metrics, trace from opentelemetry.baggage.propagation import W3CBaggagePropagator @@ -25,6 +25,7 @@ from opentelemetry.metrics import Meter, MeterProvider from opentelemetry.trace import Tracer, TracerProvider + from opentelemetry.util.types import Attributes from faststream.broker.message import StreamMessage from faststream.types import AnyDict, AsyncFunc, AsyncFuncAny @@ -45,40 +46,42 @@ def _is_batch_message(msg: "StreamMessage[Any]") -> bool: return bool(msg.batch_headers or with_batch) -def _get_span_link(headers: "AnyDict", count: int) -> Optional[Link]: +def _get_span_link(headers: "AnyDict", attributes: "Attributes") -> Optional[Link]: trace_context = _TRACE_PROPAGATOR.extract(headers) if not len(trace_context): return None span_context = next(iter(trace_context.values())) if not isinstance(span_context, Span): return None - attributes = ( - {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: count} if count > 1 else None - ) return Link(span_context.get_span_context(), attributes=attributes) -def _get_span_links(msg: "StreamMessage[Any]") -> Optional[List[Link]]: +def _get_link_attributes(message_count: int) -> "Attributes": + if message_count <= 1: + return {} + return {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: message_count} + + +def _get_span_links(msg: "StreamMessage[Any]") -> List[Link]: if not msg.batch_headers: - link = _get_span_link(msg.headers, 1) - return [link] if link else None + link = _get_span_link(msg.headers, {}) + return [link] if link else [] - links, headers_by_correlation, all_correlations = [], {}, [] + links = {} + counter: Dict[str, int] = defaultdict(lambda: 0) for headers in msg.batch_headers: correlation_id = headers.get("correlation_id") if correlation_id is None: continue - headers_by_correlation[correlation_id] = headers - all_correlations.append(correlation_id) - - for correlation_id, count in Counter(all_correlations).most_common(): - link = _get_span_link(headers_by_correlation[correlation_id], count) + counter[correlation_id] += 1 + attributes = _get_link_attributes(counter[correlation_id]) + link = _get_span_link(headers, attributes) if link is None: continue - links.append(link) + links[correlation_id] = link - return links if links else None + return list(links.values()) class _MetricsContainer: From 36997f26eea26f75710e42900312c0388b85fb7b Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 19 Jun 2024 13:09:31 +0300 Subject: [PATCH 7/9] refactor: optimize OTEL links resolution --- faststream/opentelemetry/middleware.py | 105 +++++++++++++------------ 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/faststream/opentelemetry/middleware.py b/faststream/opentelemetry/middleware.py index bf9286bfd4..2da4635b89 100644 --- a/faststream/opentelemetry/middleware.py +++ b/faststream/opentelemetry/middleware.py @@ -35,55 +35,6 @@ _TRACE_PROPAGATOR = TraceContextTextMapPropagator() -def _create_span_name(destination: str, action: str) -> str: - return f"{destination} {action}" - - -def _is_batch_message(msg: "StreamMessage[Any]") -> bool: - with_batch = baggage.get_baggage( - WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers) - ) - return bool(msg.batch_headers or with_batch) - - -def _get_span_link(headers: "AnyDict", attributes: "Attributes") -> Optional[Link]: - trace_context = _TRACE_PROPAGATOR.extract(headers) - if not len(trace_context): - return None - span_context = next(iter(trace_context.values())) - if not isinstance(span_context, Span): - return None - return Link(span_context.get_span_context(), attributes=attributes) - - -def _get_link_attributes(message_count: int) -> "Attributes": - if message_count <= 1: - return {} - return {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: message_count} - - -def _get_span_links(msg: "StreamMessage[Any]") -> List[Link]: - if not msg.batch_headers: - link = _get_span_link(msg.headers, {}) - return [link] if link else [] - - links = {} - counter: Dict[str, int] = defaultdict(lambda: 0) - - for headers in msg.batch_headers: - correlation_id = headers.get("correlation_id") - if correlation_id is None: - continue - counter[correlation_id] += 1 - attributes = _get_link_attributes(counter[correlation_id]) - link = _get_span_link(headers, attributes) - if link is None: - continue - links[correlation_id] = link - - return list(links.values()) - - class _MetricsContainer: __slots__ = ( "include_messages_counters", @@ -245,7 +196,7 @@ async def consume_scope( return await call_next(msg) if _is_batch_message(msg): - links = _get_span_links(msg) + links = _get_msg_links(msg) current_context = Context() else: links = None @@ -365,3 +316,57 @@ def _get_tracer(tracer_provider: Optional["TracerProvider"] = None) -> "Tracer": tracer_provider=tracer_provider, schema_url=OTEL_SCHEMA, ) + + +def _create_span_name(destination: str, action: str) -> str: + return f"{destination} {action}" + + +def _is_batch_message(msg: "StreamMessage[Any]") -> bool: + with_batch = baggage.get_baggage( + WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers) + ) + return bool(msg.batch_headers or with_batch) + + +def _get_msg_links(msg: "StreamMessage[Any]") -> List[Link]: + if not msg.batch_headers: + if (span := _get_span_from_headers(msg.headers)) is not None: + return [Link(span.get_span_context())] + else: + return [] + + links = {} + counter: Dict[str, int] = defaultdict(lambda: 0) + + for headers in msg.batch_headers: + if (correlation_id := headers.get("correlation_id")) is None: + continue + + counter[correlation_id] += 1 + + if (span := _get_span_from_headers(headers)) is None: + continue + + attributes = _get_link_attributes(counter[correlation_id]) + + links[correlation_id] = Link( + span.get_span_context(), + attributes=attributes, + ) + + return list(links.values()) + + +def _get_span_from_headers(headers: "AnyDict", attributes: "Attributes") -> Optional[Span]: + trace_context = _TRACE_PROPAGATOR.extract(headers) + if not len(trace_context): + return None + + return next(iter(trace_context.values())) + + +def _get_link_attributes(message_count: int) -> "Attributes": + if message_count <= 1: + return {} + return {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: message_count} From b8f089b8cd245acb8f137b6f60138aab012ade95 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 19 Jun 2024 13:22:09 +0300 Subject: [PATCH 8/9] lint: fix codespell --- .codespell-whitelist.txt | 1 + docs/docs/en/release.md | 2 +- faststream/opentelemetry/middleware.py | 6 ++++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.codespell-whitelist.txt b/.codespell-whitelist.txt index cd9d103e1e..0ec54623a2 100644 --- a/.codespell-whitelist.txt +++ b/.codespell-whitelist.txt @@ -1,2 +1,3 @@ dependant unsecure +socio-economic \ No newline at end of file diff --git a/docs/docs/en/release.md b/docs/docs/en/release.md index feadf44839..28037fe2c7 100644 --- a/docs/docs/en/release.md +++ b/docs/docs/en/release.md @@ -324,7 +324,7 @@ You can find more information about it in the official [**aiokafka** doc](https: `pattern` option was added too, but it is still experimental and does not support `Path` -3. [`Path`](https://faststream.airt.ai/latest/nats/message/#subject-pattern-access) feature performance was increased. Also, `Path` is suitable for NATS `PullSub` batch subscribtion as well now. +3. [`Path`](https://faststream.airt.ai/latest/nats/message/#subject-pattern-access) feature performance was increased. Also, `Path` is suitable for NATS `PullSub` batch subscription as well now. ```python from faststream import NatsBroker, PullSub diff --git a/faststream/opentelemetry/middleware.py b/faststream/opentelemetry/middleware.py index 2da4635b89..77fcade0df 100644 --- a/faststream/opentelemetry/middleware.py +++ b/faststream/opentelemetry/middleware.py @@ -358,7 +358,7 @@ def _get_msg_links(msg: "StreamMessage[Any]") -> List[Link]: return list(links.values()) -def _get_span_from_headers(headers: "AnyDict", attributes: "Attributes") -> Optional[Span]: +def _get_span_from_headers(headers: "AnyDict") -> Optional[Span]: trace_context = _TRACE_PROPAGATOR.extract(headers) if not len(trace_context): return None @@ -369,4 +369,6 @@ def _get_span_from_headers(headers: "AnyDict", attributes: "Attributes") -> Opti def _get_link_attributes(message_count: int) -> "Attributes": if message_count <= 1: return {} - return {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: message_count} + return { + SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: message_count, + } From 305308ed51c287faca85624ae3e38477868d7116 Mon Sep 17 00:00:00 2001 From: Nikita Pastukhov Date: Wed, 19 Jun 2024 13:38:08 +0300 Subject: [PATCH 9/9] lint: fix mypy --- faststream/opentelemetry/middleware.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/faststream/opentelemetry/middleware.py b/faststream/opentelemetry/middleware.py index 77fcade0df..695cef979d 100644 --- a/faststream/opentelemetry/middleware.py +++ b/faststream/opentelemetry/middleware.py @@ -1,7 +1,7 @@ import time from collections import defaultdict from copy import copy -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, cast from opentelemetry import baggage, context, metrics, trace from opentelemetry.baggage.propagation import W3CBaggagePropagator @@ -363,7 +363,10 @@ def _get_span_from_headers(headers: "AnyDict") -> Optional[Span]: if not len(trace_context): return None - return next(iter(trace_context.values())) + return cast( + Optional[Span], + next(iter(trace_context.values())), + ) def _get_link_attributes(message_count: int) -> "Attributes":