Skip to content

Commit

Permalink
fix: correct spans linking in batches case (#1532)
Browse files Browse the repository at this point in the history
* fix: add links for batches

* fix: strict get correlation_id from dict

* test: add batch telemetry tests for redis

* test: add batch telemetry tests for kafka

* test: refactor nats test_batch

* refactor: _get_span_links

* refactor: optimize OTEL links resolution

* lint: fix codespell

* lint: fix mypy

---------

Co-authored-by: Nikita Pastukhov <[email protected]>
  • Loading branch information
draincoder and Lancetnik authored Jun 19, 2024
1 parent 1e87e76 commit 6bfca41
Show file tree
Hide file tree
Showing 8 changed files with 441 additions and 25 deletions.
1 change: 1 addition & 0 deletions .codespell-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
dependant
unsecure
socio-economic
2 changes: 1 addition & 1 deletion docs/docs/en/release.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ You can find more information about it in the official [**aiokafka** doc](https:

`pattern` option was added too, but it is still experimental and does not support `Path`

3. [`Path`](https://faststream.airt.ai/latest/nats/message/#subject-pattern-access) feature performance was increased. Also, `Path` is suitable for NATS `PullSub` batch subscribtion as well now.
3. [`Path`](https://faststream.airt.ai/latest/nats/message/#subject-pattern-access) feature performance was increased. Also, `Path` is suitable for NATS `PullSub` batch subscription as well now.

```python
from faststream import NatsBroker, PullSub
Expand Down
2 changes: 2 additions & 0 deletions faststream/opentelemetry/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@ class MessageAction:
RECEIVE = "receive"


OTEL_SCHEMA = "https://opentelemetry.io/schemas/1.11.0"
ERROR_TYPE = "error.type"
MESSAGING_DESTINATION_PUBLISH_NAME = "messaging.destination_publish.name"
WITH_BATCH = "with_batch"
102 changes: 87 additions & 15 deletions faststream/opentelemetry/middleware.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,38 @@
import time
from collections import defaultdict
from copy import copy
from typing import TYPE_CHECKING, Any, Callable, Optional, Type
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, cast

from opentelemetry import context, metrics, propagate, trace
from opentelemetry import baggage, context, metrics, trace
from opentelemetry.baggage.propagation import W3CBaggagePropagator
from opentelemetry.context import Context
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Link, Span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

from faststream import BaseMiddleware
from faststream.opentelemetry.consts import (
ERROR_TYPE,
MESSAGING_DESTINATION_PUBLISH_NAME,
OTEL_SCHEMA,
WITH_BATCH,
MessageAction,
)
from faststream.opentelemetry.provider import TelemetrySettingsProvider

if TYPE_CHECKING:
from types import TracebackType

from opentelemetry.context import Context
from opentelemetry.metrics import Meter, MeterProvider
from opentelemetry.trace import Span, Tracer, TracerProvider
from opentelemetry.trace import Tracer, TracerProvider
from opentelemetry.util.types import Attributes

from faststream.broker.message import StreamMessage
from faststream.types import AnyDict, AsyncFunc, AsyncFuncAny


_OTEL_SCHEMA = "https://opentelemetry.io/schemas/1.11.0"


def _create_span_name(destination: str, action: str) -> str:
return f"{destination} {action}"
_BAGGAGE_PROPAGATOR = W3CBaggagePropagator()
_TRACE_PROPAGATOR = TraceContextTextMapPropagator()


class _MetricsContainer:
Expand Down Expand Up @@ -139,12 +143,15 @@ async def publish_scope(
# NOTE: if batch with single message?
if (msg_count := len((msg, *args))) > 1:
trace_attributes[SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT] = msg_count
_BAGGAGE_PROPAGATOR.inject(
headers, baggage.set_baggage(WITH_BATCH, True, context=current_context)
)

if self._current_span and self._current_span.is_recording():
current_context = trace.set_span_in_context(
self._current_span, current_context
)
propagate.inject(headers, context=self._origin_context)
_TRACE_PROPAGATOR.inject(headers, context=self._origin_context)

else:
create_span = self._tracer.start_span(
Expand All @@ -153,7 +160,7 @@ async def publish_scope(
attributes=trace_attributes,
)
current_context = trace.set_span_in_context(create_span)
propagate.inject(headers, context=current_context)
_TRACE_PROPAGATOR.inject(headers, context=current_context)
create_span.end()

start_time = time.perf_counter()
Expand Down Expand Up @@ -188,9 +195,14 @@ async def consume_scope(
if (provider := self.__settings_provider) is None:
return await call_next(msg)

current_context = propagate.extract(msg.headers)
destination_name = provider.get_consume_destination_name(msg)
if _is_batch_message(msg):
links = _get_msg_links(msg)
current_context = Context()
else:
links = None
current_context = _TRACE_PROPAGATOR.extract(msg.headers)

destination_name = provider.get_consume_destination_name(msg)
trace_attributes = provider.get_consume_attrs_from_message(msg)
metrics_attributes = {
SpanAttributes.MESSAGING_SYSTEM: provider.messaging_system,
Expand All @@ -202,6 +214,7 @@ async def consume_scope(
name=_create_span_name(destination_name, MessageAction.CREATE),
kind=trace.SpanKind.CONSUMER,
attributes=trace_attributes,
links=links,
)
current_context = trace.set_span_in_context(create_span)
create_span.end()
Expand Down Expand Up @@ -292,7 +305,7 @@ def _get_meter(
return metrics.get_meter(
__name__,
meter_provider=meter_provider,
schema_url=_OTEL_SCHEMA,
schema_url=OTEL_SCHEMA,
)
return meter

Expand All @@ -301,5 +314,64 @@ def _get_tracer(tracer_provider: Optional["TracerProvider"] = None) -> "Tracer":
return trace.get_tracer(
__name__,
tracer_provider=tracer_provider,
schema_url=_OTEL_SCHEMA,
schema_url=OTEL_SCHEMA,
)


def _create_span_name(destination: str, action: str) -> str:
return f"{destination} {action}"


def _is_batch_message(msg: "StreamMessage[Any]") -> bool:
with_batch = baggage.get_baggage(
WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers)
)
return bool(msg.batch_headers or with_batch)


def _get_msg_links(msg: "StreamMessage[Any]") -> List[Link]:
if not msg.batch_headers:
if (span := _get_span_from_headers(msg.headers)) is not None:
return [Link(span.get_span_context())]
else:
return []

links = {}
counter: Dict[str, int] = defaultdict(lambda: 0)

for headers in msg.batch_headers:
if (correlation_id := headers.get("correlation_id")) is None:
continue

counter[correlation_id] += 1

if (span := _get_span_from_headers(headers)) is None:
continue

attributes = _get_link_attributes(counter[correlation_id])

links[correlation_id] = Link(
span.get_span_context(),
attributes=attributes,
)

return list(links.values())


def _get_span_from_headers(headers: "AnyDict") -> Optional[Span]:
trace_context = _TRACE_PROPAGATOR.extract(headers)
if not len(trace_context):
return None

return cast(
Optional[Span],
next(iter(trace_context.values())),
)


def _get_link_attributes(message_count: int) -> "Attributes":
if message_count <= 1:
return {}
return {
SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: message_count,
}
114 changes: 113 additions & 1 deletion tests/opentelemetry/confluent/test_confluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ async def test_batch(
)
broker = self.broker_class(middlewares=(mid,))
expected_msg_count = 3
expected_link_count = 1
expected_link_attrs = {"messaging.batch.message_count": 3}

@broker.subscriber(queue, batch=True, **self.subscriber_kwargs)
async def handler(m):
Expand All @@ -96,7 +98,7 @@ async def handler(m):

metrics = self.get_metrics(metric_reader)
spans = self.get_spans(trace_exporter)
_, publish, process = spans
_, publish, create_process, process = spans

assert (
publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT]
Expand All @@ -106,11 +108,121 @@ async def handler(m):
process.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT]
== expected_msg_count
)
assert len(create_process.links) == expected_link_count
assert create_process.links[0].attributes == expected_link_attrs
self.assert_metrics(metrics, count=expected_msg_count)

assert event.is_set()
mock.assert_called_once_with([1, "hi", 3])

async def test_batch_publish_with_single_consume(
self,
queue: str,
meter_provider: MeterProvider,
metric_reader: InMemoryMetricReader,
tracer_provider: TracerProvider,
trace_exporter: InMemorySpanExporter,
):
mid = self.telemetry_middleware_class(
meter_provider=meter_provider, tracer_provider=tracer_provider
)
broker = self.broker_class(middlewares=(mid,))
msgs_queue = asyncio.Queue(maxsize=3)
expected_msg_count = 3
expected_link_count = 1
expected_span_count = 8
expected_pub_batch_count = 1

@broker.subscriber(queue, **self.subscriber_kwargs)
async def handler(msg):
await msgs_queue.put(msg)

broker = self.patch_broker(broker)

async with broker:
await broker.start()
await broker.publish_batch(1, "hi", 3, topic=queue)
result, _ = await asyncio.wait(
(
asyncio.create_task(msgs_queue.get()),
asyncio.create_task(msgs_queue.get()),
asyncio.create_task(msgs_queue.get()),
),
timeout=self.timeout,
)

metrics = self.get_metrics(metric_reader)
proc_dur, proc_msg, pub_dur, pub_msg = metrics
spans = self.get_spans(trace_exporter)
publish = spans[1]
create_processes = [spans[2], spans[4], spans[6]]

assert len(spans) == expected_span_count
assert (
publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT]
== expected_msg_count
)
for cp in create_processes:
assert len(cp.links) == expected_link_count

assert proc_msg.data.data_points[0].value == expected_msg_count
assert pub_msg.data.data_points[0].value == expected_msg_count
assert proc_dur.data.data_points[0].count == expected_msg_count
assert pub_dur.data.data_points[0].count == expected_pub_batch_count

assert {1, "hi", 3} == {r.result() for r in result}

async def test_single_publish_with_batch_consume(
self,
event: asyncio.Event,
queue: str,
mock: Mock,
meter_provider: MeterProvider,
metric_reader: InMemoryMetricReader,
tracer_provider: TracerProvider,
trace_exporter: InMemorySpanExporter,
):
mid = self.telemetry_middleware_class(
meter_provider=meter_provider, tracer_provider=tracer_provider
)
broker = self.broker_class(middlewares=(mid,))
expected_msg_count = 2
expected_link_count = 2
expected_span_count = 6
expected_process_batch_count = 1

@broker.subscriber(queue, batch=True, **self.subscriber_kwargs)
async def handler(m):
m.sort()
mock(m)
event.set()

broker = self.patch_broker(broker)

async with broker:
await broker.start()
tasks = (
asyncio.create_task(broker.publish("hi", topic=queue)),
asyncio.create_task(broker.publish("buy", topic=queue)),
asyncio.create_task(event.wait()),
)
await asyncio.wait(tasks, timeout=self.timeout)

metrics = self.get_metrics(metric_reader)
proc_dur, proc_msg, pub_dur, pub_msg = metrics
spans = self.get_spans(trace_exporter)
create_process = spans[-2]

assert len(spans) == expected_span_count
assert len(create_process.links) == expected_link_count
assert proc_msg.data.data_points[0].value == expected_msg_count
assert pub_msg.data.data_points[0].value == expected_msg_count
assert proc_dur.data.data_points[0].count == expected_process_batch_count
assert pub_dur.data.data_points[0].count == expected_msg_count

assert event.is_set()
mock.assert_called_once_with(["buy", "hi"])


@pytest.mark.confluent()
class TestPublishWithTelemetry(TestPublish):
Expand Down
Loading

0 comments on commit 6bfca41

Please sign in to comment.