From 1be35f34a86fcc7047ec2ef6dd7905f844ac98b0 Mon Sep 17 00:00:00 2001 From: JamieDeMaria Date: Wed, 13 Sep 2023 18:37:59 -0400 Subject: [PATCH] play around with providing asset execution context --- .../definitions/decorators/op_decorator.py | 7 + .../_core/execution/context/compute.py | 52 +++++ .../dagster/_core/execution/plan/compute.py | 4 +- .../execution_tests/test_context.py | 187 +++++++++++++++++- 4 files changed, 247 insertions(+), 3 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/decorators/op_decorator.py b/python_modules/dagster/dagster/_core/definitions/decorators/op_decorator.py index 328641b66b056..a03d6cf23a1e1 100644 --- a/python_modules/dagster/dagster/_core/definitions/decorators/op_decorator.py +++ b/python_modules/dagster/dagster/_core/definitions/decorators/op_decorator.py @@ -286,6 +286,13 @@ def name(self): def has_context_arg(self) -> bool: return is_context_provided(get_function_params(self.decorated_fn)) + def get_context_arg(self) -> Parameter: + for param in get_function_params(self.decorated_fn): + if param.name == "context": + return param + + check.failed("Requested context arg on function that does not have one") + @lru_cache(maxsize=1) def _get_function_params(self) -> Sequence[Parameter]: return get_function_params(self.decorated_fn) diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index 12a6eae2b0d86..0c0fbf7205c82 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -9,6 +9,7 @@ Optional, Sequence, Set, + Union, cast, ) @@ -21,6 +22,7 @@ DataVersion, extract_data_provenance_from_entry, ) +from dagster._core.definitions.decorators.op_decorator import DecoratedOpFunction from dagster._core.definitions.dependency import Node, NodeHandle from dagster._core.definitions.events import ( AssetKey, @@ -36,6 +38,7 @@ from dagster._core.definitions.step_launcher import StepLauncher from dagster._core.definitions.time_window_partitions import TimeWindow from dagster._core.errors import ( + DagsterInvalidDefinitionError, DagsterInvalidPropertyError, DagsterInvariantViolationError, ) @@ -1483,3 +1486,52 @@ def asset_check_spec(self) -> AssetCheckSpec: def partition_key_range_for_asset_key(self, asset_key: AssetKey) -> PartitionKeyRange: """TODO - implement in stacked pr.""" pass + + +def build_execution_context( + step_context: StepExecutionContext, +) -> Union[OpExecutionContext, AssetExecutionContext]: + """Get the correct context based on the type of step (op or asset) and the user provided context + type annotation. Follows these rules. + + step type annotation result + asset AssetExecutionContext AssetExecutionContext + asset OpExecutionContext AssetExecutionContext - with deprecation warning + asset None AssetExecutionContext + op AssetExecutionContext Error + op OpExecutionContext OpExecutionContext + op None OpExecutionContext + + """ + is_sda_step = step_context.is_sda_step + is_asset_context = False + is_op_context = False + + compute_fn = step_context.op_def._compute_fn # noqa: SLF001 + if isinstance(compute_fn, DecoratedOpFunction) and compute_fn.has_context_arg(): + context_param = compute_fn.get_context_arg() + is_asset_context = context_param.annotation is AssetExecutionContext + is_op_context = context_param.annotation is OpExecutionContext + + if is_asset_context and not is_sda_step: + raise DagsterInvalidDefinitionError( + "When executed in jobs, the op context should be annotated with OpExecutionContext, not" + " AssetExecutionContext." + ) + + if is_op_context and is_sda_step: + deprecation_warning( + "Contexts with type annotation OpExecutionContext for @assets, @multi_assets," + " @graph_asset, and @graph_multi_asset.", + "1.7.0", + additional_warn_text="Please annotate the context with AssetExecutionContext", + stacklevel=1, + ) + + op_context = OpExecutionContext(step_context) + + # TODO - determine special casing for graph backed assets + + if is_sda_step: + return AssetExecutionContext(op_context) + return op_context diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute.py b/python_modules/dagster/dagster/_core/execution/plan/compute.py index 549c091f655a2..146da55fd3caf 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute.py @@ -32,7 +32,7 @@ from dagster._core.definitions.result import MaterializeResult from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError from dagster._core.events import DagsterEvent -from dagster._core.execution.context.compute import OpExecutionContext +from dagster._core.execution.context.compute import build_execution_context from dagster._core.execution.context.system import StepExecutionContext from dagster._core.system_config.objects import ResolvedRunConfig from dagster._utils import iterate_with_context @@ -146,7 +146,7 @@ def _yield_compute_results( ) -> Iterator[OpOutputUnion]: check.inst_param(step_context, "step_context", StepExecutionContext) - context = OpExecutionContext(step_context) + context = build_execution_context(step_context) user_event_generator = compute_fn(context, inputs) if isinstance(user_event_generator, Output): diff --git a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py index bd4d523dfa6d5..2f94d2bc40814 100644 --- a/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py +++ b/python_modules/dagster/dagster_tests/core_tests/execution_tests/test_context.py @@ -1,7 +1,18 @@ import dagster._check as check -from dagster import OpExecutionContext, job, op +import pytest +from dagster import OpExecutionContext, asset, job, materialize, op +from dagster._core.definitions.asset_out import AssetOut +from dagster._core.definitions.decorators.asset_decorator import ( + graph_asset, + graph_multi_asset, + multi_asset, +) +from dagster._core.definitions.events import Output +from dagster._core.definitions.graph_definition import GraphDefinition from dagster._core.definitions.job_definition import JobDefinition from dagster._core.definitions.op_definition import OpDefinition +from dagster._core.errors import DagsterInvalidDefinitionError +from dagster._core.execution.context.compute import AssetExecutionContext from dagster._core.storage.dagster_run import DagsterRun @@ -20,3 +31,177 @@ def foo(): ctx_op() assert foo.execute_in_process().success + + +def test_correct_context_provided_no_type_hints(): + # asset + @asset + def the_asset(context): + assert isinstance(context, AssetExecutionContext) + + materialize([the_asset]) + + # ops, jobs + @op + def the_op(context): + assert isinstance(context, OpExecutionContext) + # AssetExecutionContext is an instance of OpExecutionContext, so add this additional check + assert not isinstance(context, AssetExecutionContext) + + @job + def the_job(): + the_op() + + assert the_job.execute_in_process().success + + # multi_asset + @multi_asset(outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)}) + def the_multi_asset(context): + assert isinstance(context, AssetExecutionContext) + return None, None + + materialize([the_multi_asset]) + + # graph backed asset + @op + def the_asset_op(context): + assert isinstance(context, AssetExecutionContext) + + @graph_asset + def the_graph_asset(): + return the_asset_op() + + materialize([the_graph_asset]) + + # graph backed multi asset + @graph_multi_asset( + outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)} + ) + def the_graph_multi_asset(): + return the_asset_op(), the_asset_op() + + materialize([the_graph_multi_asset]) + + # job created using Definitions classes, not decorators + def plain_python(context, *args): + assert isinstance(context, OpExecutionContext) + # AssetExecutionContext is an instance of OpExecutionContext, so add this additional check + assert not isinstance(context, AssetExecutionContext) + yield Output(1) + + no_decorator_op = OpDefinition(compute_fn=plain_python, name="no_decorator_op") + no_decorator_graph = GraphDefinition(name="no_decorator_graph", node_defs=[no_decorator_op]) + + no_decorator_graph.to_job(name="no_decorator_job").execute_in_process() + + +def test_correct_context_provided_with_expected_type_hints(): + # asset + @asset + def the_asset(context: AssetExecutionContext): + assert isinstance(context, AssetExecutionContext) + + materialize([the_asset]) + + # ops, jobs + @op + def the_op(context: OpExecutionContext): + assert isinstance(context, OpExecutionContext) + # AssetExecutionContext is an instance of OpExecutionContext, so add this additional check + assert not isinstance(context, AssetExecutionContext) + + @job + def the_job(): + the_op() + + assert the_job.execute_in_process().success + + # multi_asset + @multi_asset(outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)}) + def the_multi_asset(context: AssetExecutionContext): + assert isinstance(context, AssetExecutionContext) + return None, None + + materialize([the_multi_asset]) + + # job created using Definitions classes, not decorators + def plain_python(context: OpExecutionContext, *args): + assert isinstance(context, OpExecutionContext) + # AssetExecutionContext is an instance of OpExecutionContext, so add this additional check + assert not isinstance(context, AssetExecutionContext) + yield Output(1) + + no_decorator_op = OpDefinition(compute_fn=plain_python, name="no_decorator_op") + no_decorator_graph = GraphDefinition(name="no_decorator_graph", node_defs=[no_decorator_op]) + + no_decorator_graph.to_job(name="no_decorator_job").execute_in_process() + + +def test_graph_asset_with_op_context(): + # TODO - this test fails right now. How do we want to handle this case? + # If we want to provide an OpExecutionContext to this op, then we need a way + # to determine if the asset is a graph-backed asset rather than an @asset or @multi_asset so that we + # can special case this behavior + # + # weird edge case: + # an op is used in both a job and a graph backed asset. This would mean in the job it would get an + # OpExecutionContext, but in the graph backed asset it would get an AssetExecutionContext. Once we + # deprecate the op methods from AssetExecutionContext this will be a problem since a method like + # describe_op would be accessed as context.describe_op in the job and context.op_execution_context.describe_op + # in the graph backed asset + + @op + def the_op(context: OpExecutionContext): + assert isinstance(context, OpExecutionContext) + # AssetExecutionContext is an instance of OpExecutionContext, so add this additional check + assert not isinstance(context, AssetExecutionContext) + + @graph_asset + def the_graph_asset(): + return the_op() + + materialize([the_graph_asset]) + + @graph_multi_asset( + outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)} + ) + def the_graph_multi_asset(): + return the_op(), the_op() + + materialize([the_graph_multi_asset]) + + +def test_graph_asset_with_asset_context(): + @op + def the_op(context: AssetExecutionContext): + assert isinstance(context, AssetExecutionContext) + + @graph_asset + def the_graph_asset(): + return the_op() + + materialize([the_graph_asset]) + + @graph_multi_asset( + outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)} + ) + def the_graph_multi_asset(): + return the_op(), the_op() + + materialize([the_graph_multi_asset]) + + +def test_error_on_context_type_mismatch(): + @op + def the_op(context: AssetExecutionContext): + pass + + @job + def the_job(): + the_op() + + with pytest.raises( + DagsterInvalidDefinitionError, + match="When executed in jobs, the op context should be annotated with OpExecutionContext", + ): + assert the_job.execute_in_process().success