Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Asset and Op ExecutionContexts in standard execution path #17972

Closed
129 changes: 116 additions & 13 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,62 @@
from .system import StepExecutionContext


class ExecutionProperties:
"""Information to be used by dagster internals during execution.
You should not need to access these attributes directly.
"""

def __init__(
self, step_description: str, node_type: str, op_def: "OpDefinition", op_config: Any
):
self._step_description = step_description
self._node_type = node_type
self._op_def = op_def
self._events: List[DagsterEvent] = []
self._requires_typed_event_stream = False
self._typed_event_stream_error_message = None
self._op_config = op_config

@property
def step_description(self) -> str:
return self._step_description

@property
def node_type(self) -> str:
return self._node_type

@property
def op_def(self) -> "OpDefinition":
return self._op_def

@property
def op_config(self) -> Any:
return self._op_config

def consume_events(self) -> Iterator[DagsterEvent]:
events = self._events
self._events = []
yield from events

def has_events(self) -> bool:
return bool(self._events)

def log_event(self, event: DagsterEvent):
self._events.append(event)

@property
def requires_typed_event_stream(self) -> bool:
return self._requires_typed_event_stream

@property
def typed_event_stream_error_message(self) -> Optional[str]:
return self._typed_event_stream_error_message

def set_requires_typed_event_stream(self, *, error_message: Optional[str] = None):
self._requires_typed_event_stream = True
self._typed_event_stream_error_message = error_message


# This metaclass has to exist for OpExecutionContext to have a metaclass
class AbstractComputeMetaclass(ABCMeta):
pass
Expand Down Expand Up @@ -118,6 +174,22 @@ def op_config(self) -> Any:
"""The parsed config specific to this op."""


class ContextHasExecutionProperties(ABC):
"""Base class that any context that can be used for execution or invocation of an op or asset
must implement.
"""

@property
@abstractmethod
def execution_properties(self) -> ExecutionProperties:
"""Context classes must contain an instance of ExecutionProperties."""

@property
@abstractmethod
def resources(self) -> Any:
"""Context classes must be able to provide currently available resources."""


class OpExecutionContextMetaClass(AbstractComputeMetaclass):
def __instancecheck__(cls, instance) -> bool:
# This makes isinstance(context, OpExecutionContext) throw a deprecation warning when
Expand All @@ -136,7 +208,11 @@ def __instancecheck__(cls, instance) -> bool:
return super().__instancecheck__(instance)


class OpExecutionContext(AbstractComputeExecutionContext, metaclass=OpExecutionContextMetaClass):
class OpExecutionContext(
AbstractComputeExecutionContext,
ContextHasExecutionProperties,
metaclass=OpExecutionContextMetaClass,
):
"""The ``context`` object that can be made available as the first argument to the function
used for computing an op or asset.

Expand Down Expand Up @@ -166,11 +242,25 @@ def __init__(self, step_execution_context: StepExecutionContext):
self._events: List[DagsterEvent] = []
self._output_metadata: Dict[str, Any] = {}

self._execution_props = ExecutionProperties( # TODO - maybe swap to this being None here and creating/caching in the property
step_description=self._step_execution_context.describe_op(),
node_type="op",
op_def=cast(
OpDefinition,
self._step_execution_context.job_def.get_node(self.node_handle).definition,
),
op_config=self._step_execution_context.op_config,
)

@property
def execution_properties(self) -> ExecutionProperties:
return self._execution_props

@public
@property
def op_config(self) -> Any:
"""Any: The parsed config specific to this op."""
return self._step_execution_context.op_config
return self.execution_properties.op_config

@property
def dagster_run(self) -> DagsterRun:
Expand Down Expand Up @@ -441,16 +531,14 @@ def run_tags(self) -> Mapping[str, str]:
return self._step_execution_context.run_tags

def has_events(self) -> bool:
return bool(self._events)
return self.execution_properties.has_events()

def consume_events(self) -> Iterator[DagsterEvent]:
"""Pops and yields all user-generated events that have been recorded from this context.

If consume_events has not yet been called, this will yield all logged events since the beginning of the op's computation. If consume_events has been called, it will yield all events since the last time consume_events was called. Designed for internal use. Users should never need to invoke this method.
"""
events = self._events
self._events = []
yield from events
yield from self.execution_properties.consume_events()

@public
def log_event(self, event: UserEvent) -> None:
Expand All @@ -472,17 +560,16 @@ def log_materialization(context):
context.log_event(AssetMaterialization("foo"))
"""
if isinstance(event, AssetMaterialization):
self._events.append(
DagsterEvent.asset_materialization(self._step_execution_context, event)
)
dagster_event = DagsterEvent.asset_materialization(self._step_execution_context, event)
elif isinstance(event, AssetObservation):
self._events.append(DagsterEvent.asset_observation(self._step_execution_context, event))
dagster_event = DagsterEvent.asset_observation(self._step_execution_context, event)
elif isinstance(event, ExpectationResult):
self._events.append(
DagsterEvent.step_expectation_result(self._step_execution_context, event)
dagster_event = DagsterEvent.step_expectation_result(
self._step_execution_context, event
)
else:
check.failed(f"Unexpected event {event}")
self.execution_properties.log_event(event=dagster_event)

@public
def add_output_metadata(
Expand Down Expand Up @@ -555,7 +642,7 @@ def retry_number(self) -> int:
return self._step_execution_context.previous_attempt_count

def describe_op(self) -> str:
return self._step_execution_context.describe_op()
return self.execution_properties.step_description

@public
def get_mapping_key(self) -> Optional[str]:
Expand Down Expand Up @@ -1420,6 +1507,11 @@ def __init__(self, op_execution_context: OpExecutionContext) -> None:
retry_number=self._op_execution_context.retry_number,
)

# TODO - confirm accuracy of this comment
# start execution_props as None since enter_execution_context builds an AssetExecutionContext
# for all steps (including ops) and ops will fail on self.assets_def call
self._execution_props = None

@staticmethod
def get() -> "AssetExecutionContext":
ctx = _current_asset_execution_context.get()
Expand All @@ -1435,6 +1527,17 @@ def op_execution_context(self) -> OpExecutionContext:
def run_properties(self) -> RunProperties:
return self._run_props

@property
def execution_properties(self) -> ExecutionProperties:
if self._execution_props is None:
self._execution_props = ExecutionProperties(
step_description=f"asset {self.op_execution_context.node_handle}",
node_type="asset",
op_def=self.op_execution_context.op_def,
op_config=self.op_execution_context.op_config,
)
return self._execution_props

######## Deprecated methods

@deprecated(**_get_deprecation_kwargs("run"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from dagster._utils.merger import merge_dicts
from dagster._utils.warnings import deprecation_warning

from .compute import OpExecutionContext
from .compute import ExecutionProperties, OpExecutionContext
from .system import StepExecutionContext, TypeCheckContext


Expand Down Expand Up @@ -106,18 +106,26 @@ def __new__(
)


class RunlessExecutionProperties:
"""Maintains information about the invocation that is updated during execution time. This information
needs to be available to the user once invocation is complete, so that they can assert on events and
outputs. It needs to be cleared before the context is used for another invocation.
class RunlessExecutionProperties(ExecutionProperties):
"""Maintains properties that need to be available to the execution code. To support runless execution
(direct invocation) this class also maintains information about the invocation that is updated
during execution time. This information needs to be available to the user once invocation is
complete, so that they can assert on events and outputs. It needs to be cleared before the
context is used for another invocation.
"""

def __init__(self):
def __init__(
self, step_description: str, node_type: str, op_def: "OpDefinition", op_config: Any
):
self._step_description = step_description
self._node_type = node_type
self._op_def = op_def
self._events: List[UserEvent] = []
self._seen_outputs = {}
self._output_metadata = {}
self._requires_typed_event_stream = False
self._typed_event_stream_error_message = None
self._op_config = op_config

@property
def user_events(self):
Expand All @@ -131,14 +139,6 @@ def seen_outputs(self):
def output_metadata(self):
return self._output_metadata

@property
def requires_typed_event_stream(self) -> bool:
return self._requires_typed_event_stream

@property
def typed_event_stream_error_message(self) -> Optional[str]:
return self._typed_event_stream_error_message

def log_event(self, event: UserEvent) -> None:
check.inst_param(
event,
Expand Down Expand Up @@ -292,7 +292,7 @@ def __init__(
# my_op(ctx)
# ctx._execution_properties.output_metadata # information is retained after invocation
# my_op(ctx) # ctx._execution_properties is cleared at the beginning of the next invocation
self._execution_properties = RunlessExecutionProperties()
self._execution_properties = None

def __enter__(self):
self._cm_scope_entered = True
Expand Down Expand Up @@ -326,9 +326,6 @@ def bind(
f"This context is currently being used to execute {self.alias}. The context cannot be used to execute another op until {self.alias} has finished executing."
)

# reset execution_properties
self._execution_properties = RunlessExecutionProperties()

# update the bound context with properties relevant to the execution of the op

invocation_tags = (
Expand Down Expand Up @@ -403,6 +400,11 @@ def bind(
step_description=step_description,
)

# reset execution_properties
self._execution_properties = RunlessExecutionProperties(
step_description=step_description, node_type="op", op_def=op_def, op_config=op_config
)

return self

def unbind(self):
Expand All @@ -414,6 +416,11 @@ def is_bound(self) -> bool:

@property
def execution_properties(self) -> RunlessExecutionProperties:
if self._execution_properties is None:
raise DagsterInvalidPropertyError(
"Cannot access execution_properties until after the context has been used to"
" invoke an op"
)
return self._execution_properties

@property
Expand Down
19 changes: 9 additions & 10 deletions python_modules/dagster/dagster/_core/execution/plan/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@
)
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import (
AssetExecutionContext,
OpExecutionContext,
ContextHasExecutionProperties,
)
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.system_config.objects import ResolvedRunConfig
Expand Down Expand Up @@ -154,7 +153,7 @@ def _yield_compute_results(
step_context: StepExecutionContext,
inputs: Mapping[str, Any],
compute_fn: OpComputeFunction,
compute_context: Union[OpExecutionContext, AssetExecutionContext],
compute_context: ContextHasExecutionProperties,
) -> Iterator[OpOutputUnion]:
user_event_generator = compute_fn(compute_context, inputs)

Expand All @@ -176,32 +175,32 @@ def _yield_compute_results(
if inspect.isasyncgen(user_event_generator):
user_event_generator = gen_from_async_gen(user_event_generator)

op_label = step_context.describe_op()
step_label = compute_context.execution_properties.step_description

for event in iterate_with_context(
lambda: op_execution_error_boundary(
DagsterExecutionStepExecutionError,
msg_fn=lambda: f"Error occurred while executing {op_label}:",
msg_fn=lambda: f"Error occurred while executing {step_label}:",
step_context=step_context,
step_key=step_context.step.key,
op_def_name=step_context.op_def.name,
op_name=step_context.op.name,
),
user_event_generator,
):
if compute_context.has_events():
yield from compute_context.consume_events()
if compute_context.execution_properties.has_events():
yield from compute_context.execution_properties.consume_events()
yield _validate_event(event, step_context)

if compute_context.has_events():
yield from compute_context.consume_events()
if compute_context.execution_properties.has_events():
yield from compute_context.execution_properties.consume_events()


def execute_core_compute(
step_context: StepExecutionContext,
inputs: Mapping[str, Any],
compute_fn: OpComputeFunction,
compute_context: Union[OpExecutionContext, AssetExecutionContext],
compute_context: ContextHasExecutionProperties,
) -> Iterator[OpOutputUnion]:
"""Execute the user-specified compute for the op. Wrap in an error boundary and do
all relevant logging and metrics tracking.
Expand Down
Loading