Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: correct spans linking in batches case #1532

Merged
merged 9 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .codespell-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
dependant
unsecure
socio-economic
2 changes: 1 addition & 1 deletion docs/docs/en/release.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ You can find more information about it in the official [**aiokafka** doc](https:

`pattern` option was added too, but it is still experimental and does not support `Path`

3. [`Path`](https://faststream.airt.ai/latest/nats/message/#subject-pattern-access) feature performance was increased. Also, `Path` is suitable for NATS `PullSub` batch subscribtion as well now.
3. [`Path`](https://faststream.airt.ai/latest/nats/message/#subject-pattern-access) feature performance was increased. Also, `Path` is suitable for NATS `PullSub` batch subscription as well now.

```python
from faststream import NatsBroker, PullSub
Expand Down
2 changes: 2 additions & 0 deletions faststream/opentelemetry/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@ class MessageAction:
RECEIVE = "receive"


OTEL_SCHEMA = "https://opentelemetry.io/schemas/1.11.0"
ERROR_TYPE = "error.type"
MESSAGING_DESTINATION_PUBLISH_NAME = "messaging.destination_publish.name"
WITH_BATCH = "with_batch"
102 changes: 87 additions & 15 deletions faststream/opentelemetry/middleware.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,38 @@
import time
from collections import defaultdict
from copy import copy
from typing import TYPE_CHECKING, Any, Callable, Optional, Type
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, cast

from opentelemetry import context, metrics, propagate, trace
from opentelemetry import baggage, context, metrics, trace
from opentelemetry.baggage.propagation import W3CBaggagePropagator
from opentelemetry.context import Context
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Link, Span
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

from faststream import BaseMiddleware
from faststream.opentelemetry.consts import (
ERROR_TYPE,
MESSAGING_DESTINATION_PUBLISH_NAME,
OTEL_SCHEMA,
WITH_BATCH,
MessageAction,
)
from faststream.opentelemetry.provider import TelemetrySettingsProvider

if TYPE_CHECKING:
from types import TracebackType

from opentelemetry.context import Context
from opentelemetry.metrics import Meter, MeterProvider
from opentelemetry.trace import Span, Tracer, TracerProvider
from opentelemetry.trace import Tracer, TracerProvider
from opentelemetry.util.types import Attributes

from faststream.broker.message import StreamMessage
from faststream.types import AnyDict, AsyncFunc, AsyncFuncAny


_OTEL_SCHEMA = "https://opentelemetry.io/schemas/1.11.0"


def _create_span_name(destination: str, action: str) -> str:
return f"{destination} {action}"
_BAGGAGE_PROPAGATOR = W3CBaggagePropagator()
_TRACE_PROPAGATOR = TraceContextTextMapPropagator()


class _MetricsContainer:
Expand Down Expand Up @@ -139,12 +143,15 @@ async def publish_scope(
# NOTE: if batch with single message?
if (msg_count := len((msg, *args))) > 1:
trace_attributes[SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT] = msg_count
_BAGGAGE_PROPAGATOR.inject(
headers, baggage.set_baggage(WITH_BATCH, True, context=current_context)
)

if self._current_span and self._current_span.is_recording():
current_context = trace.set_span_in_context(
self._current_span, current_context
)
propagate.inject(headers, context=self._origin_context)
_TRACE_PROPAGATOR.inject(headers, context=self._origin_context)

else:
create_span = self._tracer.start_span(
Expand All @@ -153,7 +160,7 @@ async def publish_scope(
attributes=trace_attributes,
)
current_context = trace.set_span_in_context(create_span)
propagate.inject(headers, context=current_context)
_TRACE_PROPAGATOR.inject(headers, context=current_context)
create_span.end()

start_time = time.perf_counter()
Expand Down Expand Up @@ -188,9 +195,14 @@ async def consume_scope(
if (provider := self.__settings_provider) is None:
return await call_next(msg)

current_context = propagate.extract(msg.headers)
destination_name = provider.get_consume_destination_name(msg)
if _is_batch_message(msg):
links = _get_msg_links(msg)
current_context = Context()
else:
links = None
current_context = _TRACE_PROPAGATOR.extract(msg.headers)

destination_name = provider.get_consume_destination_name(msg)
trace_attributes = provider.get_consume_attrs_from_message(msg)
metrics_attributes = {
SpanAttributes.MESSAGING_SYSTEM: provider.messaging_system,
Expand All @@ -202,6 +214,7 @@ async def consume_scope(
name=_create_span_name(destination_name, MessageAction.CREATE),
kind=trace.SpanKind.CONSUMER,
attributes=trace_attributes,
links=links,
)
current_context = trace.set_span_in_context(create_span)
create_span.end()
Expand Down Expand Up @@ -292,7 +305,7 @@ def _get_meter(
return metrics.get_meter(
__name__,
meter_provider=meter_provider,
schema_url=_OTEL_SCHEMA,
schema_url=OTEL_SCHEMA,
)
return meter

Expand All @@ -301,5 +314,64 @@ def _get_tracer(tracer_provider: Optional["TracerProvider"] = None) -> "Tracer":
return trace.get_tracer(
__name__,
tracer_provider=tracer_provider,
schema_url=_OTEL_SCHEMA,
schema_url=OTEL_SCHEMA,
)


def _create_span_name(destination: str, action: str) -> str:
return f"{destination} {action}"


def _is_batch_message(msg: "StreamMessage[Any]") -> bool:
with_batch = baggage.get_baggage(
WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers)
)
return bool(msg.batch_headers or with_batch)


def _get_msg_links(msg: "StreamMessage[Any]") -> List[Link]:
if not msg.batch_headers:
if (span := _get_span_from_headers(msg.headers)) is not None:
return [Link(span.get_span_context())]
else:
return []

links = {}
counter: Dict[str, int] = defaultdict(lambda: 0)

for headers in msg.batch_headers:
if (correlation_id := headers.get("correlation_id")) is None:
continue

counter[correlation_id] += 1

if (span := _get_span_from_headers(headers)) is None:
continue

attributes = _get_link_attributes(counter[correlation_id])

links[correlation_id] = Link(
span.get_span_context(),
attributes=attributes,
)

return list(links.values())


def _get_span_from_headers(headers: "AnyDict") -> Optional[Span]:
trace_context = _TRACE_PROPAGATOR.extract(headers)
if not len(trace_context):
return None

return cast(
Optional[Span],
next(iter(trace_context.values())),
)


def _get_link_attributes(message_count: int) -> "Attributes":
if message_count <= 1:
return {}
return {
SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: message_count,
}
114 changes: 113 additions & 1 deletion tests/opentelemetry/confluent/test_confluent.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ async def test_batch(
)
broker = self.broker_class(middlewares=(mid,))
expected_msg_count = 3
expected_link_count = 1
expected_link_attrs = {"messaging.batch.message_count": 3}

@broker.subscriber(queue, batch=True, **self.subscriber_kwargs)
async def handler(m):
Expand All @@ -96,7 +98,7 @@ async def handler(m):

metrics = self.get_metrics(metric_reader)
spans = self.get_spans(trace_exporter)
_, publish, process = spans
_, publish, create_process, process = spans

assert (
publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT]
Expand All @@ -106,11 +108,121 @@ async def handler(m):
process.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT]
== expected_msg_count
)
assert len(create_process.links) == expected_link_count
assert create_process.links[0].attributes == expected_link_attrs
self.assert_metrics(metrics, count=expected_msg_count)

assert event.is_set()
mock.assert_called_once_with([1, "hi", 3])

async def test_batch_publish_with_single_consume(
self,
queue: str,
meter_provider: MeterProvider,
metric_reader: InMemoryMetricReader,
tracer_provider: TracerProvider,
trace_exporter: InMemorySpanExporter,
):
mid = self.telemetry_middleware_class(
meter_provider=meter_provider, tracer_provider=tracer_provider
)
broker = self.broker_class(middlewares=(mid,))
msgs_queue = asyncio.Queue(maxsize=3)
expected_msg_count = 3
expected_link_count = 1
expected_span_count = 8
expected_pub_batch_count = 1

@broker.subscriber(queue, **self.subscriber_kwargs)
async def handler(msg):
await msgs_queue.put(msg)

broker = self.patch_broker(broker)

async with broker:
await broker.start()
await broker.publish_batch(1, "hi", 3, topic=queue)
result, _ = await asyncio.wait(
(
asyncio.create_task(msgs_queue.get()),
asyncio.create_task(msgs_queue.get()),
asyncio.create_task(msgs_queue.get()),
),
timeout=self.timeout,
)

metrics = self.get_metrics(metric_reader)
proc_dur, proc_msg, pub_dur, pub_msg = metrics
spans = self.get_spans(trace_exporter)
publish = spans[1]
create_processes = [spans[2], spans[4], spans[6]]

assert len(spans) == expected_span_count
assert (
publish.attributes[SpanAttr.MESSAGING_BATCH_MESSAGE_COUNT]
== expected_msg_count
)
for cp in create_processes:
assert len(cp.links) == expected_link_count

assert proc_msg.data.data_points[0].value == expected_msg_count
assert pub_msg.data.data_points[0].value == expected_msg_count
assert proc_dur.data.data_points[0].count == expected_msg_count
assert pub_dur.data.data_points[0].count == expected_pub_batch_count

assert {1, "hi", 3} == {r.result() for r in result}

async def test_single_publish_with_batch_consume(
self,
event: asyncio.Event,
queue: str,
mock: Mock,
meter_provider: MeterProvider,
metric_reader: InMemoryMetricReader,
tracer_provider: TracerProvider,
trace_exporter: InMemorySpanExporter,
):
mid = self.telemetry_middleware_class(
meter_provider=meter_provider, tracer_provider=tracer_provider
)
broker = self.broker_class(middlewares=(mid,))
expected_msg_count = 2
expected_link_count = 2
expected_span_count = 6
expected_process_batch_count = 1

@broker.subscriber(queue, batch=True, **self.subscriber_kwargs)
async def handler(m):
m.sort()
mock(m)
event.set()

broker = self.patch_broker(broker)

async with broker:
await broker.start()
tasks = (
asyncio.create_task(broker.publish("hi", topic=queue)),
asyncio.create_task(broker.publish("buy", topic=queue)),
asyncio.create_task(event.wait()),
)
await asyncio.wait(tasks, timeout=self.timeout)

metrics = self.get_metrics(metric_reader)
proc_dur, proc_msg, pub_dur, pub_msg = metrics
spans = self.get_spans(trace_exporter)
create_process = spans[-2]

assert len(spans) == expected_span_count
assert len(create_process.links) == expected_link_count
assert proc_msg.data.data_points[0].value == expected_msg_count
assert pub_msg.data.data_points[0].value == expected_msg_count
assert proc_dur.data.data_points[0].count == expected_process_batch_count
assert pub_dur.data.data_points[0].count == expected_msg_count

assert event.is_set()
mock.assert_called_once_with(["buy", "hi"])


@pytest.mark.confluent()
class TestPublishWithTelemetry(TestPublish):
Expand Down
Loading
Loading