From dce8eca583a2322d96aef3acba5988a6e6642bf7 Mon Sep 17 00:00:00 2001 From: Jean Luciano Date: Thu, 12 Dec 2024 10:52:32 -0600 Subject: [PATCH] Update flow run instrumentation to use `RunTelemetry` class (#16233) Co-authored-by: Chris Pickett --- src/prefect/flow_engine.py | 146 +++++------ src/prefect/server/models/flow_runs.py | 3 + src/prefect/task_engine.py | 54 ++-- src/prefect/telemetry/run_telemetry.py | 95 +++++-- tests/telemetry/test_instrumentation.py | 253 +++++++++++++++++- tests/test_flow_engine.py | 328 +----------------------- 6 files changed, 413 insertions(+), 466 deletions(-) diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index c37154a09cdf..45a18c35246f 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -24,10 +24,8 @@ from anyio import CancelScope from opentelemetry import propagate, trace -from opentelemetry.trace import Tracer, get_tracer from typing_extensions import ParamSpec -import prefect from prefect import Task from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client from prefect.client.schemas import FlowRun, TaskRun @@ -72,7 +70,12 @@ exception_to_failed_state, return_value_to_state, ) -from prefect.telemetry.run_telemetry import OTELSetter +from prefect.telemetry.run_telemetry import ( + LABELS_TRACEPARENT_KEY, + TRACEPARENT_KEY, + OTELSetter, + RunTelemetry, +) from prefect.types import KeyValueLabels from prefect.utilities._engine import get_hook_name, resolve_custom_flow_run_name from prefect.utilities.annotations import NotSet @@ -95,8 +98,6 @@ P = ParamSpec("P") R = TypeVar("R") -LABELS_TRACEPARENT_KEY = "__OTEL_TRACEPARENT" -TRACEPARENT_KEY = "traceparent" class FlowRunTimeoutError(TimeoutError): @@ -136,10 +137,7 @@ class BaseFlowRunEngine(Generic[P, R]): _is_started: bool = False short_circuit: bool = False _flow_run_name_set: bool = False - _tracer: Tracer = field( - default_factory=lambda: get_tracer("prefect", prefect.__version__) - ) - _span: Optional[trace.Span] = None + _telemetry: RunTelemetry = field(default_factory=RunTelemetry) def __post_init__(self): if self.flow is None and self.flow_run_id is None: @@ -152,21 +150,6 @@ def __post_init__(self): def state(self) -> State: return self.flow_run.state # type: ignore - def _end_span_on_success(self): - if not self._span: - return - self._span.set_status(trace.Status(trace.StatusCode.OK)) - self._span.end(time.time_ns()) - self._span = None - - def _end_span_on_error(self, exc: BaseException, description: Optional[str]): - if not self._span: - return - self._span.record_exception(exc) - self._span.set_status(trace.Status(trace.StatusCode.ERROR, description)) - self._span.end(time.time_ns()) - self._span = None - def is_running(self) -> bool: if getattr(self, "flow_run", None) is None: return False @@ -185,6 +168,7 @@ def _update_otel_labels( self, span: trace.Span, client: Union[SyncPrefectClient, PrefectClient] ): parent_flow_run_ctx = FlowRunContext.get() + if parent_flow_run_ctx and parent_flow_run_ctx.flow_run: if traceparent := parent_flow_run_ctx.flow_run.labels.get( LABELS_TRACEPARENT_KEY @@ -194,6 +178,7 @@ def _update_otel_labels( carrier={TRACEPARENT_KEY: traceparent}, setter=OTELSetter(), ) + else: carrier: KeyValueLabels = {} propagate.get_global_textmap().inject( @@ -315,16 +300,7 @@ def set_state(self, state: State, force: bool = False) -> State: self.flow_run.state_name = state.name # type: ignore self.flow_run.state_type = state.type # type: ignore - if self._span: - self._span.add_event( - state.name or state.type, - { - "prefect.state.message": state.message or "", - "prefect.state.type": state.type, - "prefect.state.name": state.name or state.type, - "prefect.state.id": str(state.id), - }, - ) + self._telemetry.update_state(state) return state def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": @@ -374,7 +350,7 @@ def handle_success(self, result: R) -> R: self.set_state(terminal_state) self._return_value = resolved_result - self._end_span_on_success() + self._telemetry.end_span_on_success() return result @@ -406,8 +382,8 @@ def handle_exception( ) state = self.set_state(Running()) self._raised = exc - - self._end_span_on_error(exc, state.message) + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(state.message) return state @@ -426,8 +402,8 @@ def handle_timeout(self, exc: TimeoutError) -> None: ) self.set_state(state) self._raised = exc - - self._end_span_on_error(exc, message) + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(message) def handle_crash(self, exc: BaseException) -> None: state = run_coro_as_sync(exception_to_crashed_state(exc)) @@ -435,8 +411,8 @@ def handle_crash(self, exc: BaseException) -> None: self.logger.debug("Crash details:", exc_info=exc) self.set_state(state, force=True) self._raised = exc - - self._end_span_on_error(exc, state.message if state else "") + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(state.message if state else None) def load_subflow_run( self, @@ -680,20 +656,23 @@ def initialize_run(self): flow_version=self.flow.version, empirical_policy=self.flow_run.empirical_policy, ) - - span = self._tracer.start_span( - name=self.flow_run.name, - attributes={ - **self.flow_run.labels, - "prefect.run.type": "flow", - "prefect.run.id": str(self.flow_run.id), - "prefect.tags": self.flow_run.tags, - "prefect.flow.name": self.flow.name, - }, + parent_flow_run = FlowRunContext.get() + parent_labels = {} + if parent_flow_run and parent_flow_run.flow_run: + parent_labels = parent_flow_run.flow_run.labels + + self._telemetry.start_span( + name=self.flow.name, + run=self.flow_run, + parameters=self.parameters, + parent_labels=parent_labels, ) - self._update_otel_labels(span, self.client) - - self._span = span + carrier = self._telemetry.propagate_traceparent() + if carrier: + self.client.update_flow_run_labels( + flow_run_id=self.flow_run.id, + labels={LABELS_TRACEPARENT_KEY: carrier[TRACEPARENT_KEY]}, + ) try: yield self @@ -736,7 +715,9 @@ def initialize_run(self): @contextmanager def start(self) -> Generator[None, None, None]: with self.initialize_run(): - with trace.use_span(self._span) if self._span else nullcontext(): + with trace.use_span( + self._telemetry.span + ) if self._telemetry.span else nullcontext(): self.begin_run() if self.state.is_running(): @@ -892,16 +873,7 @@ async def set_state(self, state: State, force: bool = False) -> State: self.flow_run.state_name = state.name # type: ignore self.flow_run.state_type = state.type # type: ignore - if self._span: - self._span.add_event( - state.name or state.type, - { - "prefect.state.message": state.message or "", - "prefect.state.type": state.type, - "prefect.state.name": state.name or state.type, - "prefect.state.id": str(state.id), - }, - ) + self._telemetry.update_state(state) return state async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": @@ -949,7 +921,7 @@ async def handle_success(self, result: R) -> R: await self.set_state(terminal_state) self._return_value = resolved_result - self._end_span_on_success() + self._telemetry.end_span_on_success() return result @@ -979,8 +951,8 @@ async def handle_exception( ) state = await self.set_state(Running()) self._raised = exc - - self._end_span_on_error(exc, state.message) + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(state.message) return state @@ -1000,7 +972,8 @@ async def handle_timeout(self, exc: TimeoutError) -> None: await self.set_state(state) self._raised = exc - self._end_span_on_error(exc, message) + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(message) async def handle_crash(self, exc: BaseException) -> None: # need to shield from asyncio cancellation to ensure we update the state @@ -1012,7 +985,8 @@ async def handle_crash(self, exc: BaseException) -> None: await self.set_state(state, force=True) self._raised = exc - self._end_span_on_error(exc, state.message) + self._telemetry.record_exception(exc) + self._telemetry.end_span_on_failure(state.message) async def load_subflow_run( self, @@ -1254,19 +1228,23 @@ async def initialize_run(self): flow_version=self.flow.version, empirical_policy=self.flow_run.empirical_policy, ) - - span = self._tracer.start_span( - name=self.flow_run.name, - attributes={ - **self.flow_run.labels, - "prefect.run.type": "flow", - "prefect.run.id": str(self.flow_run.id), - "prefect.tags": self.flow_run.tags, - "prefect.flow.name": self.flow.name, - }, + parent_flow_run = FlowRunContext.get() + parent_labels = {} + if parent_flow_run and parent_flow_run.flow_run: + parent_labels = parent_flow_run.flow_run.labels + + self._telemetry.start_span( + name=self.flow.name, + run=self.flow_run, + parameters=self.parameters, + parent_labels=parent_labels, ) - self._update_otel_labels(span, self.client) - self._span = span + carrier = self._telemetry.propagate_traceparent() + if carrier: + await self.client.update_flow_run_labels( + flow_run_id=self.flow_run.id, + labels={LABELS_TRACEPARENT_KEY: carrier[TRACEPARENT_KEY]}, + ) try: yield self @@ -1309,7 +1287,9 @@ async def initialize_run(self): @asynccontextmanager async def start(self) -> AsyncGenerator[None, None]: async with self.initialize_run(): - with trace.use_span(self._span) if self._span else nullcontext(): + with trace.use_span( + self._telemetry.span + ) if self._telemetry.span else nullcontext(): await self.begin_run() if self.state.is_running(): diff --git a/src/prefect/server/models/flow_runs.py b/src/prefect/server/models/flow_runs.py index 5db2ff750956..a454fff8f1d4 100644 --- a/src/prefect/server/models/flow_runs.py +++ b/src/prefect/server/models/flow_runs.py @@ -52,6 +52,9 @@ logger = get_logger("flow_runs") +logger = get_logger("flow_runs") + + T = TypeVar("T", bound=tuple) diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index d6b834d6bca6..15053d5016ae 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -4,7 +4,7 @@ import threading import time from asyncio import CancelledError -from contextlib import ExitStack, asynccontextmanager, contextmanager +from contextlib import ExitStack, asynccontextmanager, contextmanager, nullcontext from dataclasses import dataclass, field from functools import partial from textwrap import dedent @@ -523,7 +523,7 @@ def handle_success(self, result: R, transaction: Transaction) -> R: self.set_state(terminal_state) self._return_value = result - self._telemetry.end_span_on_success(terminal_state.message) + self._telemetry.end_span_on_success() return result def handle_retry(self, exc: Exception) -> bool: @@ -586,7 +586,7 @@ def handle_exception(self, exc: Exception) -> None: self.record_terminal_state_timing(state) self.set_state(state) self._raised = exc - self._telemetry.end_span_on_failure(state.message) + self._telemetry.end_span_on_failure(state.message if state else None) def handle_timeout(self, exc: TimeoutError) -> None: if not self.handle_retry(exc): @@ -612,7 +612,7 @@ def handle_crash(self, exc: BaseException) -> None: self.set_state(state, force=True) self._raised = exc self._telemetry.record_exception(exc) - self._telemetry.end_span_on_failure(state.message) + self._telemetry.end_span_on_failure(state.message if state else None) @contextmanager def setup_run_context(self, client: Optional[SyncPrefectClient] = None): @@ -670,7 +670,7 @@ def initialize_run( with SyncClientContext.get_or_create() as client_ctx: self._client = client_ctx.client self._is_started = True - flow_run_context = FlowRunContext.get() + parent_flow_run_context = FlowRunContext.get() parent_task_run_context = TaskRunContext.get() try: @@ -679,7 +679,7 @@ def initialize_run( self.task.create_local_run( id=task_run_id, parameters=self.parameters, - flow_run_context=flow_run_context, + flow_run_context=parent_flow_run_context, parent_task_run_context=parent_task_run_context, wait_for=self.wait_for, extra_task_inputs=dependencies, @@ -697,11 +697,16 @@ def initialize_run( self.logger.debug( f"Created task run {self.task_run.name!r} for task {self.task.name!r}" ) - labels = ( - flow_run_context.flow_run.labels if flow_run_context else {} - ) + + parent_labels = {} + if parent_flow_run_context and parent_flow_run_context.flow_run: + parent_labels = parent_flow_run_context.flow_run.labels + self._telemetry.start_span( - self.task_run, self.parameters, labels + run=self.task_run, + name=self.task.name, + parameters=self.parameters, + parent_labels=parent_labels, ) yield self @@ -755,7 +760,9 @@ def start( dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, ) -> Generator[None, None, None]: with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies): - with trace.use_span(self._telemetry._span): + with trace.use_span( + self._telemetry.span + ) if self._telemetry.span else nullcontext(): self.begin_run() try: yield @@ -1058,7 +1065,7 @@ async def handle_success(self, result: R, transaction: Transaction) -> R: await self.set_state(terminal_state) self._return_value = result - self._telemetry.end_span_on_success(terminal_state.message) + self._telemetry.end_span_on_success() return result @@ -1206,15 +1213,16 @@ async def initialize_run( async with AsyncClientContext.get_or_create(): self._client = get_client() self._is_started = True - flow_run_context = FlowRunContext.get() + parent_flow_run_context = FlowRunContext.get() + parent_task_run_context = TaskRunContext.get() try: if not self.task_run: self.task_run = await self.task.create_local_run( id=task_run_id, parameters=self.parameters, - flow_run_context=flow_run_context, - parent_task_run_context=TaskRunContext.get(), + flow_run_context=parent_flow_run_context, + parent_task_run_context=parent_task_run_context, wait_for=self.wait_for, extra_task_inputs=dependencies, ) @@ -1231,11 +1239,15 @@ async def initialize_run( f"Created task run {self.task_run.name!r} for task {self.task.name!r}" ) - labels = ( - flow_run_context.flow_run.labels if flow_run_context else {} - ) + parent_labels = {} + if parent_flow_run_context and parent_flow_run_context.flow_run: + parent_labels = parent_flow_run_context.flow_run.labels + self._telemetry.start_span( - self.task_run, self.parameters, labels + run=self.task_run, + name=self.task.name, + parameters=self.parameters, + parent_labels=parent_labels, ) yield self @@ -1291,7 +1303,9 @@ async def start( async with self.initialize_run( task_run_id=task_run_id, dependencies=dependencies ): - with trace.use_span(self._telemetry._span): + with trace.use_span( + self._telemetry.span + ) if self._telemetry.span else nullcontext(): await self.begin_run() try: yield diff --git a/src/prefect/telemetry/run_telemetry.py b/src/prefect/telemetry/run_telemetry.py index 08de1a2ebd0b..bb7cc81de5f9 100644 --- a/src/prefect/telemetry/run_telemetry.py +++ b/src/prefect/telemetry/run_telemetry.py @@ -1,22 +1,28 @@ import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from opentelemetry import propagate, trace from opentelemetry.propagators.textmap import Setter from opentelemetry.trace import ( + Span, Status, StatusCode, get_tracer, ) import prefect -from prefect.client.schemas import TaskRun +from prefect.client.schemas import FlowRun, TaskRun from prefect.client.schemas.objects import State +from prefect.context import FlowRunContext from prefect.types import KeyValueLabels if TYPE_CHECKING: from opentelemetry.trace import Tracer +LABELS_TRACEPARENT_KEY = "__OTEL_TRACEPARENT" +TRACEPARENT_KEY = "traceparent" + class OTELSetter(Setter[KeyValueLabels]): """ @@ -36,67 +42,74 @@ class RunTelemetry: _tracer: "Tracer" = field( default_factory=lambda: get_tracer("prefect", prefect.__version__) ) - _span = None + span: Optional[Span] = None def start_span( self, - task_run: TaskRun, + run: Union[TaskRun, FlowRun], + name: Optional[str] = None, parameters: Optional[Dict[str, Any]] = None, - labels: Optional[Dict[str, Any]] = None, + parent_labels: Optional[Dict[str, Any]] = None, ): """ Start a span for a task run. """ if parameters is None: parameters = {} - if labels is None: - labels = {} + if parent_labels is None: + parent_labels = {} parameter_attributes = { f"prefect.run.parameter.{k}": type(v).__name__ for k, v in parameters.items() } - self._span = self._tracer.start_span( - name=task_run.name, + run_type = "task" if isinstance(run, TaskRun) else "flow" + + self.span = self._tracer.start_span( + name=name or run.name, attributes={ - "prefect.run.type": "task", - "prefect.run.id": str(task_run.id), - "prefect.tags": task_run.tags, + f"prefect.{run_type}.name": name or run.name, + "prefect.run.type": run_type, + "prefect.run.id": str(run.id), + "prefect.tags": run.tags, **parameter_attributes, - **labels, + **parent_labels, }, ) + return self.span - def end_span_on_success(self, terminal_message: str) -> None: + def end_span_on_success(self) -> None: """ End a span for a task run on success. """ - if self._span: - self._span.set_status(Status(StatusCode.OK), terminal_message) - self._span.end(time.time_ns()) - self._span = None + if self.span: + self.span.set_status(Status(StatusCode.OK)) + self.span.end(time.time_ns()) + self.span = None - def end_span_on_failure(self, terminal_message: str) -> None: + def end_span_on_failure(self, terminal_message: Optional[str] = None) -> None: """ End a span for a task run on failure. """ - if self._span: - self._span.set_status(Status(StatusCode.ERROR, terminal_message)) - self._span.end(time.time_ns()) - self._span = None + if self.span: + self.span.set_status( + Status(StatusCode.ERROR, terminal_message or "Run failed") + ) + self.span.end(time.time_ns()) + self.span = None - def record_exception(self, exc: Exception) -> None: + def record_exception(self, exc: BaseException) -> None: """ Record an exception on a span. """ - if self._span: - self._span.record_exception(exc) + if self.span: + self.span.record_exception(exc) def update_state(self, new_state: State) -> None: """ Update a span with the state of a task run. """ - if self._span: - self._span.add_event( + if self.span: + self.span.add_event( new_state.name or new_state.type, { "prefect.state.message": new_state.message or "", @@ -105,3 +118,29 @@ def update_state(self, new_state: State) -> None: "prefect.state.id": str(new_state.id), }, ) + + def propagate_traceparent(self) -> Optional[KeyValueLabels]: + """ + Propagate a traceparent to a span. + """ + parent_flow_run_ctx = FlowRunContext.get() + + if parent_flow_run_ctx and parent_flow_run_ctx.flow_run: + if traceparent := parent_flow_run_ctx.flow_run.labels.get( + LABELS_TRACEPARENT_KEY + ): + carrier: KeyValueLabels = {TRACEPARENT_KEY: traceparent} + propagate.get_global_textmap().inject( + carrier={TRACEPARENT_KEY: traceparent}, + setter=OTELSetter(), + ) + return carrier + else: + if self.span: + carrier: KeyValueLabels = {} + propagate.get_global_textmap().inject( + carrier, + context=trace.set_span_in_context(self.span), + setter=OTELSetter(), + ) + return carrier diff --git a/tests/telemetry/test_instrumentation.py b/tests/telemetry/test_instrumentation.py index 6fae35a96895..ecb7377be899 100644 --- a/tests/telemetry/test_instrumentation.py +++ b/tests/telemetry/test_instrumentation.py @@ -1,4 +1,5 @@ import os +from typing import Literal from uuid import UUID, uuid4 import pytest @@ -13,7 +14,10 @@ from opentelemetry.sdk.trace import TracerProvider from tests.telemetry.instrumentation_tester import InstrumentationTester +import prefect from prefect import flow, task +from prefect.client.orchestration import SyncPrefectClient +from prefect.context import FlowRunContext from prefect.task_engine import ( run_task_async, run_task_sync, @@ -170,9 +174,215 @@ def test_logger_provider( assert log_handler._logger_provider == logger_provider +class TestFlowRunInstrumentation: + @pytest.fixture(params=["async", "sync"]) + async def engine_type( + self, request: pytest.FixtureRequest + ) -> Literal["async", "sync"]: + return request.param + + async def test_flow_run_creates_and_stores_otel_traceparent( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + sync_prefect_client: SyncPrefectClient, + ): + """Test that when no parent traceparent exists, the flow run stores its own span's traceparent""" + + @flow(name="child-flow") + async def async_child_flow() -> str: + return "hello from child" + + @flow(name="child-flow") + def sync_child_flow() -> str: + return "hello from child" + + @flow(name="parent-flow") + async def async_parent_flow() -> str: + return await async_child_flow() + + @flow(name="parent-flow") + def sync_parent_flow() -> str: + return sync_child_flow() + + if engine_type == "async": + await async_parent_flow() + else: + sync_parent_flow() + + spans = instrumentation.get_finished_spans() + + next( + span + for span in spans + if span.attributes.get("prefect.flow.name") == "parent-flow" + ) + child_span = next( + span + for span in spans + if span.attributes.get("prefect.flow.name") == "child-flow" + ) + + # Get the child flow run + child_flow_run_id = child_span.attributes.get("prefect.run.id") + child_flow_run = sync_prefect_client.read_flow_run(UUID(child_flow_run_id)) + + # Verify the child flow run has its span's traceparent in its labels + assert "__OTEL_TRACEPARENT" in child_flow_run.labels + assert child_flow_run.labels["__OTEL_TRACEPARENT"].startswith("00-") + trace_id_hex = child_flow_run.labels["__OTEL_TRACEPARENT"].split("-")[1] + assert int(trace_id_hex, 16) == child_span.context.trace_id + + async def test_flow_run_propagates_otel_traceparent_to_subflow( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): + """Test that OTEL traceparent gets propagated from parent flow to child flow""" + + @flow(name="child-flow") + async def async_child_flow() -> str: + return "hello from child" + + @flow(name="child-flow") + def sync_child_flow() -> str: + return "hello from child" + + @flow(name="parent-flow") + async def async_parent_flow() -> str: + # Set OTEL context in the parent flow's labels + flow_run = FlowRunContext.get().flow_run + mock_traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + flow_run.labels["__OTEL_TRACEPARENT"] = mock_traceparent + return await async_child_flow() + + @flow(name="parent-flow") + def sync_parent_flow() -> str: + # Set OTEL context in the parent flow's labels + flow_run = FlowRunContext.get().flow_run + mock_traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + flow_run.labels["__OTEL_TRACEPARENT"] = mock_traceparent + return sync_child_flow() + + parent_flow = async_parent_flow if engine_type == "async" else sync_parent_flow + await parent_flow() if engine_type == "async" else parent_flow() + + spans = instrumentation.get_finished_spans() + + parent_span = next( + span + for span in spans + if span.attributes.get("prefect.flow.name") == "parent-flow" + ) + child_span = next( + span + for span in spans + if span.attributes.get("prefect.flow.name") == "child-flow" + ) + + assert parent_span is not None + assert child_span is not None + assert child_span.context.trace_id == parent_span.context.trace_id + + async def test_flow_run_instrumentation( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): + @flow(name="instrumented-flow") + async def async_flow() -> str: + return 42 + + @flow(name="instrumented-flow") + def sync_flow() -> str: + return 42 + + test_flow = async_flow if engine_type == "async" else sync_flow + await test_flow() if engine_type == "async" else test_flow() + + spans = instrumentation.get_finished_spans() + assert len(spans) == 1 + + span = spans[0] + assert span is not None + instrumentation.assert_span_instrumented_for(span, prefect) + + instrumentation.assert_has_attributes( + span, + { + "prefect.flow.name": "instrumented-flow", + "prefect.run.type": "flow", + }, + ) + + async def test_flow_run_inherits_parent_labels( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + sync_prefect_client: SyncPrefectClient, + ): + """Test that parent flow labels get propagated to child flow spans""" + + @flow(name="child-flow") + async def async_child_flow() -> str: + return "hello from child" + + @flow(name="child-flow") + def sync_child_flow() -> str: + return "hello from child" + + @flow(name="parent-flow") + async def async_parent_flow() -> str: + # Set custom labels in parent flow + flow_run = FlowRunContext.get().flow_run + flow_run.labels.update( + {"test.label": "test-value", "environment": "testing"} + ) + return await async_child_flow() + + @flow(name="parent-flow") + def sync_parent_flow() -> str: + # Set custom labels in parent flow + flow_run = FlowRunContext.get().flow_run + flow_run.labels.update( + {"test.label": "test-value", "environment": "testing"} + ) + return sync_child_flow() + + if engine_type == "async": + state = await async_parent_flow(return_state=True) + else: + state = sync_parent_flow(return_state=True) + + spans = instrumentation.get_finished_spans() + child_spans = [ + span + for span in spans + if span.attributes.get("prefect.flow.name") == "child-flow" + ] + assert len(child_spans) == 1 + + # Get the parent flow run + parent_flow_run = sync_prefect_client.read_flow_run( + state.state_details.flow_run_id + ) + + # Verify the child span has the parent flow's labels + instrumentation.assert_has_attributes( + child_spans[0], + { + **parent_flow_run.labels, + "prefect.run.type": "flow", + "prefect.flow.name": "child-flow", + }, + ) + + class TestTaskRunInstrumentation: @pytest.fixture(params=["async", "sync"]) - async def engine_type(self, request): + async def engine_type( + self, request: pytest.FixtureRequest + ) -> Literal["async", "sync"]: return request.param async def run_task(self, task, task_run_id, parameters, engine_type): @@ -184,7 +394,9 @@ async def run_task(self, task, task_run_id, parameters, engine_type): return run_task_sync(task, task_run_id=task_run_id, parameters=parameters) async def test_span_creation( - self, engine_type, instrumentation: InstrumentationTester + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, ): @task async def async_task(x: int, y: int): @@ -213,7 +425,11 @@ def sync_task(x: int, y: int): ) assert spans[0].name == task_fn.name - async def test_span_attributes(self, engine_type, instrumentation): + async def test_span_attributes( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): @task async def async_task(x: int, y: int): return x + y @@ -245,7 +461,11 @@ def sync_task(x: int, y: int): ) assert spans[0].name == task_fn.__name__ - async def test_span_events(self, engine_type, instrumentation): + async def test_span_events( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): @task async def async_task(x: int, y: int): return x + y @@ -270,7 +490,11 @@ def sync_task(x: int, y: int): assert events[0].name == "Running" assert events[1].name == "Completed" - async def test_span_status_on_success(self, engine_type, instrumentation): + async def test_span_status_on_success( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): @task async def async_task(x: int, y: int): return x + y @@ -293,7 +517,11 @@ def sync_task(x: int, y: int): assert len(spans) == 1 assert spans[0].status.status_code == trace.StatusCode.OK - async def test_span_status_on_failure(self, engine_type, instrumentation): + async def test_span_status_on_failure( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): @task async def async_task(x: int, y: int): raise ValueError("Test error") @@ -318,7 +546,11 @@ def sync_task(x: int, y: int): assert spans[0].status.status_code == trace.StatusCode.ERROR assert "Test error" in spans[0].status.description - async def test_span_exception_recording(self, engine_type, instrumentation): + async def test_span_exception_recording( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + ): @task async def async_task(x: int, y: int): raise Exception("Test error") @@ -347,7 +579,12 @@ def sync_task(x: int, y: int): assert exception_event.attributes["exception.type"] == "Exception" assert exception_event.attributes["exception.message"] == "Test error" - async def test_flow_labels(self, engine_type, instrumentation, sync_prefect_client): + async def test_flow_labels( + self, + engine_type: Literal["async", "sync"], + instrumentation: InstrumentationTester, + sync_prefect_client: SyncPrefectClient, + ): """Test that parent flow ID gets propagated to task spans""" @task diff --git a/tests/test_flow_engine.py b/tests/test_flow_engine.py index 9807f8c219cd..9ceeae8241bd 100644 --- a/tests/test_flow_engine.py +++ b/tests/test_flow_engine.py @@ -10,9 +10,7 @@ import anyio import pydantic import pytest -from opentelemetry import trace -import prefect from prefect import Flow, __development_base_path__, flow, task from prefect.client.orchestration import PrefectClient, SyncPrefectClient from prefect.client.schemas.filters import FlowFilter, FlowRunFilter @@ -49,8 +47,6 @@ from prefect.utilities.callables import get_call_parameters from prefect.utilities.filesystem import tmpchdir -from .telemetry.instrumentation_tester import InstrumentationTester - @flow async def foo(): @@ -615,7 +611,7 @@ def my_flow(): # after a flow run retry, the stale value will be pulled from the cache. async def test_flow_retry_with_no_error_in_flow_and_one_failed_child_flow( - self, sync_prefect_client: SyncPrefectClient + self, sync_prefect_client ): child_run_count = 0 flow_run_count = 0 @@ -1853,325 +1849,3 @@ async def expensive_flow(): concurrency_limit_v2.name ) assert response.active_slots == 0 - - -class TestFlowRunInstrumentation: - def test_flow_run_instrumentation(self, instrumentation: InstrumentationTester): - @flow - def instrumented_flow(): - from prefect.states import Completed - - return Completed(message="The flow is with you") - - instrumented_flow() - - spans = instrumentation.get_finished_spans() - assert len(spans) == 1 - span = spans[0] - assert span is not None - instrumentation.assert_span_instrumented_for(span, prefect) - - instrumentation.assert_has_attributes( - span, - { - "prefect.run.type": "flow", - "prefect.tags": (), - "prefect.flow.name": "instrumented-flow", - "prefect.run.id": mock.ANY, - }, - ) - assert span.status.status_code == trace.StatusCode.OK - - assert len(span.events) == 2 - assert span.events[0].name == "Running" - instrumentation.assert_has_attributes( - span.events[0], - { - "prefect.state.message": "", - "prefect.state.type": StateType.RUNNING, - "prefect.state.name": "Running", - "prefect.state.id": mock.ANY, - }, - ) - - assert span.events[1].name == "Completed" - instrumentation.assert_has_attributes( - span.events[1], - { - "prefect.state.message": "The flow is with you", - "prefect.state.type": StateType.COMPLETED, - "prefect.state.name": "Completed", - "prefect.state.id": mock.ANY, - }, - ) - - def test_flow_run_instrumentation_captures_tags( - self, - instrumentation: InstrumentationTester, - ): - from prefect import tags - - @flow - def instrumented_flow(): - pass - - with tags("foo", "bar"): - instrumented_flow() - - spans = instrumentation.get_finished_spans() - assert len(spans) == 1 - span = spans[0] - assert span is not None - instrumentation.assert_span_instrumented_for(span, prefect) - - instrumentation.assert_has_attributes( - span, - { - "prefect.run.type": "flow", - "prefect.flow.name": "instrumented-flow", - "prefect.run.id": mock.ANY, - }, - ) - # listy span attributes are serialized to tuples -- order seems nondeterministic so ignore rather than flake - assert set(span.attributes.get("prefect.tags")) == {"foo", "bar"} # type: ignore - assert span.status.status_code == trace.StatusCode.OK - - def test_flow_run_instrumentation_captures_labels( - self, - instrumentation: InstrumentationTester, - sync_prefect_client: SyncPrefectClient, - ): - @flow - def instrumented_flow(): - pass - - state = instrumented_flow(return_state=True) - - assert state.state_details.flow_run_id is not None - flow_run = sync_prefect_client.read_flow_run(state.state_details.flow_run_id) - - spans = instrumentation.get_finished_spans() - assert len(spans) == 1 - span = spans[0] - assert span is not None - instrumentation.assert_span_instrumented_for(span, prefect) - - instrumentation.assert_has_attributes( - span, - { - **flow_run.labels, - "prefect.run.type": "flow", - "prefect.flow.name": "instrumented-flow", - "prefect.run.id": mock.ANY, - }, - ) - - def test_flow_run_instrumentation_on_exception( - self, instrumentation: InstrumentationTester - ): - @flow - def a_broken_flow(): - raise Exception("This flow broke!") - - with pytest.raises(Exception): - a_broken_flow() - - spans = instrumentation.get_finished_spans() - assert len(spans) == 1 - span = spans[0] - assert span is not None - instrumentation.assert_span_instrumented_for(span, prefect) - - instrumentation.assert_has_attributes( - span, - { - "prefect.run.type": "flow", - "prefect.tags": (), - "prefect.flow.name": "a-broken-flow", - "prefect.run.id": mock.ANY, - }, - ) - - assert span.status.status_code == trace.StatusCode.ERROR - assert ( - span.status.description - == "Flow run encountered an exception: Exception: This flow broke!" - ) - - assert len(span.events) == 3 - assert span.events[0].name == "Running" - instrumentation.assert_has_attributes( - span.events[0], - { - "prefect.state.message": "", - "prefect.state.type": StateType.RUNNING, - "prefect.state.name": "Running", - "prefect.state.id": mock.ANY, - }, - ) - - assert span.events[1].name == "Failed" - instrumentation.assert_has_attributes( - span.events[1], - { - "prefect.state.message": "Flow run encountered an exception: Exception: This flow broke!", - "prefect.state.type": StateType.FAILED, - "prefect.state.name": "Failed", - "prefect.state.id": mock.ANY, - }, - ) - - assert span.events[2].name == "exception" - instrumentation.assert_has_attributes( - span.events[2], - { - "exception.type": "Exception", - "exception.message": "This flow broke!", - "exception.stacktrace": mock.ANY, - "exception.escaped": "False", - }, - ) - - def test_flow_run_instrumentation_on_timeout( - self, instrumentation: InstrumentationTester - ): - @flow(timeout_seconds=0.1) - def a_slow_flow(): - time.sleep(1) - - with pytest.raises(TimeoutError): - a_slow_flow() - - spans = instrumentation.get_finished_spans() - assert len(spans) == 1 - span = spans[0] - assert span is not None - instrumentation.assert_span_instrumented_for(span, prefect) - - instrumentation.assert_has_attributes( - span, - { - "prefect.run.type": "flow", - "prefect.tags": (), - "prefect.flow.name": "a-slow-flow", - "prefect.run.id": mock.ANY, - }, - ) - - assert span.status.status_code == trace.StatusCode.ERROR - assert span.status.description == "Flow run exceeded timeout of 0.1 second(s)" - - assert len(span.events) == 3 - assert span.events[0].name == "Running" - instrumentation.assert_has_attributes( - span.events[0], - { - "prefect.state.message": "", - "prefect.state.type": StateType.RUNNING, - "prefect.state.name": "Running", - "prefect.state.id": mock.ANY, - }, - ) - - assert span.events[1].name == "TimedOut" - instrumentation.assert_has_attributes( - span.events[1], - { - "prefect.state.message": "Flow run exceeded timeout of 0.1 second(s)", - "prefect.state.type": StateType.FAILED, - "prefect.state.name": "TimedOut", - "prefect.state.id": mock.ANY, - }, - ) - - assert span.events[2].name == "exception" - instrumentation.assert_has_attributes( - span.events[2], - { - "exception.type": "prefect.flow_engine.FlowRunTimeoutError", - "exception.message": "Scope timed out after 0.1 second(s).", - "exception.stacktrace": mock.ANY, - "exception.escaped": "False", - }, - ) - - async def test_flow_run_propagates_otel_traceparent_to_subflow( - self, instrumentation: InstrumentationTester - ): - """Test that OTEL traceparent gets propagated from parent flow to child flow""" - - @flow - def child_flow(): - return "hello from child" - - @flow - def parent_flow(): - flow_run_ctx = FlowRunContext.get() - assert flow_run_ctx - assert flow_run_ctx.flow_run - flow_run = flow_run_ctx.flow_run - mock_traceparent = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" - flow_run.labels["__OTEL_TRACEPARENT"] = mock_traceparent - - return child_flow() - - parent_flow() - - spans = instrumentation.get_finished_spans() - - parent_span = next( - span - for span in spans - if span.attributes - and span.attributes.get("prefect.flow.name") == "parent-flow" - ) - child_span = next( - span - for span in spans - if span.attributes - and span.attributes.get("prefect.flow.name") == "child-flow" - ) - - assert parent_span is not None - assert child_span is not None - assert child_span.context and parent_span.context - assert child_span.context.trace_id == parent_span.context.trace_id - - async def test_flow_run_creates_and_stores_otel_traceparent( - self, instrumentation: InstrumentationTester, sync_prefect_client - ): - """Test that when no parent traceparent exists, the flow run stores its own span's traceparent""" - - @flow - def child_flow(): - return "hello from child" - - @flow - def parent_flow(): - return child_flow() - - parent_flow() - - spans = instrumentation.get_finished_spans() - - next( - span - for span in spans - if span.attributes - and span.attributes.get("prefect.flow.name") == "parent-flow" - ) - child_span = next( - span - for span in spans - if span.attributes - and span.attributes.get("prefect.flow.name") == "child-flow" - ) - - child_flow_run_id = child_span.attributes.get("prefect.run.id") - assert child_flow_run_id - child_flow_run = sync_prefect_client.read_flow_run(UUID(child_flow_run_id)) - - assert "__OTEL_TRACEPARENT" in child_flow_run.labels - assert child_flow_run.labels["__OTEL_TRACEPARENT"].startswith("00-") - trace_id_hex = child_flow_run.labels["__OTEL_TRACEPARENT"].split("-")[1] - assert int(trace_id_hex, 16) == child_span.context.trace_id