From bf35dcaa19414a402bc9c5e35adaee6c4bf9cc37 Mon Sep 17 00:00:00 2001 From: treaditup Date: Mon, 17 Jun 2024 15:16:04 +0300 Subject: [PATCH] test: add batch telemetry tests for redis --- faststream/opentelemetry/middleware.py | 14 ++- tests/opentelemetry/nats/test_nats.py | 3 +- tests/opentelemetry/redis/test_redis.py | 112 ++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 5 deletions(-) diff --git a/faststream/opentelemetry/middleware.py b/faststream/opentelemetry/middleware.py index fd1fdd1eb2..3055341f89 100644 --- a/faststream/opentelemetry/middleware.py +++ b/faststream/opentelemetry/middleware.py @@ -5,9 +5,9 @@ from opentelemetry import baggage, context, metrics, trace from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.context import Context from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import Link, Span -from opentelemetry.context import Context from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from faststream import BaseMiddleware @@ -39,7 +39,9 @@ def _create_span_name(destination: str, action: str) -> str: def _is_batch_message(msg: "StreamMessage[Any]") -> bool: - with_batch = baggage.get_baggage(WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers)) + with_batch = baggage.get_baggage( + WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers) + ) return bool(msg.batch_headers or with_batch) @@ -50,7 +52,9 @@ def _get_span_link(headers: "AnyDict", count: int) -> Optional[Link]: span_context = next(iter(trace_context.values())) if not isinstance(span_context, Span): return None - attributes = {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: count} if count > 1 else None + attributes = ( + {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: count} if count > 1 else None + ) return Link(span_context.get_span_context(), attributes=attributes) @@ -185,7 +189,9 @@ async def publish_scope( # NOTE: if batch with single message? if (msg_count := len((msg, *args))) > 1: trace_attributes[SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT] = msg_count - _BAGGAGE_PROPAGATOR.inject(headers, baggage.set_baggage(WITH_BATCH, True, context=current_context)) + _BAGGAGE_PROPAGATOR.inject( + headers, baggage.set_baggage(WITH_BATCH, True, context=current_context) + ) if self._current_span and self._current_span.is_recording(): current_context = trace.set_span_in_context( diff --git a/tests/opentelemetry/nats/test_nats.py b/tests/opentelemetry/nats/test_nats.py index 3a89699dc0..69007ddfbb 100644 --- a/tests/opentelemetry/nats/test_nats.py +++ b/tests/opentelemetry/nats/test_nats.py @@ -45,6 +45,7 @@ async def test_batch( broker = self.broker_class(middlewares=(mid,)) expected_msg_count = 3 expected_span_count = 8 + expected_proc_batch_count = 1 @broker.subscriber( queue, @@ -82,7 +83,7 @@ async def handler(m): ) assert proc_msg.data.data_points[0].value == expected_msg_count assert pub_msg.data.data_points[0].value == expected_msg_count - assert proc_dur.data.data_points[0].count == 1 + assert proc_dur.data.data_points[0].count == expected_proc_batch_count assert pub_dur.data.data_points[0].count == expected_msg_count assert event.is_set() diff --git a/tests/opentelemetry/redis/test_redis.py b/tests/opentelemetry/redis/test_redis.py index c2d06681d0..31b9216e65 100644 --- a/tests/opentelemetry/redis/test_redis.py +++ b/tests/opentelemetry/redis/test_redis.py @@ -43,6 +43,7 @@ async def test_batch( broker = self.broker_class(middlewares=(mid,)) expected_msg_count = 3 expected_link_count = 1 + expected_link_attrs = {"messaging.batch.message_count": 3} @broker.subscriber(list=ListSub(queue, batch=True), **self.subscriber_kwargs) async def handler(m): @@ -72,11 +73,122 @@ async def handler(m): == expected_msg_count ) assert len(create_process.links) == expected_link_count + assert create_process.links[0].attributes == expected_link_attrs self.assert_metrics(metrics, count=expected_msg_count) assert event.is_set() mock.assert_called_once_with([1, "hi", 3]) + async def test_batch_publish_with_single_consume( + self, + queue: str, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + msgs_queue = asyncio.Queue(maxsize=3) + expected_msg_count = 3 + expected_link_count = 1 + expected_span_count = 8 + expected_pub_batch_count = 1 + + @broker.subscriber(list=ListSub(queue), **self.subscriber_kwargs) + async def handler(msg): + await msgs_queue.put(msg) + + broker = self.patch_broker(broker) + + async with broker: + await broker.start() + await broker.publish_batch(1, "hi", 3, list=queue) + result, _ = await asyncio.wait( + ( + asyncio.create_task(msgs_queue.get()), + asyncio.create_task(msgs_queue.get()), + asyncio.create_task(msgs_queue.get()), + ), + timeout=3, + ) + + metrics = self.get_metrics(metric_reader) + proc_dur, proc_msg, pub_dur, pub_msg = metrics + spans = self.get_spans(trace_exporter) + publish = spans[1] + create_processes = [spans[2], spans[4], spans[6]] + + assert len(spans) == expected_span_count + assert ( + publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT] + == expected_msg_count + ) + for cp in create_processes: + assert len(cp.links) == expected_link_count + + assert proc_msg.data.data_points[0].value == expected_msg_count + assert pub_msg.data.data_points[0].value == expected_msg_count + assert proc_dur.data.data_points[0].count == expected_msg_count + assert pub_dur.data.data_points[0].count == expected_pub_batch_count + + assert {1, "hi", 3} == {r.result() for r in result} + + async def test_single_publish_with_batch_consume( + self, + event: asyncio.Event, + queue: str, + mock: Mock, + meter_provider: MeterProvider, + metric_reader: InMemoryMetricReader, + tracer_provider: TracerProvider, + trace_exporter: InMemorySpanExporter, + ): + mid = self.telemetry_middleware_class( + meter_provider=meter_provider, tracer_provider=tracer_provider + ) + broker = self.broker_class(middlewares=(mid,)) + expected_msg_count = 2 + expected_link_count = 2 + expected_span_count = 6 + expected_process_batch_count = 1 + + @broker.subscriber(list=ListSub(queue, batch=True), **self.subscriber_kwargs) + async def handler(m): + m.sort() + mock(m) + event.set() + + broker = self.patch_broker(broker) + + async with broker: + tasks = ( + asyncio.create_task(broker.publish("hi", list=queue)), + asyncio.create_task(broker.publish("buy", list=queue)), + ) + await asyncio.wait(tasks, timeout=self.timeout) + await broker.start() + await asyncio.wait( + (asyncio.create_task(event.wait()),), timeout=self.timeout + ) + + metrics = self.get_metrics(metric_reader) + proc_dur, proc_msg, pub_dur, pub_msg = metrics + spans = self.get_spans(trace_exporter) + create_process = spans[-2] + + assert len(spans) == expected_span_count + assert len(create_process.links) == expected_link_count + assert proc_msg.data.data_points[0].value == expected_msg_count + assert pub_msg.data.data_points[0].value == expected_msg_count + assert proc_dur.data.data_points[0].count == expected_process_batch_count + assert pub_dur.data.data_points[0].count == expected_msg_count + + assert event.is_set() + mock.assert_called_once_with(["buy", "hi"]) + @pytest.mark.redis() class TestPublishWithTelemetry(TestPublish):