From 8328998fd81cb6c54b7de4fccc62e3df8baffd1c Mon Sep 17 00:00:00 2001 From: JamieDeMaria Date: Wed, 6 Dec 2023 18:20:11 -0500 Subject: [PATCH] add parent class so typing can start to work --- .../_core/definitions/op_invocation.py | 42 +++++++++------ .../_core/execution/context/invocation.py | 54 ++++++++++++++++++- 2 files changed, 78 insertions(+), 18 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/op_invocation.py b/python_modules/dagster/dagster/_core/definitions/op_invocation.py index 115ccfcf55ae0..916ed63e82c77 100644 --- a/python_modules/dagster/dagster/_core/definitions/op_invocation.py +++ b/python_modules/dagster/dagster/_core/definitions/op_invocation.py @@ -19,6 +19,7 @@ DagsterInvariantViolationError, DagsterTypeCheckDidNotPass, ) +from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext from .events import ( AssetKey, @@ -32,7 +33,7 @@ from .result import MaterializeResult if TYPE_CHECKING: - from ..execution.context.invocation import RunlessOpExecutionContext + from ..execution.context.invocation import BaseRunlessContext from .assets import AssetsDefinition from .composition import PendingNodeInvocation from .decorators.op_decorator import DecoratedOpFunction @@ -100,6 +101,14 @@ def _separate_args_and_kwargs( ) +def _get_op_context( + context: Union[OpExecutionContext, AssetExecutionContext] +) -> OpExecutionContext: + if isinstance(context, AssetExecutionContext): + return context.op_execution_context + return context + + def direct_invocation_result( def_or_invocation: Union[ "OpDefinition", "PendingNodeInvocation[OpDefinition]", "AssetsDefinition" @@ -109,7 +118,7 @@ def direct_invocation_result( ) -> Any: from dagster._config.pythonic_config import Config from dagster._core.execution.context.invocation import ( - RunlessOpExecutionContext, + BaseRunlessContext, build_op_context, ) @@ -149,12 +158,12 @@ def direct_invocation_result( " no context was provided when invoking." ) if len(args) > 0: - if args[0] is not None and not isinstance(args[0], RunlessOpExecutionContext): + if args[0] is not None and not isinstance(args[0], BaseRunlessContext): raise DagsterInvalidInvocationError( f"Decorated function '{compute_fn.name}' has context argument, " "but no context was provided when invoking." ) - context = cast(RunlessOpExecutionContext, args[0]) + context = args[0] # update args to omit context args = args[1:] else: # context argument is provided under kwargs @@ -165,14 +174,14 @@ def direct_invocation_result( f"'{context_param_name}', but no value for '{context_param_name}' was " f"found when invoking. Provided kwargs: {kwargs}" ) - context = cast(RunlessOpExecutionContext, kwargs[context_param_name]) + context = kwargs[context_param_name] # update kwargs to remove context kwargs = { kwarg: val for kwarg, val in kwargs.items() if not kwarg == context_param_name } # allow passing context, even if the function doesn't have an arg for it - elif len(args) > 0 and isinstance(args[0], RunlessOpExecutionContext): - context = cast(RunlessOpExecutionContext, args[0]) + elif len(args) > 0 and isinstance(args[0], BaseRunlessContext): + context = args[0] args = args[1:] resource_arg_mapping = {arg.name: arg.name for arg in compute_fn.get_resource_args()} @@ -230,7 +239,7 @@ def direct_invocation_result( def _resolve_inputs( - op_def: "OpDefinition", args, kwargs, context: "RunlessOpExecutionContext" + op_def: "OpDefinition", args, kwargs, context: "BaseRunlessContext" ) -> Mapping[str, Any]: from dagster._core.execution.plan.execute_step import do_type_check @@ -333,7 +342,7 @@ def _resolve_inputs( return input_dict -def _key_for_result(result: MaterializeResult, context: "RunlessOpExecutionContext") -> AssetKey: +def _key_for_result(result: MaterializeResult, context: "BaseRunlessContext") -> AssetKey: if not context.bound_properties.assets_def: raise DagsterInvariantViolationError( f"Op {context.bound_properties.alias} does not have an assets definition." @@ -352,7 +361,7 @@ def _key_for_result(result: MaterializeResult, context: "RunlessOpExecutionConte def _output_name_for_result_obj( event: MaterializeResult, - context: "RunlessOpExecutionContext", + context: "BaseRunlessContext", ): if not context.bound_properties.assets_def: raise DagsterInvariantViolationError( @@ -365,7 +374,7 @@ def _output_name_for_result_obj( def _handle_gen_event( event: T, op_def: "OpDefinition", - context: "RunlessOpExecutionContext", + context: "BaseRunlessContext", output_defs: Mapping[str, OutputDefinition], outputs_seen: Set[str], ) -> T: @@ -399,7 +408,7 @@ def _handle_gen_event( def _type_check_output_wrapper( - op_def: "OpDefinition", result: Any, context: "RunlessOpExecutionContext" + op_def: "OpDefinition", result: Any, context: "BaseRunlessContext" ) -> Any: """Type checks and returns the result of a op. @@ -493,12 +502,13 @@ def type_check_gen(gen): def _type_check_function_output( - op_def: "OpDefinition", result: T, context: "RunlessOpExecutionContext" + op_def: "OpDefinition", result: T, context: "BaseRunlessContext" ) -> T: from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator output_defs_by_name = {output_def.name: output_def for output_def in op_def.output_defs} - for event in validate_and_coerce_op_result_to_iterator(result, context, op_def.output_defs): + op_context = _get_op_context(context) + for event in validate_and_coerce_op_result_to_iterator(result, op_context, op_def.output_defs): if isinstance(event, (Output, DynamicOutput)): _type_check_output(output_defs_by_name[event.output_name], event, context) elif isinstance(event, (MaterializeResult)): @@ -512,14 +522,14 @@ def _type_check_function_output( def _type_check_output( output_def: "OutputDefinition", output: Union[Output, DynamicOutput], - context: "RunlessOpExecutionContext", + context: "BaseRunlessContext", ) -> None: """Validates and performs core type check on a provided output. Args: output_def (OutputDefinition): The output definition to validate against. output (Any): The output to validate. - context (RunlessOpExecutionContext): Context containing resources to be used for type + context (BaseRunlessContext): Context containing resources to be used for type check. """ from ..execution.plan.execute_step import do_type_check diff --git a/python_modules/dagster/dagster/_core/execution/context/invocation.py b/python_modules/dagster/dagster/_core/execution/context/invocation.py index f99850726526a..d794375b62819 100644 --- a/python_modules/dagster/dagster/_core/execution/context/invocation.py +++ b/python_modules/dagster/dagster/_core/execution/context/invocation.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from contextlib import ExitStack from typing import ( AbstractSet, @@ -67,6 +68,37 @@ def _property_msg(prop_name: str, method_name: str) -> str: ) +class BaseRunlessContext: + @abstractmethod + def bind( + self, + op_def: OpDefinition, + pending_invocation: Optional[PendingNodeInvocation[OpDefinition]], + assets_def: Optional[AssetsDefinition], + config_from_args: Optional[Mapping[str, Any]], + resources_from_args: Optional[Mapping[str, Any]], + ): + """Instances of BsaeRunlessContest must implement bind.""" + + @abstractmethod + def unbind(self): + """Instances of BsaeRunlessContest must implement unbind.""" + + @property + @abstractmethod + def bound_properties(self) -> "BoundProperties": + """Instances of BaseRunlessContext must contain a BoundProperties object.""" + + @property + @abstractmethod + def execution_properties(self) -> "RunlessExecutionProperties": + """Instances of BaseRunlessContext must contain a RunlessExecutionProperties object.""" + + @abstractmethod + def for_type(self, dagster_type: DagsterType) -> TypeCheckContext: + pass + + class BoundProperties( NamedTuple( "_BoundProperties", @@ -224,7 +256,7 @@ def set_requires_typed_event_stream(self, *, error_message: Optional[str]) -> No self._typed_event_stream_error_message = error_message -class RunlessOpExecutionContext(OpExecutionContext): +class RunlessOpExecutionContext(OpExecutionContext, BaseRunlessContext): """The ``context`` object available as the first argument to an op's compute function when being invoked directly. Can also be used as a context manager. """ @@ -717,7 +749,7 @@ def set_requires_typed_event_stream(self, *, error_message: Optional[str]) -> No self._execution_properties.set_requires_typed_event_stream(error_message=error_message) -class RunlessAssetExecutionContext(AssetExecutionContext): +class RunlessAssetExecutionContext(AssetExecutionContext, BaseRunlessContext): """The ``context`` object available as the first argument to an asset's compute function when being invoked directly. Can also be used as a context manager. """ @@ -727,6 +759,16 @@ def __init__(self, op_execution_context: RunlessOpExecutionContext): self._run_props = None + def __enter__(self): + self.op_execution_context._cm_scope_entered = True # noqa: SLF001 + return self + + def __exit__(self, *exc): + self.op_execution_context._exit_stack.close() # noqa: SLF001 + + def __del__(self): + self.op_execution_context._exit_stack.close() # noqa: SLF001 + def _check_bound(self, fn_name: str, fn_type: str): if not self._op_execution_context._bound_properties: # noqa: SLF001 raise DagsterInvalidPropertyError(_property_msg(fn_name, fn_type)) @@ -763,6 +805,14 @@ def bind( def unbind(self): self._op_execution_context = self._op_execution_context.unbind() + @property + def bound_properties(self) -> BoundProperties: + return self.op_execution_context.bound_properties + + @property + def execution_properties(self) -> RunlessExecutionProperties: + return self.op_execution_context.execution_properties + @property def op_execution_context(self) -> RunlessOpExecutionContext: return self._op_execution_context