diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute.py b/python_modules/dagster/dagster/_core/execution/plan/compute.py index 3a55902de60fe..5bebf7612c8f0 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute.py @@ -61,6 +61,14 @@ ] +def _get_op_context( + context: Union[OpExecutionContext, AssetExecutionContext] +) -> OpExecutionContext: + if isinstance(context, AssetExecutionContext): + return context.op_execution_context + return context + + def create_step_outputs( node: Node, handle: NodeHandle, @@ -189,12 +197,12 @@ def _yield_compute_results( ), user_event_generator, ): - if compute_context.has_events(): - yield from compute_context.consume_events() + if _get_op_context(compute_context).has_events(): + yield from _get_op_context(compute_context).consume_events() yield _validate_event(event, step_context) - if compute_context.has_events(): - yield from compute_context.consume_events() + if _get_op_context(compute_context).has_events(): + yield from _get_op_context(compute_context).consume_events() def execute_core_compute( @@ -245,7 +253,8 @@ def execute_core_compute( output.name for output in step.step_outputs # checks are required if we're in requires_typed_event_stream mode - if compute_context.requires_typed_event_stream or output.properties.asset_check_key + if _get_op_context(compute_context).requires_typed_event_stream + or output.properties.asset_check_key } omitted_outputs = expected_op_output_names.difference(emitted_result_names) if omitted_outputs: @@ -254,9 +263,10 @@ def execute_core_compute( f"expected outputs {omitted_outputs!r}." ) - if compute_context.requires_typed_event_stream: - if compute_context.typed_event_stream_error_message: - message += " " + compute_context.typed_event_stream_error_message + if _get_op_context(compute_context).requires_typed_event_stream: + error_message = _get_op_context(compute_context).typed_event_stream_error_message + if error_message: + message += " " + error_message raise DagsterInvariantViolationError(message) else: step_context.log.info(message) diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py b/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py index 375be39a7ea43..dd930fa0c98a6 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py @@ -36,9 +36,18 @@ from dagster._utils import is_named_tuple_instance from dagster._utils.warnings import disable_dagster_warnings -from ..context.compute import OpExecutionContext +from ..context.compute import AssetExecutionContext, OpExecutionContext +def _get_op_context( + context: Union[OpExecutionContext, AssetExecutionContext] +) -> OpExecutionContext: + if isinstance(context, AssetExecutionContext): + return context.op_execution_context + return context + + +# called in execute_step if the fn is not decorated def create_op_compute_wrapper( op_def: OpDefinition, ) -> Callable[[OpExecutionContext, Mapping[str, InputDefinition]], Any]: @@ -94,6 +103,7 @@ def compute( return compute +# called in this file (create_op_compute_wrapper) async def _coerce_async_op_to_async_gen( awaitable: Awaitable[Any], context: OpExecutionContext, output_defs: Sequence[OutputDefinition] ) -> AsyncIterator[Any]: @@ -102,6 +112,7 @@ async def _coerce_async_op_to_async_gen( yield event +# called in this file, and in op_invocation for direct invocation def invoke_compute_fn( fn: Callable, context: OpExecutionContext, @@ -125,6 +136,7 @@ def invoke_compute_fn( return fn(context, **args_to_pass) if context_arg_provided else fn(**args_to_pass) +# called in this file (create_op_compute_wrapper) def _coerce_op_compute_fn_to_iterator( fn, output_defs, context, context_arg_provided, kwargs, config_arg_class, resource_arg_mapping ): @@ -135,6 +147,7 @@ def _coerce_op_compute_fn_to_iterator( yield event +# called in this file (validate_and_coerce_op_result_to_iterator) def _zip_and_iterate_op_result( result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition] ) -> Iterator[Tuple[int, Any, OutputDefinition]]: @@ -162,6 +175,7 @@ def _zip_and_iterate_op_result( # Filter out output_defs corresponding to asset check results that already exist on a # MaterializeResult. +# called in this file (_zip_and_iterate_op_result) def _filter_expected_output_defs( result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition] ) -> Sequence[OutputDefinition]: @@ -177,6 +191,7 @@ def _filter_expected_output_defs( return [out for out in output_defs if out.name not in remove_outputs] +# called in this file (_zip_and_iterate_op_result) def _validate_multi_return( context: OpExecutionContext, result: Any, @@ -212,6 +227,7 @@ def _validate_multi_return( return result +# called in this file (validate_and_coerce_op_result_to_iterator) def _get_annotation_for_output_position( position: int, op_def: OpDefinition, output_defs: Sequence[OutputDefinition] ) -> Any: @@ -226,6 +242,7 @@ def _get_annotation_for_output_position( return inspect.Parameter.empty +# called in this file (validate_and_coerce_op_result_to_iterator) def _check_output_object_name( output: Union[DynamicOutput, Output], output_def: OutputDefinition, position: int ) -> None: @@ -239,6 +256,7 @@ def _check_output_object_name( ) +# called in op_invocation and this file def validate_and_coerce_op_result_to_iterator( result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition] ) -> Iterator[Any]: