Skip to content

Commit

Permalink
play around with providing asset execution context
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Sep 14, 2023
1 parent a6e9235 commit 1be35f3
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,13 @@ def name(self):
def has_context_arg(self) -> bool:
return is_context_provided(get_function_params(self.decorated_fn))

def get_context_arg(self) -> Parameter:
for param in get_function_params(self.decorated_fn):
if param.name == "context":
return param

check.failed("Requested context arg on function that does not have one")

@lru_cache(maxsize=1)
def _get_function_params(self) -> Sequence[Parameter]:
return get_function_params(self.decorated_fn)
Expand Down
52 changes: 52 additions & 0 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Optional,
Sequence,
Set,
Union,
cast,
)

Expand All @@ -21,6 +22,7 @@
DataVersion,
extract_data_provenance_from_entry,
)
from dagster._core.definitions.decorators.op_decorator import DecoratedOpFunction
from dagster._core.definitions.dependency import Node, NodeHandle
from dagster._core.definitions.events import (
AssetKey,
Expand All @@ -36,6 +38,7 @@
from dagster._core.definitions.step_launcher import StepLauncher
from dagster._core.definitions.time_window_partitions import TimeWindow
from dagster._core.errors import (
DagsterInvalidDefinitionError,
DagsterInvalidPropertyError,
DagsterInvariantViolationError,
)
Expand Down Expand Up @@ -1483,3 +1486,52 @@ def asset_check_spec(self) -> AssetCheckSpec:
def partition_key_range_for_asset_key(self, asset_key: AssetKey) -> PartitionKeyRange:
"""TODO - implement in stacked pr."""
pass


def build_execution_context(
step_context: StepExecutionContext,
) -> Union[OpExecutionContext, AssetExecutionContext]:
"""Get the correct context based on the type of step (op or asset) and the user provided context
type annotation. Follows these rules.
step type annotation result
asset AssetExecutionContext AssetExecutionContext
asset OpExecutionContext AssetExecutionContext - with deprecation warning
asset None AssetExecutionContext
op AssetExecutionContext Error
op OpExecutionContext OpExecutionContext
op None OpExecutionContext
"""
is_sda_step = step_context.is_sda_step
is_asset_context = False
is_op_context = False

compute_fn = step_context.op_def._compute_fn # noqa: SLF001
if isinstance(compute_fn, DecoratedOpFunction) and compute_fn.has_context_arg():
context_param = compute_fn.get_context_arg()
is_asset_context = context_param.annotation is AssetExecutionContext
is_op_context = context_param.annotation is OpExecutionContext

if is_asset_context and not is_sda_step:
raise DagsterInvalidDefinitionError(
"When executed in jobs, the op context should be annotated with OpExecutionContext, not"
" AssetExecutionContext."
)

if is_op_context and is_sda_step:
deprecation_warning(
"Contexts with type annotation OpExecutionContext for @assets, @multi_assets,"
" @graph_asset, and @graph_multi_asset.",
"1.7.0",
additional_warn_text="Please annotate the context with AssetExecutionContext",
stacklevel=1,
)

op_context = OpExecutionContext(step_context)

# TODO - determine special casing for graph backed assets

if is_sda_step:
return AssetExecutionContext(op_context)
return op_context
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.compute import build_execution_context
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.system_config.objects import ResolvedRunConfig
from dagster._utils import iterate_with_context
Expand Down Expand Up @@ -146,7 +146,7 @@ def _yield_compute_results(
) -> Iterator[OpOutputUnion]:
check.inst_param(step_context, "step_context", StepExecutionContext)

context = OpExecutionContext(step_context)
context = build_execution_context(step_context)
user_event_generator = compute_fn(context, inputs)

if isinstance(user_event_generator, Output):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import dagster._check as check
from dagster import OpExecutionContext, job, op
import pytest
from dagster import OpExecutionContext, asset, job, materialize, op
from dagster._core.definitions.asset_out import AssetOut
from dagster._core.definitions.decorators.asset_decorator import (
graph_asset,
graph_multi_asset,
multi_asset,
)
from dagster._core.definitions.events import Output
from dagster._core.definitions.graph_definition import GraphDefinition
from dagster._core.definitions.job_definition import JobDefinition
from dagster._core.definitions.op_definition import OpDefinition
from dagster._core.errors import DagsterInvalidDefinitionError
from dagster._core.execution.context.compute import AssetExecutionContext
from dagster._core.storage.dagster_run import DagsterRun


Expand All @@ -20,3 +31,177 @@ def foo():
ctx_op()

assert foo.execute_in_process().success


def test_correct_context_provided_no_type_hints():
# asset
@asset
def the_asset(context):
assert isinstance(context, AssetExecutionContext)

materialize([the_asset])

# ops, jobs
@op
def the_op(context):
assert isinstance(context, OpExecutionContext)
# AssetExecutionContext is an instance of OpExecutionContext, so add this additional check
assert not isinstance(context, AssetExecutionContext)

@job
def the_job():
the_op()

assert the_job.execute_in_process().success

# multi_asset
@multi_asset(outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)})
def the_multi_asset(context):
assert isinstance(context, AssetExecutionContext)
return None, None

materialize([the_multi_asset])

# graph backed asset
@op
def the_asset_op(context):
assert isinstance(context, AssetExecutionContext)

@graph_asset
def the_graph_asset():
return the_asset_op()

materialize([the_graph_asset])

# graph backed multi asset
@graph_multi_asset(
outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)}
)
def the_graph_multi_asset():
return the_asset_op(), the_asset_op()

materialize([the_graph_multi_asset])

# job created using Definitions classes, not decorators
def plain_python(context, *args):
assert isinstance(context, OpExecutionContext)
# AssetExecutionContext is an instance of OpExecutionContext, so add this additional check
assert not isinstance(context, AssetExecutionContext)
yield Output(1)

no_decorator_op = OpDefinition(compute_fn=plain_python, name="no_decorator_op")
no_decorator_graph = GraphDefinition(name="no_decorator_graph", node_defs=[no_decorator_op])

no_decorator_graph.to_job(name="no_decorator_job").execute_in_process()


def test_correct_context_provided_with_expected_type_hints():
# asset
@asset
def the_asset(context: AssetExecutionContext):
assert isinstance(context, AssetExecutionContext)

materialize([the_asset])

# ops, jobs
@op
def the_op(context: OpExecutionContext):
assert isinstance(context, OpExecutionContext)
# AssetExecutionContext is an instance of OpExecutionContext, so add this additional check
assert not isinstance(context, AssetExecutionContext)

@job
def the_job():
the_op()

assert the_job.execute_in_process().success

# multi_asset
@multi_asset(outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)})
def the_multi_asset(context: AssetExecutionContext):
assert isinstance(context, AssetExecutionContext)
return None, None

materialize([the_multi_asset])

# job created using Definitions classes, not decorators
def plain_python(context: OpExecutionContext, *args):
assert isinstance(context, OpExecutionContext)
# AssetExecutionContext is an instance of OpExecutionContext, so add this additional check
assert not isinstance(context, AssetExecutionContext)
yield Output(1)

no_decorator_op = OpDefinition(compute_fn=plain_python, name="no_decorator_op")
no_decorator_graph = GraphDefinition(name="no_decorator_graph", node_defs=[no_decorator_op])

no_decorator_graph.to_job(name="no_decorator_job").execute_in_process()


def test_graph_asset_with_op_context():
# TODO - this test fails right now. How do we want to handle this case?
# If we want to provide an OpExecutionContext to this op, then we need a way
# to determine if the asset is a graph-backed asset rather than an @asset or @multi_asset so that we
# can special case this behavior
#
# weird edge case:
# an op is used in both a job and a graph backed asset. This would mean in the job it would get an
# OpExecutionContext, but in the graph backed asset it would get an AssetExecutionContext. Once we
# deprecate the op methods from AssetExecutionContext this will be a problem since a method like
# describe_op would be accessed as context.describe_op in the job and context.op_execution_context.describe_op
# in the graph backed asset

@op
def the_op(context: OpExecutionContext):
assert isinstance(context, OpExecutionContext)
# AssetExecutionContext is an instance of OpExecutionContext, so add this additional check
assert not isinstance(context, AssetExecutionContext)

@graph_asset
def the_graph_asset():
return the_op()

materialize([the_graph_asset])

@graph_multi_asset(
outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)}
)
def the_graph_multi_asset():
return the_op(), the_op()

materialize([the_graph_multi_asset])


def test_graph_asset_with_asset_context():
@op
def the_op(context: AssetExecutionContext):
assert isinstance(context, AssetExecutionContext)

@graph_asset
def the_graph_asset():
return the_op()

materialize([the_graph_asset])

@graph_multi_asset(
outs={"out1": AssetOut(dagster_type=None), "out2": AssetOut(dagster_type=None)}
)
def the_graph_multi_asset():
return the_op(), the_op()

materialize([the_graph_multi_asset])


def test_error_on_context_type_mismatch():
@op
def the_op(context: AssetExecutionContext):
pass

@job
def the_job():
the_op()

with pytest.raises(
DagsterInvalidDefinitionError,
match="When executed in jobs, the op context should be annotated with OpExecutionContext",
):
assert the_job.execute_in_process().success

0 comments on commit 1be35f3

Please sign in to comment.