Skip to content

Commit

Permalink
do the if else thing
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Jan 2, 2024
1 parent 0561d2c commit 48e6d71
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
16 changes: 12 additions & 4 deletions python_modules/dagster/dagster/_core/execution/plan/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@
]


def _get_op_context(
context: Union[OpExecutionContext, AssetExecutionContext]
) -> OpExecutionContext:
if isinstance(context, AssetExecutionContext):
return context.op_execution_context
return context


def create_step_outputs(
node: Node,
handle: NodeHandle,
Expand Down Expand Up @@ -189,12 +197,12 @@ def _yield_compute_results(
),
user_event_generator,
):
if compute_context.has_events():
yield from compute_context.consume_events()
if _get_op_context(compute_context).has_events():
yield from _get_op_context(compute_context).consume_events()
yield _validate_event(event, step_context)

if compute_context.has_events():
yield from compute_context.consume_events()
if _get_op_context(compute_context).has_events():
yield from _get_op_context(compute_context).consume_events()


def execute_core_compute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,18 @@
from dagster._utils import is_named_tuple_instance
from dagster._utils.warnings import disable_dagster_warnings

from ..context.compute import OpExecutionContext
from ..context.compute import AssetExecutionContext, OpExecutionContext


def _get_op_context(
context: Union[OpExecutionContext, AssetExecutionContext]
) -> OpExecutionContext:
if isinstance(context, AssetExecutionContext):
return context.op_execution_context
return context


# called in execute_step if the fn is not decorated
def create_op_compute_wrapper(
op_def: OpDefinition,
) -> Callable[[OpExecutionContext, Mapping[str, InputDefinition]], Any]:
Expand Down Expand Up @@ -94,6 +103,7 @@ def compute(
return compute


# called in this file (create_op_compute_wrapper)
async def _coerce_async_op_to_async_gen(
awaitable: Awaitable[Any], context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> AsyncIterator[Any]:
Expand All @@ -102,6 +112,7 @@ async def _coerce_async_op_to_async_gen(
yield event


# called in this file, and in op_invocation for direct invocation
def invoke_compute_fn(
fn: Callable,
context: OpExecutionContext,
Expand All @@ -125,6 +136,7 @@ def invoke_compute_fn(
return fn(context, **args_to_pass) if context_arg_provided else fn(**args_to_pass)


# called in this file (create_op_compute_wrapper)
def _coerce_op_compute_fn_to_iterator(
fn, output_defs, context, context_arg_provided, kwargs, config_arg_class, resource_arg_mapping
):
Expand All @@ -135,6 +147,7 @@ def _coerce_op_compute_fn_to_iterator(
yield event


# called in this file (validate_and_coerce_op_result_to_iterator)
def _zip_and_iterate_op_result(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> Iterator[Tuple[int, Any, OutputDefinition]]:
Expand Down Expand Up @@ -162,6 +175,7 @@ def _zip_and_iterate_op_result(

# Filter out output_defs corresponding to asset check results that already exist on a
# MaterializeResult.
# called in this file (_zip_and_iterate_op_result)
def _filter_expected_output_defs(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> Sequence[OutputDefinition]:
Expand All @@ -177,6 +191,7 @@ def _filter_expected_output_defs(
return [out for out in output_defs if out.name not in remove_outputs]


# called in this file (_zip_and_iterate_op_result)
def _validate_multi_return(
context: OpExecutionContext,
result: Any,
Expand Down Expand Up @@ -212,6 +227,7 @@ def _validate_multi_return(
return result


# called in this file (validate_and_coerce_op_result_to_iterator)
def _get_annotation_for_output_position(
position: int, op_def: OpDefinition, output_defs: Sequence[OutputDefinition]
) -> Any:
Expand All @@ -226,6 +242,7 @@ def _get_annotation_for_output_position(
return inspect.Parameter.empty


# called in this file (validate_and_coerce_op_result_to_iterator)
def _check_output_object_name(
output: Union[DynamicOutput, Output], output_def: OutputDefinition, position: int
) -> None:
Expand All @@ -239,6 +256,7 @@ def _check_output_object_name(
)


# called in op_invocation and this file
def validate_and_coerce_op_result_to_iterator(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> Iterator[Any]:
Expand Down

0 comments on commit 48e6d71

Please sign in to comment.