Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Much simpler AssetExecutionContext #16417

Closed
wants to merge 19 commits into from
218 changes: 208 additions & 10 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
from abc import ABC, ABCMeta, abstractmethod
from typing import (
AbstractSet,
Any,
Expand All @@ -12,8 +12,6 @@
cast,
)

from typing_extensions import TypeAlias

import dagster._check as check
from dagster._annotations import deprecated, experimental, public
from dagster._core.definitions.asset_check_spec import AssetCheckSpec
Expand Down Expand Up @@ -46,11 +44,17 @@
from dagster._core.log_manager import DagsterLogManager
from dagster._core.storage.dagster_run import DagsterRun
from dagster._utils.forked_pdb import ForkedPdb
from dagster._utils.warnings import deprecation_warning

from .system import StepExecutionContext


class AbstractComputeExecutionContext(ABC):
# This metaclass has to exist for OpExecutionContext to have a metaclass
class AbstractComputeMetaclass(ABCMeta):
pass


class AbstractComputeExecutionContext(ABC, metaclass=AbstractComputeMetaclass):
"""Base class for op context implemented by OpExecutionContext and DagstermillExecutionContext."""

@abstractmethod
Expand Down Expand Up @@ -97,7 +101,18 @@ def op_config(self) -> Any:
"""The parsed config specific to this op."""


class OpExecutionContext(AbstractComputeExecutionContext):
class OpExecutionContextMetaClass(AbstractComputeMetaclass):
def __instancecheck__(cls, instance) -> bool:
# This makes isinstance(context, OpExecutionContext) return True when
# the context is an AssetExecutionContext. This makes the new
# AssetExecutionContext backwards compatible with the old
# OpExecutionContext codepaths.
if isinstance(instance, AssetExecutionContext):
return True
return super().__instancecheck__(instance)


class OpExecutionContext(AbstractComputeExecutionContext, metaclass=OpExecutionContextMetaClass):
"""The ``context`` object that can be made available as the first argument to the function
used for computing an op or asset.

Expand Down Expand Up @@ -688,8 +703,191 @@ def asset_check_spec(self) -> AssetCheckSpec:
return asset_checks_def.spec


# actually forking the object type for assets is tricky for users in the cases of:
# * manually constructing ops to make AssetsDefinitions
# * having ops in a graph that form a graph backed asset
# so we have a single type that users can call by their preferred name where appropriate
AssetExecutionContext: TypeAlias = OpExecutionContext
OP_EXECUTION_CONTEXT_ONLY_METHODS = set(
[
"describe_op",
"file_manager",
"has_assets_def",
"get_mapping_key",
# "get_step_execution_context", # used by internals
"job_def",
"job_name",
"node_handle",
"op",
"op_config",
# "op_def", # used by internals
"op_handle",
"retry_number",
"step_launcher",
# "has_events", # used by internals
"consumer_events",
]
)


PARTITION_KEY_RANGE_AS_ALT = "use partition_key_range or partition_key_range_for_asset instead"
INPUT_OUTPUT_ALT = "not use input or output names and instead use asset keys directly"
OUTPUT_METADATA_ALT = "return MaterializationResult from the asset instead"

DEPRECATED_IO_MANAGER_CENTRIC_CONTEXT_METHODS = {
"add_output_metadata": OUTPUT_METADATA_ALT,
"asset_key_for_input": INPUT_OUTPUT_ALT,
"asset_key_for_output": INPUT_OUTPUT_ALT,
"asset_partition_key_for_input": PARTITION_KEY_RANGE_AS_ALT,
"asset_partition_key_for_output": PARTITION_KEY_RANGE_AS_ALT,
"asset_partition_key_range_for_input": PARTITION_KEY_RANGE_AS_ALT,
"asset_partition_key_range_for_output": PARTITION_KEY_RANGE_AS_ALT,
"asset_partition_keys_for_input": PARTITION_KEY_RANGE_AS_ALT,
"asset_partition_keys_for_output": PARTITION_KEY_RANGE_AS_ALT,
"asset_partitions_time_window_for_input": PARTITION_KEY_RANGE_AS_ALT,
"asset_partitions_time_window_for_output": PARTITION_KEY_RANGE_AS_ALT,
"asset_partitions_def_for_input": PARTITION_KEY_RANGE_AS_ALT,
"asset_partitions_def_for_output": PARTITION_KEY_RANGE_AS_ALT,
"get_output_metadata": "use op_execution_context.op_def.get_output(...).metadata",
"merge_output_metadata": OUTPUT_METADATA_ALT,
"output_for_asset_key": INPUT_OUTPUT_ALT,
"selected_output_names": INPUT_OUTPUT_ALT,
}

ALTERNATE_AVAILABLE_METHODS = {
"has_tag": "use dagster_run.has_tag instead",
"get_tag": "use dagster_run.get_tag instead",
"run_tags": "use dagster_run.tags instead",
"set_data_version": "use MaterializationResult instead",
}


class AssetExecutionContext:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thread for figuring out why AssetExecutionContext was not made a subclass of OpExecutionContext to begin with.

Background:

  • AssetExecutionContext began as a type alias to align on naming and get a quick docs/examples improvement (here)
  • AssetExecutionContext was made a subclass of OpExecutionContext here
  • AssetExecutionContext was reverted back to a type alias here

In the revert PR, the reasoning for reverting was:

Conditions like:

* manually constructing AssetsDefinition with a manually written @op
* @ops that make up a graph backed AssetsDefinition

make having different context objects trickier for users than originally anticipated.

There is also a slack thread mentioning this where alex says

the wall I hit with trying to split AssetExecutionContext and OpExecutionContext was resolving what the ops in a graph-backed-asset should receive.

based on this, my interpretation is that the issue wasn't a technical one (python limitation, inability to pass the correct context through, etc), but more a design issue "What is the correct context for an @op in a @graph_backed_asset to receive?"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok got it. Thanks for digging that up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can figure out a reasonable solution for the graph-backed asset case. We could alter what instance the user gets based on their typehint. We could also make an asset_execution_context property on OpExecutionContext so you can do the reverse in that case.

def __init__(self, op_execution_context: OpExecutionContext) -> None:
self._op_execution_context = check.inst_param(
op_execution_context, "op_execution_context", OpExecutionContext
)

def __getattr__(self, attr) -> Any:
check.str_param(attr, "attr")

if attr in self.__dict__:
return getattr(self, attr)

if attr in OP_EXECUTION_CONTEXT_ONLY_METHODS:
deprecation_warning(
subject=f"AssetExecutionContext.{attr}",
additional_warn_text=(
f"You have called the deprecated method {attr} on AssetExecutionContext. Use"
" the underlying OpExecutionContext instead by calling"
f" op_execution_context.{attr}."
),
breaking_version="1.7",
stacklevel=1,
)

if attr in DEPRECATED_IO_MANAGER_CENTRIC_CONTEXT_METHODS:
alt = DEPRECATED_IO_MANAGER_CENTRIC_CONTEXT_METHODS[attr]

# warnings.warn(
deprecation_warning(
subject=f"AssetExecutionContext.{attr}",
additional_warn_text=(
f"You have called method {attr} on AssetExecutionContext that is oriented"
f" around I/O managers. If you not using I/O managers we suggest you {alt}. If"
" you are using I/O managers the method still exists at"
f" op_execution_context.{attr}."
),
breaking_version="1.7",
stacklevel=1,
)

if attr in ALTERNATE_AVAILABLE_METHODS:
deprecation_warning(
subject=f"AssetExecutionContext.{attr}",
additional_warn_text=f"Instead {ALTERNATE_AVAILABLE_METHODS[attr]}.",
breaking_version="1.7",
stacklevel=1,
)

return getattr(self._op_execution_context, attr)

# include all supported methods below

@public
@property
def op_execution_context(self) -> OpExecutionContext:
return self._op_execution_context

@public
@property
def run_id(self) -> str:
return self._op_execution_context.run_id

@public
@property
def dagster_run(self) -> DagsterRun:
"""PipelineRun: The current pipeline run."""
return self._step_execution_context.dagster_run

@public
@property
def asset_key(self) -> AssetKey:
return self._op_execution_context.asset_key

@public
@property
def pdb(self) -> ForkedPdb:
return self._op_execution_context.pdb

@public
@property
def log(self) -> DagsterLogManager:
"""DagsterLogManager: The log manager available in the execution context."""
return self._op_execution_context.log

@public
# renaming to avoid ambiguity in single run and multi-partition case
@property
def is_partitioned_execution(self) -> bool:
return self._op_execution_context.has_partition_key

@public
def log_event(self, event: UserEvent) -> None:
return self._op_execution_context.log_event(event)

@public
@property
def assets_def(self) -> AssetsDefinition:
return self._op_execution_context.assets_def

@public
# TODO confirm semantics in the presense of asset subsetting
# seems like there should be both "asset_keys" and "selected_asset_keys"
@property
def selected_asset_keys(self) -> AbstractSet[AssetKey]:
return self._op_execution_context.selected_asset_keys
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also likely will be adding selected_asset_checks here. This refactor might be an opportunity to condense them to some selection object


@public
@experimental
def get_asset_provenance(self, asset_key: AssetKey) -> Optional[DataProvenance]:
return self._op_execution_context.get_asset_provenance(asset_key)

@property
def asset_check_spec(self) -> AssetCheckSpec:
return self._op_execution_context.asset_check_spec

@public
@property
def partition_key_range(self) -> PartitionKeyRange:
return self._op_execution_context.asset_partition_key_range

@public
def partition_key_range_for_asset_key(self, asset_key: AssetKey) -> PartitionKeyRange:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be partition_key_range_for_dep()?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just leaving a note that making the partition methods agnostic to "input" or "output" lead to some pretty complex errors. Some of the methods (like the partition_key_range one) have different code paths for "inputs" and "outputs" because of partition mapping. Not worth trying to do all of that implementation detail in this RFC, but for a final implementation, we should make sure this code path is really thoroughly tested

subset = self._op_execution_context.get_step_execution_context().asset_partitions_subset_for_asset_key(
asset_key
)
partition_key_ranges = subset.get_partition_key_ranges(
dynamic_partitions_store=self._op_execution_context.instance
)
if len(partition_key_ranges) != 1:
check.failed(
"Tried to access asset partition key range, but there are "
f"({len(partition_key_ranges)}) key ranges associated with this asset key.",
)
return partition_key_ranges[0]
85 changes: 46 additions & 39 deletions python_modules/dagster/dagster/_core/execution/context/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,50 +1072,57 @@ def asset_partitions_subset_for_input(
self, input_name: str, *, require_valid_partitions: bool = True
) -> PartitionsSubset:
asset_layer = self.job_def.asset_layer
assets_def = asset_layer.assets_def_for_node(self.node_handle)
upstream_asset_key = asset_layer.asset_key_for_input(self.node_handle, input_name)
asset_key = check.not_none(
asset_layer.asset_key_for_input(self.node_handle, input_name),
"The input has no asset key",
)

if upstream_asset_key is not None:
upstream_asset_partitions_def = asset_layer.partitions_def_for_asset(upstream_asset_key)
return self.asset_partitions_subset_for_asset_key(asset_key, require_valid_partitions)

if upstream_asset_partitions_def is not None:
partitions_def = assets_def.partitions_def if assets_def else None
partitions_subset = (
partitions_def.empty_subset().with_partition_key_range(
self.asset_partition_key_range, dynamic_partitions_store=self.instance
)
if partitions_def
else None
)
partition_mapping = infer_partition_mapping(
asset_layer.partition_mapping_for_node_input(
self.node_handle, upstream_asset_key
),
partitions_def,
upstream_asset_partitions_def,
)
mapped_partitions_result = (
partition_mapping.get_upstream_mapped_partitions_result_for_partitions(
partitions_subset,
upstream_asset_partitions_def,
dynamic_partitions_store=self.instance,
)
)
def asset_partitions_subset_for_asset_key(
self, asset_key: AssetKey, require_valid_partitions: bool = True
):
asset_layer = self.job_def.asset_layer
assets_def = asset_layer.assets_def_for_node(self.node_handle)
asset_partitions_def = check.not_none(
asset_layer.partitions_def_for_asset(asset_key),
"The asset key does not have a partitions definition",
)

if (
require_valid_partitions
and mapped_partitions_result.required_but_nonexistent_partition_keys
):
raise DagsterInvariantViolationError(
f"Partition key range {self.asset_partition_key_range} in"
f" {self.node_handle.name} depends on invalid partition keys"
f" {mapped_partitions_result.required_but_nonexistent_partition_keys} in"
f" upstream asset {upstream_asset_key}"
)
# assets_def can be None in cases where op-only jobs are invoked with a partition key
partitions_def = assets_def.partitions_def if assets_def else None
partitions_subset = (
partitions_def.empty_subset().with_partition_key_range(
self.asset_partition_key_range, dynamic_partitions_store=self.instance
)
if partitions_def
else None
)
partition_mapping = infer_partition_mapping(
asset_layer.partition_mapping_for_node_input(self.node_handle, asset_key),
partitions_def,
asset_partitions_def,
)
mapped_partitions_result = (
partition_mapping.get_upstream_mapped_partitions_result_for_partitions(
partitions_subset,
asset_partitions_def,
dynamic_partitions_store=self.instance,
)
)

return mapped_partitions_result.partitions_subset
if (
require_valid_partitions
and mapped_partitions_result.required_but_nonexistent_partition_keys
):
raise DagsterInvariantViolationError(
f"Partition key range {self.asset_partition_key_range} in"
f" {self.node_handle.name} depends on invalid partition keys"
f" {mapped_partitions_result.required_but_nonexistent_partition_keys} in"
f" upstream asset {asset_key}"
)

check.failed("The input has no asset partitions")
return mapped_partitions_result.partitions_subset

def asset_partition_key_for_input(self, input_name: str) -> str:
start, end = self.asset_partition_key_range_for_input(input_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.system_config.objects import ResolvedRunConfig
from dagster._utils import iterate_with_context
Expand Down Expand Up @@ -146,7 +146,11 @@ def _yield_compute_results(
) -> Iterator[OpOutputUnion]:
check.inst_param(step_context, "step_context", StepExecutionContext)

context = OpExecutionContext(step_context)
context = (
AssetExecutionContext(OpExecutionContext(step_context))
if step_context.is_sda_step
else OpExecutionContext(step_context)
)
user_event_generator = compute_fn(context, inputs)

if isinstance(user_event_generator, Output):
Expand Down
6 changes: 6 additions & 0 deletions python_modules/dagster/dagster/_core/storage/dagster_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,12 @@ def tags_for_sensor(sensor) -> Mapping[str, str]:
def tags_for_backfill_id(backfill_id: str) -> Mapping[str, str]:
return {BACKFILL_ID_TAG: backfill_id}

def get_tag(self, key: str) -> Optional[str]:
return self.tags.get(key)

def has_tag(self, key: str) -> bool:
return key in self.tags


class RunsFilterSerializer(NamedTupleSerializer["RunsFilter"]):
def before_unpack(
Expand Down
Loading