diff --git a/python_modules/dagster/dagster/_core/execution/context/compute.py b/python_modules/dagster/dagster/_core/execution/context/compute.py index b1dfe43cce873..96a6b99e99d10 100644 --- a/python_modules/dagster/dagster/_core/execution/context/compute.py +++ b/python_modules/dagster/dagster/_core/execution/context/compute.py @@ -1407,7 +1407,7 @@ class RunProperties(NamedTuple): retry_number: PublicAttr[int] -class AssetExecutionContext(OpExecutionContext): +class AssetExecutionContext: def __init__(self, op_execution_context: OpExecutionContext) -> None: self._op_execution_context = check.inst_param( op_execution_context, "op_execution_context", OpExecutionContext diff --git a/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py b/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py index dd930fa0c98a6..5118118970d45 100644 --- a/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py +++ b/python_modules/dagster/dagster/_core/execution/plan/compute_generator.py @@ -67,7 +67,7 @@ def create_op_compute_wrapper( @wraps(fn) def compute( - context: OpExecutionContext, + context: Union[OpExecutionContext, AssetExecutionContext], input_defs: Mapping[str, InputDefinition], ) -> Union[Iterator[Output], AsyncIterator[Output]]: kwargs = {} @@ -105,7 +105,9 @@ def compute( # called in this file (create_op_compute_wrapper) async def _coerce_async_op_to_async_gen( - awaitable: Awaitable[Any], context: OpExecutionContext, output_defs: Sequence[OutputDefinition] + awaitable: Awaitable[Any], + context: Union[OpExecutionContext, AssetExecutionContext], + output_defs: Sequence[OutputDefinition], ) -> AsyncIterator[Any]: result = await awaitable for event in validate_and_coerce_op_result_to_iterator(result, context, output_defs): @@ -115,7 +117,7 @@ async def _coerce_async_op_to_async_gen( # called in this file, and in op_invocation for direct invocation def invoke_compute_fn( fn: Callable, - context: OpExecutionContext, + context: Union[OpExecutionContext, AssetExecutionContext], kwargs: Mapping[str, Any], context_arg_provided: bool, config_arg_cls: Optional[Type[Config]], @@ -258,7 +260,9 @@ def _check_output_object_name( # called in op_invocation and this file def validate_and_coerce_op_result_to_iterator( - result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition] + result: Any, + context: Union[OpExecutionContext, AssetExecutionContext], + output_defs: Sequence[OutputDefinition], ) -> Iterator[Any]: if inspect.isgenerator(result): # this happens when a user explicitly returns a generator in the op