Skip to content

Commit

Permalink
Pick up trace context from run labels (#16346)
Browse files Browse the repository at this point in the history
Co-authored-by: nate nowack <[email protected]>
  • Loading branch information
bunchesofdonald and zzstoatzz authored Dec 12, 2024
1 parent 8f33592 commit 3abe9d0
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 44 deletions.
16 changes: 3 additions & 13 deletions src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,15 +664,10 @@ def initialize_run(self):
self._telemetry.start_span(
name=self.flow.name,
run=self.flow_run,
client=self.client,
parameters=self.parameters,
parent_labels=parent_labels,
)
carrier = self._telemetry.propagate_traceparent()
if carrier:
self.client.update_flow_run_labels(
flow_run_id=self.flow_run.id,
labels={LABELS_TRACEPARENT_KEY: carrier[TRACEPARENT_KEY]},
)

try:
yield self
Expand Down Expand Up @@ -1233,18 +1228,13 @@ async def initialize_run(self):
if parent_flow_run and parent_flow_run.flow_run:
parent_labels = parent_flow_run.flow_run.labels

self._telemetry.start_span(
await self._telemetry.async_start_span(
name=self.flow.name,
run=self.flow_run,
client=self.client,
parameters=self.parameters,
parent_labels=parent_labels,
)
carrier = self._telemetry.propagate_traceparent()
if carrier:
await self.client.update_flow_run_labels(
flow_run_id=self.flow_run.id,
labels={LABELS_TRACEPARENT_KEY: carrier[TRACEPARENT_KEY]},
)

try:
yield self
Expand Down
4 changes: 3 additions & 1 deletion src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ def initialize_run(
self._telemetry.start_span(
run=self.task_run,
name=self.task.name,
client=self.client,
parameters=self.parameters,
parent_labels=parent_labels,
)
Expand Down Expand Up @@ -1243,9 +1244,10 @@ async def initialize_run(
if parent_flow_run_context and parent_flow_run_context.flow_run:
parent_labels = parent_flow_run_context.flow_run.labels

self._telemetry.start_span(
await self._telemetry.async_start_span(
run=self.task_run,
name=self.task.name,
client=self.client,
parameters=self.parameters,
parent_labels=parent_labels,
)
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/telemetry/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _url_join(base_url: str, path: str) -> str:

def setup_exporters(
api_url: str, api_key: str
) -> tuple[TracerProvider, MeterProvider, "LoggerProvider"]:
) -> "tuple[TracerProvider, MeterProvider, LoggerProvider]":
account_id, workspace_id = extract_account_and_workspace_id(api_url)
telemetry_url = _url_join(api_url, "telemetry/")

Expand Down
17 changes: 10 additions & 7 deletions src/prefect/telemetry/processors.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import time
from threading import Event, Lock, Thread
from typing import Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional

from opentelemetry.context import Context
from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor
from opentelemetry.sdk.trace.export import SpanExporter
from opentelemetry.sdk.trace import Span, SpanProcessor

if TYPE_CHECKING:
from opentelemetry.sdk.trace import ReadableSpan, Span
from opentelemetry.sdk.trace.export import SpanExporter


class InFlightSpanProcessor(SpanProcessor):
def __init__(self, span_exporter: SpanExporter):
def __init__(self, span_exporter: "SpanExporter"):
self.span_exporter = span_exporter
self._in_flight: Dict[int, Span] = {}
self._lock = Lock()
Expand All @@ -26,7 +29,7 @@ def _export_periodically(self) -> None:
if to_export:
self.span_exporter.export(to_export)

def _readable_span(self, span: Span) -> ReadableSpan:
def _readable_span(self, span: "Span") -> "ReadableSpan":
readable = span._readable_span()
readable._end_time = time.time_ns()
readable._attributes = {
Expand All @@ -35,13 +38,13 @@ def _readable_span(self, span: Span) -> ReadableSpan:
}
return readable

def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None:
def on_start(self, span: "Span", parent_context: Optional[Context] = None) -> None:
if not span.context or not span.context.trace_flags.sampled:
return
with self._lock:
self._in_flight[span.context.span_id] = span

def on_end(self, span: ReadableSpan) -> None:
def on_end(self, span: "ReadableSpan") -> None:
if not span.context or not span.context.trace_flags.sampled:
return
with self._lock:
Expand Down
90 changes: 84 additions & 6 deletions src/prefect/telemetry/run_telemetry.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

from opentelemetry import propagate, trace
from opentelemetry.context import Context
from opentelemetry.propagators.textmap import Setter
from opentelemetry.trace import (
Span,
Status,
StatusCode,
get_tracer,
)
from typing_extensions import TypeAlias

import prefect
from prefect.client.orchestration import PrefectClient, SyncPrefectClient
from prefect.client.schemas import FlowRun, TaskRun
from prefect.client.schemas.objects import State
from prefect.context import FlowRunContext
Expand All @@ -23,6 +26,8 @@
LABELS_TRACEPARENT_KEY = "__OTEL_TRACEPARENT"
TRACEPARENT_KEY = "traceparent"

FlowOrTaskRun: TypeAlias = Union[FlowRun, TaskRun]


class OTELSetter(Setter[KeyValueLabels]):
"""
Expand All @@ -44,13 +49,47 @@ class RunTelemetry:
)
span: Optional[Span] = None

async def async_start_span(
self,
run: FlowOrTaskRun,
client: PrefectClient,
name: Optional[str] = None,
parameters: Optional[dict[str, Any]] = None,
parent_labels: Optional[dict[str, Any]] = None,
):
should_set_traceparent = self._should_set_traceparent(run)
traceparent, span = self._start_span(run, name, parameters, parent_labels)

if should_set_traceparent and traceparent:
await client.update_flow_run_labels(
run.id, {LABELS_TRACEPARENT_KEY: traceparent}
)

return span

def start_span(
self,
run: Union[TaskRun, FlowRun],
run: FlowOrTaskRun,
client: SyncPrefectClient,
name: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
parent_labels: Optional[Dict[str, Any]] = None,
parameters: Optional[dict[str, Any]] = None,
parent_labels: Optional[dict[str, Any]] = None,
):
should_set_traceparent = self._should_set_traceparent(run)
traceparent, span = self._start_span(run, name, parameters, parent_labels)

if should_set_traceparent and traceparent:
client.update_flow_run_labels(run.id, {LABELS_TRACEPARENT_KEY: traceparent})

return span

def _start_span(
self,
run: FlowOrTaskRun,
name: Optional[str] = None,
parameters: Optional[dict[str, Any]] = None,
parent_labels: Optional[dict[str, Any]] = None,
) -> tuple[Optional[str], Span]:
"""
Start a span for a task run.
"""
Expand All @@ -62,10 +101,15 @@ def start_span(
f"prefect.run.parameter.{k}": type(v).__name__
for k, v in parameters.items()
}
run_type = "task" if isinstance(run, TaskRun) else "flow"

traceparent, context = self._traceparent_and_context_from_labels(
{**parent_labels, **run.labels}
)
run_type = self._run_type(run)

self.span = self._tracer.start_span(
name=name or run.name,
context=context,
attributes={
f"prefect.{run_type}.name": name or run.name,
"prefect.run.type": run_type,
Expand All @@ -75,7 +119,41 @@ def start_span(
**parent_labels,
},
)
return self.span

if not traceparent:
traceparent = self._traceparent_from_span(self.span)

if traceparent and LABELS_TRACEPARENT_KEY not in run.labels:
run.labels[LABELS_TRACEPARENT_KEY] = traceparent

return traceparent, self.span

def _run_type(self, run: FlowOrTaskRun) -> str:
return "task" if isinstance(run, TaskRun) else "flow"

def _should_set_traceparent(self, run: FlowOrTaskRun) -> bool:
# If the run is a flow run and it doesn't already have a traceparent,
# we need to update its labels with the traceparent so that its
# propagated to child runs. Task runs are updated via events so we
# don't need to update them via the client in the same way.
return (
LABELS_TRACEPARENT_KEY not in run.labels and self._run_type(run) == "flow"
)

def _traceparent_and_context_from_labels(
self, labels: Optional[KeyValueLabels]
) -> tuple[Optional[str], Optional[Context]]:
"""Get trace context from run labels if it exists."""
if not labels or LABELS_TRACEPARENT_KEY not in labels:
return None, None
traceparent = labels[LABELS_TRACEPARENT_KEY]
carrier = {TRACEPARENT_KEY: traceparent}
return str(traceparent), propagate.extract(carrier)

def _traceparent_from_span(self, span: Span) -> Optional[str]:
carrier = {}
propagate.inject(carrier, context=trace.set_span_in_context(span))
return carrier.get(TRACEPARENT_KEY)

def end_span_on_success(self) -> None:
"""
Expand Down
84 changes: 68 additions & 16 deletions tests/telemetry/test_instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
from prefect import flow, task
from prefect.client.orchestration import SyncPrefectClient
from prefect.context import FlowRunContext
from prefect.task_engine import (
run_task_async,
run_task_sync,
)
from prefect.flow_engine import run_flow_async, run_flow_sync
from prefect.task_engine import run_task_async, run_task_sync
from prefect.telemetry.bootstrap import setup_telemetry
from prefect.telemetry.instrumentation import (
extract_account_and_workspace_id,
)
from prefect.telemetry.logging import get_log_handler
from prefect.telemetry.processors import InFlightSpanProcessor
from prefect.telemetry.run_telemetry import LABELS_TRACEPARENT_KEY


def test_extract_account_and_workspace_id_valid_url(
Expand Down Expand Up @@ -181,6 +180,67 @@ async def engine_type(
) -> Literal["async", "sync"]:
return request.param

async def test_traceparent_propagates_from_server_side(
self,
engine_type: Literal["async", "sync"],
instrumentation: InstrumentationTester,
sync_prefect_client: SyncPrefectClient,
):
"""Test that when no parent traceparent exists, the flow run stores its own span's traceparent"""

@flow
async def my_async_flow():
pass

@flow
def my_sync_flow():
pass

if engine_type == "async":
the_flow = my_async_flow
else:
the_flow = my_sync_flow

flow_run = sync_prefect_client.create_flow_run(the_flow) # type: ignore

# Give the flow run a traceparent. This can occur when the server has
# already created a trace for the run, likely because it was Late.
#
# Trace ID: 314419354619557650326501540139523824930
# Span ID: 5357380918965115138
sync_prefect_client.update_flow_run_labels(
flow_run.id,
{
LABELS_TRACEPARENT_KEY: "00-ec8af70b445d54387035c27eb182dd22-4a593d8fa95f1902-01"
},
)

flow_run = sync_prefect_client.read_flow_run(flow_run.id)
assert flow_run.labels[LABELS_TRACEPARENT_KEY] == (
"00-ec8af70b445d54387035c27eb182dd22-4a593d8fa95f1902-01"
)

if engine_type == "async":
await run_flow_async(the_flow, flow_run=flow_run) # type: ignore
else:
run_flow_sync(the_flow, flow_run=flow_run) # type: ignore

assert flow_run.labels[LABELS_TRACEPARENT_KEY] == (
"00-ec8af70b445d54387035c27eb182dd22-4a593d8fa95f1902-01"
)

spans = instrumentation.get_finished_spans()
assert len(spans) == 1
span = spans[0]

span_context = span.get_span_context()
assert span_context is not None
assert span_context.trace_id == 314419354619557650326501540139523824930

assert span.parent is not None
assert span.parent.trace_id == 314419354619557650326501540139523824930
assert span.parent.span_id == 5357380918965115138

async def test_flow_run_creates_and_stores_otel_traceparent(
self,
engine_type: Literal["async", "sync"],
Expand Down Expand Up @@ -249,20 +309,12 @@ def sync_child_flow() -> str:
return "hello from child"

@flow(name="parent-flow")
async def async_parent_flow() -> str:
# Set OTEL context in the parent flow's labels
flow_run = FlowRunContext.get().flow_run
mock_traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"
flow_run.labels["__OTEL_TRACEPARENT"] = mock_traceparent
return await async_child_flow()
async def async_parent_flow():
await async_child_flow()

@flow(name="parent-flow")
def sync_parent_flow() -> str:
# Set OTEL context in the parent flow's labels
flow_run = FlowRunContext.get().flow_run
mock_traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"
flow_run.labels["__OTEL_TRACEPARENT"] = mock_traceparent
return sync_child_flow()
def sync_parent_flow():
sync_child_flow()

parent_flow = async_parent_flow if engine_type == "async" else sync_parent_flow
await parent_flow() if engine_type == "async" else parent_flow()
Expand Down

0 comments on commit 3abe9d0

Please sign in to comment.