Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

provide AssetExecutionContext class to context #16493

Closed
4 changes: 2 additions & 2 deletions python_modules/dagster-test/dagster_test/toys/asset_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ def random_fail_check_on_partitioned_asset():
can_subset=True,
)
def multi_asset_1_and_2(context):
if AssetKey("multi_asset_piece_1") in context.selected_asset_keys:
if AssetKey("multi_asset_piece_1") in context.asset_keys:
yield Output(1, output_name="one")
yield AssetCheckResult(success=True, metadata={"foo": "bar"})
if AssetKey("multi_asset_piece_2") in context.selected_asset_keys:
if AssetKey("multi_asset_piece_2") in context.asset_keys:
yield Output(1, output_name="two")


Expand Down
6 changes: 6 additions & 0 deletions python_modules/dagster/dagster/_core/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,9 @@ def is_resource_def(obj: Any) -> TypeGuard["ResourceDefinition"]:
"""
class_names = [cls.__name__ for cls in inspect.getmro(obj.__class__)]
return "ResourceDefinition" in class_names


def is_context_provided(params: Sequence[Parameter]) -> bool:
if len(params) == 0:
return False
return params[0].name in get_valid_name_permutations("context")
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dagster._annotations import deprecated_param, experimental_param
from dagster._builtins import Nothing
from dagster._config import UserConfigSchema
from dagster._core.decorator_utils import get_function_params, get_valid_name_permutations
from dagster._core.decorator_utils import get_function_params
from dagster._core.definitions.asset_dep import AssetDep, CoercibleToAssetDep
from dagster._core.definitions.auto_materialize_policy import AutoMaterializePolicy
from dagster._core.definitions.config import ConfigMapping
Expand Down Expand Up @@ -854,11 +854,10 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:


def get_function_params_without_context_or_config_or_resources(fn: Callable) -> List[Parameter]:
from dagster._core.decorator_utils import is_context_provided

params = get_function_params(fn)
is_context_provided = len(params) > 0 and params[0].name in get_valid_name_permutations(
"context"
)
input_params = params[1:] if is_context_provided else params
input_params = params[1:] if is_context_provided(params) else params

resource_arg_names = {arg.name for arg in get_resource_args(fn)}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dagster._core.decorator_utils import (
format_docstring_for_description,
get_function_params,
get_valid_name_permutations,
is_context_provided,
param_is_var_keyword,
positional_arg_name_list,
)
Expand Down Expand Up @@ -342,12 +342,6 @@ def has_context_arg(self) -> bool:
return False


def is_context_provided(params: Sequence[Parameter]) -> bool:
if len(params) == 0:
return False
return params[0].name in get_valid_name_permutations("context")


def resolve_checked_op_fn_inputs(
decorator_name: str,
fn_name: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@
def wrap_source_asset_observe_fn_in_op_compute_fn(
source_asset: "SourceAsset",
) -> "DecoratedOpFunction":
from dagster._core.definitions.decorators.op_decorator import (
DecoratedOpFunction,
is_context_provided,
)
from dagster._core.decorator_utils import is_context_provided
from dagster._core.definitions.decorators.op_decorator import DecoratedOpFunction
from dagster._core.execution.context.compute import (
OpExecutionContext,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,7 @@ def log_event(self, event: UserEvent) -> None:
def get_asset_provenance(self, asset_key: AssetKey) -> Optional[DataProvenance]:
return self._op_execution_context.get_asset_provenance(asset_key)


def build_execution_context(
step_context: StepExecutionContext,
) -> Union[OpExecutionContext, AssetExecutionContext]:
Expand All @@ -1736,6 +1737,7 @@ def build_execution_context(
op AssetExecutionContext Error - we cannot init an AssetExecutionContext w/o an AssetsDefinition
op OpExecutionContext OpExecutionContext
op None OpExecutionContext

For ops in graph-backed assets
step type annotation result
op AssetExecutionContext AssetExecutionContext
Expand All @@ -1744,6 +1746,7 @@ def build_execution_context(
"""
is_sda_step = step_context.is_sda_step
is_op_in_graph_asset = is_sda_step and step_context.is_op_in_graph

context_annotation = EmptyAnnotation
compute_fn = step_context.op_def._compute_fn # noqa: SLF001
compute_fn = (
Expand All @@ -1766,14 +1769,16 @@ def build_execution_context(
" OpExecutionContext, or left blank."
)

op_context = OpExecutionContext(step_context)

if context_annotation is EmptyAnnotation:
# if no type hint has been given, default to:
# * AssetExecutionContext for sda steps, not in graph-backed assets
# * OpExecutionContext for non sda steps
# * OpExecutionContext for ops in graph-backed assets
if is_op_in_graph_asset or not is_sda_step:
return OpExecutionContext(step_context)
return AssetExecutionContext(step_context)
return op_context
return AssetExecutionContext(op_context)
if context_annotation is AssetExecutionContext:
return AssetExecutionContext(step_context)
return OpExecutionContext(step_context)
return AssetExecutionContext(op_context)
return op_context
11 changes: 7 additions & 4 deletions python_modules/dagster/dagster/_core/execution/plan/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExecutionStepExecutionError, DagsterInvariantViolationError
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.compute import build_execution_context
from dagster._core.execution.context.compute import AssetExecutionContext, build_execution_context
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.system_config.objects import ResolvedRunConfig
from dagster._utils import iterate_with_context
Expand Down Expand Up @@ -169,6 +169,9 @@ def _yield_compute_results(
user_event_generator = gen_from_async_gen(user_event_generator)

op_label = step_context.describe_op()
op_execution_context = (
context.op_execution_context if isinstance(context, AssetExecutionContext) else context
)

for event in iterate_with_context(
lambda: op_execution_error_boundary(
Expand All @@ -181,12 +184,12 @@ def _yield_compute_results(
),
user_event_generator,
):
if context.has_events():
if op_execution_context.has_events():
yield from context.consume_events()
yield _validate_event(event, step_context)

if context.has_events():
yield from context.consume_events()
if op_execution_context.has_events():
yield from op_execution_context.consume_events()


def execute_core_compute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from dagster._utils import is_named_tuple_instance
from dagster._utils.warnings import disable_dagster_warnings

from ..context.compute import OpExecutionContext
from ..context.compute import AssetExecutionContext, OpExecutionContext


class NoAnnotationSentinel:
Expand Down Expand Up @@ -242,15 +242,22 @@ def _check_output_object_name(


def validate_and_coerce_op_result_to_iterator(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
result: Any,
context: Union[OpExecutionContext, AssetExecutionContext],
output_defs: Sequence[OutputDefinition],
) -> Iterator[Any]:
if isinstance(context, AssetExecutionContext):
step_description = f" asset '{context.op_execution_context.op_def.name}'"
context = context.op_execution_context
else:
step_description = context.describe_op()
if inspect.isgenerator(result):
# this happens when a user explicitly returns a generator in the op
for event in result:
yield event
elif isinstance(result, (AssetMaterialization, ExpectationResult)):
raise DagsterInvariantViolationError(
f"Error in {context.describe_op()}: If you are "
f"Error in {step_description}: If you are "
"returning an AssetMaterialization "
"or an ExpectationResult from "
f"{context.op_def.node_type_str} you must yield them "
Expand All @@ -263,7 +270,7 @@ def validate_and_coerce_op_result_to_iterator(
yield result
elif result is not None and not output_defs:
raise DagsterInvariantViolationError(
f"Error in {context.describe_op()}: Unexpectedly returned output of type"
f"Error in {step_description}: Unexpectedly returned output of type"
f" {type(result)}. {context.op_def.node_type_str.capitalize()} is explicitly defined to"
" return no results."
)
Expand All @@ -289,15 +296,15 @@ def validate_and_coerce_op_result_to_iterator(
if output_def.is_dynamic:
if not isinstance(element, list):
raise DagsterInvariantViolationError(
f"Error with output for {context.describe_op()}: "
f"Error with output for {step_description}: "
f"dynamic output '{output_def.name}' expected a list of "
"DynamicOutput objects, but instead received instead an "
f"object of type {type(element)}."
)
for item in element:
if not isinstance(item, DynamicOutput):
raise DagsterInvariantViolationError(
f"Error with output for {context.describe_op()}: "
f"Error with output for {step_description}: "
f"dynamic output '{output_def.name}' at position {position} expected a "
"list of DynamicOutput objects, but received an "
f"item with type {type(item)}."
Expand All @@ -319,7 +326,7 @@ def validate_and_coerce_op_result_to_iterator(
annotation
):
raise DagsterInvariantViolationError(
f"Error with output for {context.describe_op()}: received Output object for"
f"Error with output for {step_description}: received Output object for"
f" output '{output_def.name}' which does not have an Output annotation."
f" Annotation has type {annotation}."
)
Expand All @@ -337,15 +344,15 @@ def validate_and_coerce_op_result_to_iterator(
# output object was not received, throw an error.
if is_generic_output_annotation(annotation):
raise DagsterInvariantViolationError(
f"Error with output for {context.describe_op()}: output "
f"Error with output for {step_description}: output "
f"'{output_def.name}' has generic output annotation, "
"but did not receive an Output object for this output. "
f"Received instead an object of type {type(element)}."
)
if result is None and output_def.is_required is False:
context.log.warning(
'Value "None" returned for non-required output '
f'"{output_def.name}" of {context.describe_op()}. '
f'"{output_def.name}" of {step_description}. '
"This value will be passed to downstream "
f"{context.op_def.node_type_str}s. For conditional "
"execution, results must be yielded: "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1585,7 +1585,7 @@ def my_function():
can_subset=True,
)
def subset(context: AssetExecutionContext):
# ...use context.selected_asset_keys materialize subset of assets without IO manager
# ...use context.asset_keys materialize subset of assets without IO manager
pass

with pytest.raises(
Expand Down Expand Up @@ -1882,7 +1882,7 @@ def basic_deps():
can_subset=True,
)
def basic_subset(context: AssetExecutionContext):
for key in context.selected_asset_keys:
for key in context.asset_keys:
yield MaterializeResult(asset_key=key)

mats = _exec_asset(basic_subset, ["table_A"])
Expand Down
Loading