Skip to content

Commit

Permalink
test: add batch telemetry tests for redis
Browse files Browse the repository at this point in the history
  • Loading branch information
draincoder committed Jun 17, 2024
1 parent 550e0ee commit bf35dca
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 5 deletions.
14 changes: 10 additions & 4 deletions faststream/opentelemetry/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from opentelemetry import baggage, context, metrics, trace
from opentelemetry.baggage.propagation import W3CBaggagePropagator
from opentelemetry.context import Context
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Link, Span
from opentelemetry.context import Context
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

from faststream import BaseMiddleware
Expand Down Expand Up @@ -39,7 +39,9 @@ def _create_span_name(destination: str, action: str) -> str:


def _is_batch_message(msg: "StreamMessage[Any]") -> bool:
with_batch = baggage.get_baggage(WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers))
with_batch = baggage.get_baggage(
WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers)
)
return bool(msg.batch_headers or with_batch)


Expand All @@ -50,7 +52,9 @@ def _get_span_link(headers: "AnyDict", count: int) -> Optional[Link]:
span_context = next(iter(trace_context.values()))
if not isinstance(span_context, Span):
return None
attributes = {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: count} if count > 1 else None
attributes = (
{SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: count} if count > 1 else None
)
return Link(span_context.get_span_context(), attributes=attributes)


Expand Down Expand Up @@ -185,7 +189,9 @@ async def publish_scope(
# NOTE: if batch with single message?
if (msg_count := len((msg, *args))) > 1:
trace_attributes[SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT] = msg_count
_BAGGAGE_PROPAGATOR.inject(headers, baggage.set_baggage(WITH_BATCH, True, context=current_context))
_BAGGAGE_PROPAGATOR.inject(
headers, baggage.set_baggage(WITH_BATCH, True, context=current_context)
)

if self._current_span and self._current_span.is_recording():
current_context = trace.set_span_in_context(
Expand Down
3 changes: 2 additions & 1 deletion tests/opentelemetry/nats/test_nats.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ async def test_batch(
broker = self.broker_class(middlewares=(mid,))
expected_msg_count = 3
expected_span_count = 8
expected_proc_batch_count = 1

@broker.subscriber(
queue,
Expand Down Expand Up @@ -82,7 +83,7 @@ async def handler(m):
)
assert proc_msg.data.data_points[0].value == expected_msg_count
assert pub_msg.data.data_points[0].value == expected_msg_count
assert proc_dur.data.data_points[0].count == 1
assert proc_dur.data.data_points[0].count == expected_proc_batch_count
assert pub_dur.data.data_points[0].count == expected_msg_count

assert event.is_set()
Expand Down
112 changes: 112 additions & 0 deletions tests/opentelemetry/redis/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async def test_batch(
broker = self.broker_class(middlewares=(mid,))
expected_msg_count = 3
expected_link_count = 1
expected_link_attrs = {"messaging.batch.message_count": 3}

@broker.subscriber(list=ListSub(queue, batch=True), **self.subscriber_kwargs)
async def handler(m):
Expand Down Expand Up @@ -72,11 +73,122 @@ async def handler(m):
== expected_msg_count
)
assert len(create_process.links) == expected_link_count
assert create_process.links[0].attributes == expected_link_attrs
self.assert_metrics(metrics, count=expected_msg_count)

assert event.is_set()
mock.assert_called_once_with([1, "hi", 3])

async def test_batch_publish_with_single_consume(
self,
queue: str,
meter_provider: MeterProvider,
metric_reader: InMemoryMetricReader,
tracer_provider: TracerProvider,
trace_exporter: InMemorySpanExporter,
):
mid = self.telemetry_middleware_class(
meter_provider=meter_provider, tracer_provider=tracer_provider
)
broker = self.broker_class(middlewares=(mid,))
msgs_queue = asyncio.Queue(maxsize=3)
expected_msg_count = 3
expected_link_count = 1
expected_span_count = 8
expected_pub_batch_count = 1

@broker.subscriber(list=ListSub(queue), **self.subscriber_kwargs)
async def handler(msg):
await msgs_queue.put(msg)

broker = self.patch_broker(broker)

async with broker:
await broker.start()
await broker.publish_batch(1, "hi", 3, list=queue)
result, _ = await asyncio.wait(
(
asyncio.create_task(msgs_queue.get()),
asyncio.create_task(msgs_queue.get()),
asyncio.create_task(msgs_queue.get()),
),
timeout=3,
)

metrics = self.get_metrics(metric_reader)
proc_dur, proc_msg, pub_dur, pub_msg = metrics
spans = self.get_spans(trace_exporter)
publish = spans[1]
create_processes = [spans[2], spans[4], spans[6]]

assert len(spans) == expected_span_count
assert (
publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT]
== expected_msg_count
)
for cp in create_processes:
assert len(cp.links) == expected_link_count

assert proc_msg.data.data_points[0].value == expected_msg_count
assert pub_msg.data.data_points[0].value == expected_msg_count
assert proc_dur.data.data_points[0].count == expected_msg_count
assert pub_dur.data.data_points[0].count == expected_pub_batch_count

assert {1, "hi", 3} == {r.result() for r in result}

async def test_single_publish_with_batch_consume(
self,
event: asyncio.Event,
queue: str,
mock: Mock,
meter_provider: MeterProvider,
metric_reader: InMemoryMetricReader,
tracer_provider: TracerProvider,
trace_exporter: InMemorySpanExporter,
):
mid = self.telemetry_middleware_class(
meter_provider=meter_provider, tracer_provider=tracer_provider
)
broker = self.broker_class(middlewares=(mid,))
expected_msg_count = 2
expected_link_count = 2
expected_span_count = 6
expected_process_batch_count = 1

@broker.subscriber(list=ListSub(queue, batch=True), **self.subscriber_kwargs)
async def handler(m):
m.sort()
mock(m)
event.set()

broker = self.patch_broker(broker)

async with broker:
tasks = (
asyncio.create_task(broker.publish("hi", list=queue)),
asyncio.create_task(broker.publish("buy", list=queue)),
)
await asyncio.wait(tasks, timeout=self.timeout)
await broker.start()
await asyncio.wait(
(asyncio.create_task(event.wait()),), timeout=self.timeout
)

metrics = self.get_metrics(metric_reader)
proc_dur, proc_msg, pub_dur, pub_msg = metrics
spans = self.get_spans(trace_exporter)
create_process = spans[-2]

assert len(spans) == expected_span_count
assert len(create_process.links) == expected_link_count
assert proc_msg.data.data_points[0].value == expected_msg_count
assert pub_msg.data.data_points[0].value == expected_msg_count
assert proc_dur.data.data_points[0].count == expected_process_batch_count
assert pub_dur.data.data_points[0].count == expected_msg_count

assert event.is_set()
mock.assert_called_once_with(["buy", "hi"])


@pytest.mark.redis()
class TestPublishWithTelemetry(TestPublish):
Expand Down

0 comments on commit bf35dca

Please sign in to comment.