Skip to content

Commit

Permalink
Update flow run instrumentation to use RunTelemetry class (#16233)
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Pickett <[email protected]>
  • Loading branch information
jeanluciano and bunchesofdonald authored Dec 12, 2024
1 parent a8e0980 commit dce8eca
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 466 deletions.
146 changes: 63 additions & 83 deletions src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@

from anyio import CancelScope
from opentelemetry import propagate, trace
from opentelemetry.trace import Tracer, get_tracer
from typing_extensions import ParamSpec

import prefect
from prefect import Task
from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
from prefect.client.schemas import FlowRun, TaskRun
Expand Down Expand Up @@ -72,7 +70,12 @@
exception_to_failed_state,
return_value_to_state,
)
from prefect.telemetry.run_telemetry import OTELSetter
from prefect.telemetry.run_telemetry import (
LABELS_TRACEPARENT_KEY,
TRACEPARENT_KEY,
OTELSetter,
RunTelemetry,
)
from prefect.types import KeyValueLabels
from prefect.utilities._engine import get_hook_name, resolve_custom_flow_run_name
from prefect.utilities.annotations import NotSet
Expand All @@ -95,8 +98,6 @@

P = ParamSpec("P")
R = TypeVar("R")
LABELS_TRACEPARENT_KEY = "__OTEL_TRACEPARENT"
TRACEPARENT_KEY = "traceparent"


class FlowRunTimeoutError(TimeoutError):
Expand Down Expand Up @@ -136,10 +137,7 @@ class BaseFlowRunEngine(Generic[P, R]):
_is_started: bool = False
short_circuit: bool = False
_flow_run_name_set: bool = False
_tracer: Tracer = field(
default_factory=lambda: get_tracer("prefect", prefect.__version__)
)
_span: Optional[trace.Span] = None
_telemetry: RunTelemetry = field(default_factory=RunTelemetry)

def __post_init__(self):
if self.flow is None and self.flow_run_id is None:
Expand All @@ -152,21 +150,6 @@ def __post_init__(self):
def state(self) -> State:
return self.flow_run.state # type: ignore

def _end_span_on_success(self):
if not self._span:
return
self._span.set_status(trace.Status(trace.StatusCode.OK))
self._span.end(time.time_ns())
self._span = None

def _end_span_on_error(self, exc: BaseException, description: Optional[str]):
if not self._span:
return
self._span.record_exception(exc)
self._span.set_status(trace.Status(trace.StatusCode.ERROR, description))
self._span.end(time.time_ns())
self._span = None

def is_running(self) -> bool:
if getattr(self, "flow_run", None) is None:
return False
Expand All @@ -185,6 +168,7 @@ def _update_otel_labels(
self, span: trace.Span, client: Union[SyncPrefectClient, PrefectClient]
):
parent_flow_run_ctx = FlowRunContext.get()

if parent_flow_run_ctx and parent_flow_run_ctx.flow_run:
if traceparent := parent_flow_run_ctx.flow_run.labels.get(
LABELS_TRACEPARENT_KEY
Expand All @@ -194,6 +178,7 @@ def _update_otel_labels(
carrier={TRACEPARENT_KEY: traceparent},
setter=OTELSetter(),
)

else:
carrier: KeyValueLabels = {}
propagate.get_global_textmap().inject(
Expand Down Expand Up @@ -315,16 +300,7 @@ def set_state(self, state: State, force: bool = False) -> State:
self.flow_run.state_name = state.name # type: ignore
self.flow_run.state_type = state.type # type: ignore

if self._span:
self._span.add_event(
state.name or state.type,
{
"prefect.state.message": state.message or "",
"prefect.state.type": state.type,
"prefect.state.name": state.name or state.type,
"prefect.state.id": str(state.id),
},
)
self._telemetry.update_state(state)
return state

def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
Expand Down Expand Up @@ -374,7 +350,7 @@ def handle_success(self, result: R) -> R:
self.set_state(terminal_state)
self._return_value = resolved_result

self._end_span_on_success()
self._telemetry.end_span_on_success()

return result

Expand Down Expand Up @@ -406,8 +382,8 @@ def handle_exception(
)
state = self.set_state(Running())
self._raised = exc

self._end_span_on_error(exc, state.message)
self._telemetry.record_exception(exc)
self._telemetry.end_span_on_failure(state.message)

return state

Expand All @@ -426,17 +402,17 @@ def handle_timeout(self, exc: TimeoutError) -> None:
)
self.set_state(state)
self._raised = exc

self._end_span_on_error(exc, message)
self._telemetry.record_exception(exc)
self._telemetry.end_span_on_failure(message)

def handle_crash(self, exc: BaseException) -> None:
state = run_coro_as_sync(exception_to_crashed_state(exc))
self.logger.error(f"Crash detected! {state.message}")
self.logger.debug("Crash details:", exc_info=exc)
self.set_state(state, force=True)
self._raised = exc

self._end_span_on_error(exc, state.message if state else "")
self._telemetry.record_exception(exc)
self._telemetry.end_span_on_failure(state.message if state else None)

def load_subflow_run(
self,
Expand Down Expand Up @@ -680,20 +656,23 @@ def initialize_run(self):
flow_version=self.flow.version,
empirical_policy=self.flow_run.empirical_policy,
)

span = self._tracer.start_span(
name=self.flow_run.name,
attributes={
**self.flow_run.labels,
"prefect.run.type": "flow",
"prefect.run.id": str(self.flow_run.id),
"prefect.tags": self.flow_run.tags,
"prefect.flow.name": self.flow.name,
},
parent_flow_run = FlowRunContext.get()
parent_labels = {}
if parent_flow_run and parent_flow_run.flow_run:
parent_labels = parent_flow_run.flow_run.labels

self._telemetry.start_span(
name=self.flow.name,
run=self.flow_run,
parameters=self.parameters,
parent_labels=parent_labels,
)
self._update_otel_labels(span, self.client)

self._span = span
carrier = self._telemetry.propagate_traceparent()
if carrier:
self.client.update_flow_run_labels(
flow_run_id=self.flow_run.id,
labels={LABELS_TRACEPARENT_KEY: carrier[TRACEPARENT_KEY]},
)

try:
yield self
Expand Down Expand Up @@ -736,7 +715,9 @@ def initialize_run(self):
@contextmanager
def start(self) -> Generator[None, None, None]:
with self.initialize_run():
with trace.use_span(self._span) if self._span else nullcontext():
with trace.use_span(
self._telemetry.span
) if self._telemetry.span else nullcontext():
self.begin_run()

if self.state.is_running():
Expand Down Expand Up @@ -892,16 +873,7 @@ async def set_state(self, state: State, force: bool = False) -> State:
self.flow_run.state_name = state.name # type: ignore
self.flow_run.state_type = state.type # type: ignore

if self._span:
self._span.add_event(
state.name or state.type,
{
"prefect.state.message": state.message or "",
"prefect.state.type": state.type,
"prefect.state.name": state.name or state.type,
"prefect.state.id": str(state.id),
},
)
self._telemetry.update_state(state)
return state

async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
Expand Down Expand Up @@ -949,7 +921,7 @@ async def handle_success(self, result: R) -> R:
await self.set_state(terminal_state)
self._return_value = resolved_result

self._end_span_on_success()
self._telemetry.end_span_on_success()

return result

Expand Down Expand Up @@ -979,8 +951,8 @@ async def handle_exception(
)
state = await self.set_state(Running())
self._raised = exc

self._end_span_on_error(exc, state.message)
self._telemetry.record_exception(exc)
self._telemetry.end_span_on_failure(state.message)

return state

Expand All @@ -1000,7 +972,8 @@ async def handle_timeout(self, exc: TimeoutError) -> None:
await self.set_state(state)
self._raised = exc

self._end_span_on_error(exc, message)
self._telemetry.record_exception(exc)
self._telemetry.end_span_on_failure(message)

async def handle_crash(self, exc: BaseException) -> None:
# need to shield from asyncio cancellation to ensure we update the state
Expand All @@ -1012,7 +985,8 @@ async def handle_crash(self, exc: BaseException) -> None:
await self.set_state(state, force=True)
self._raised = exc

self._end_span_on_error(exc, state.message)
self._telemetry.record_exception(exc)
self._telemetry.end_span_on_failure(state.message)

async def load_subflow_run(
self,
Expand Down Expand Up @@ -1254,19 +1228,23 @@ async def initialize_run(self):
flow_version=self.flow.version,
empirical_policy=self.flow_run.empirical_policy,
)

span = self._tracer.start_span(
name=self.flow_run.name,
attributes={
**self.flow_run.labels,
"prefect.run.type": "flow",
"prefect.run.id": str(self.flow_run.id),
"prefect.tags": self.flow_run.tags,
"prefect.flow.name": self.flow.name,
},
parent_flow_run = FlowRunContext.get()
parent_labels = {}
if parent_flow_run and parent_flow_run.flow_run:
parent_labels = parent_flow_run.flow_run.labels

self._telemetry.start_span(
name=self.flow.name,
run=self.flow_run,
parameters=self.parameters,
parent_labels=parent_labels,
)
self._update_otel_labels(span, self.client)
self._span = span
carrier = self._telemetry.propagate_traceparent()
if carrier:
await self.client.update_flow_run_labels(
flow_run_id=self.flow_run.id,
labels={LABELS_TRACEPARENT_KEY: carrier[TRACEPARENT_KEY]},
)

try:
yield self
Expand Down Expand Up @@ -1309,7 +1287,9 @@ async def initialize_run(self):
@asynccontextmanager
async def start(self) -> AsyncGenerator[None, None]:
async with self.initialize_run():
with trace.use_span(self._span) if self._span else nullcontext():
with trace.use_span(
self._telemetry.span
) if self._telemetry.span else nullcontext():
await self.begin_run()

if self.state.is_running():
Expand Down
3 changes: 3 additions & 0 deletions src/prefect/server/models/flow_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
logger = get_logger("flow_runs")


logger = get_logger("flow_runs")


T = TypeVar("T", bound=tuple)


Expand Down
Loading

0 comments on commit dce8eca

Please sign in to comment.