Skip to content

Commit

Permalink
refactor: optimize OTEL links resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik committed Jun 19, 2024
1 parent 959c5e4 commit 36997f2
Showing 1 changed file with 55 additions and 50 deletions.
105 changes: 55 additions & 50 deletions faststream/opentelemetry/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,55 +35,6 @@
_TRACE_PROPAGATOR = TraceContextTextMapPropagator()


def _create_span_name(destination: str, action: str) -> str:
return f"{destination} {action}"


def _is_batch_message(msg: "StreamMessage[Any]") -> bool:
with_batch = baggage.get_baggage(
WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers)
)
return bool(msg.batch_headers or with_batch)


def _get_span_link(headers: "AnyDict", attributes: "Attributes") -> Optional[Link]:
trace_context = _TRACE_PROPAGATOR.extract(headers)
if not len(trace_context):
return None
span_context = next(iter(trace_context.values()))
if not isinstance(span_context, Span):
return None
return Link(span_context.get_span_context(), attributes=attributes)


def _get_link_attributes(message_count: int) -> "Attributes":
if message_count <= 1:
return {}
return {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: message_count}


def _get_span_links(msg: "StreamMessage[Any]") -> List[Link]:
if not msg.batch_headers:
link = _get_span_link(msg.headers, {})
return [link] if link else []

links = {}
counter: Dict[str, int] = defaultdict(lambda: 0)

for headers in msg.batch_headers:
correlation_id = headers.get("correlation_id")
if correlation_id is None:
continue
counter[correlation_id] += 1
attributes = _get_link_attributes(counter[correlation_id])
link = _get_span_link(headers, attributes)
if link is None:
continue
links[correlation_id] = link

return list(links.values())


class _MetricsContainer:
__slots__ = (
"include_messages_counters",
Expand Down Expand Up @@ -245,7 +196,7 @@ async def consume_scope(
return await call_next(msg)

if _is_batch_message(msg):
links = _get_span_links(msg)
links = _get_msg_links(msg)
current_context = Context()
else:
links = None
Expand Down Expand Up @@ -365,3 +316,57 @@ def _get_tracer(tracer_provider: Optional["TracerProvider"] = None) -> "Tracer":
tracer_provider=tracer_provider,
schema_url=OTEL_SCHEMA,
)


def _create_span_name(destination: str, action: str) -> str:
return f"{destination} {action}"


def _is_batch_message(msg: "StreamMessage[Any]") -> bool:
with_batch = baggage.get_baggage(
WITH_BATCH, _BAGGAGE_PROPAGATOR.extract(msg.headers)
)
return bool(msg.batch_headers or with_batch)


def _get_msg_links(msg: "StreamMessage[Any]") -> List[Link]:
if not msg.batch_headers:
if (span := _get_span_from_headers(msg.headers)) is not None:
return [Link(span.get_span_context())]
else:
return []

links = {}
counter: Dict[str, int] = defaultdict(lambda: 0)

for headers in msg.batch_headers:
if (correlation_id := headers.get("correlation_id")) is None:
continue

counter[correlation_id] += 1

if (span := _get_span_from_headers(headers)) is None:
continue

attributes = _get_link_attributes(counter[correlation_id])

links[correlation_id] = Link(
span.get_span_context(),
attributes=attributes,
)

return list(links.values())


def _get_span_from_headers(headers: "AnyDict", attributes: "Attributes") -> Optional[Span]:
trace_context = _TRACE_PROPAGATOR.extract(headers)
if not len(trace_context):
return None

return next(iter(trace_context.values()))


def _get_link_attributes(message_count: int) -> "Attributes":
if message_count <= 1:
return {}
return {SpanAttributes.MESSAGING_BATCH_MESSAGE_COUNT: message_count}

0 comments on commit 36997f2

Please sign in to comment.