From 6cef302e39b1703ceba138af041027c6e0333111 Mon Sep 17 00:00:00 2001 From: JamieDeMaria Date: Fri, 17 Nov 2023 16:19:27 -0500 Subject: [PATCH] update interfaces --- .../_core/definitions/op_invocation.py | 31 ++++++++-------- .../_core/execution/context/invocation.py | 36 ++++++++++++++++--- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/python_modules/dagster/dagster/_core/definitions/op_invocation.py b/python_modules/dagster/dagster/_core/definitions/op_invocation.py index 29454ff0a918a..e275da777b429 100644 --- a/python_modules/dagster/dagster/_core/definitions/op_invocation.py +++ b/python_modules/dagster/dagster/_core/definitions/op_invocation.py @@ -32,7 +32,9 @@ from .result import MaterializeResult if TYPE_CHECKING: - from ..execution.context.invocation import RunlessOpExecutionContext + from ..execution.context.invocation import ( + BaseDirectInvocationContext, + ) from .assets import AssetsDefinition from .composition import PendingNodeInvocation from .decorators.op_decorator import DecoratedOpFunction @@ -230,7 +232,7 @@ def direct_invocation_result( def _resolve_inputs( - op_def: "OpDefinition", args, kwargs, context: "RunlessOpExecutionContext" + op_def: "OpDefinition", args, kwargs, context: "BaseDirectInvocationContext" ) -> Mapping[str, Any]: from dagster._core.execution.plan.execute_step import do_type_check @@ -268,9 +270,8 @@ def _resolve_inputs( "but no context parameter was defined for the op." ) - node_label = op_def.node_type_str raise DagsterInvalidInvocationError( - f"Too many input arguments were provided for {node_label} '{context.bound_properties.alias}'." + f"Too many input arguments were provided for {context.bound_properties.step_description}'." f" {suggestion}" ) @@ -313,7 +314,7 @@ def _resolve_inputs( input_dict[k] = v # Type check inputs - op_label = context.bound_properties.step_description + step_label = context.bound_properties.step_description for input_name, val in input_dict.items(): input_def = input_defs_by_name[input_name] @@ -322,7 +323,7 @@ def _resolve_inputs( if not type_check.success: raise DagsterTypeCheckDidNotPass( description=( - f'Type check failed for {op_label} input "{input_def.name}" - ' + f'Type check failed for {step_label} input "{input_def.name}" - ' f'expected type "{dagster_type.display_name}". ' f"Description: {type_check.description}" ), @@ -352,7 +353,7 @@ def _key_for_result(result: MaterializeResult, context: "RunlessOpExecutionConte def _output_name_for_result_obj( event: MaterializeResult, - context: "RunlessOpExecutionContext", + context: "BaseDirectInvocationContext", ): if not context.bound_properties.assets_def: raise DagsterInvariantViolationError( @@ -365,7 +366,7 @@ def _output_name_for_result_obj( def _handle_gen_event( event: T, op_def: "OpDefinition", - context: "RunlessOpExecutionContext", + context: "BaseDirectInvocationContext", output_defs: Mapping[str, OutputDefinition], outputs_seen: Set[str], ) -> T: @@ -391,7 +392,7 @@ def _handle_gen_event( output_def, DynamicOutputDefinition ): raise DagsterInvariantViolationError( - f"Invocation of {op_def.node_type_str} '{context.bound_properties.alias}' yielded" + f"Invocation of {context.bound_properties.step_description} yielded" f" an output '{output_def.name}' multiple times." ) outputs_seen.add(output_def.name) @@ -399,7 +400,7 @@ def _handle_gen_event( def _type_check_output_wrapper( - op_def: "OpDefinition", result: Any, context: "RunlessOpExecutionContext" + op_def: "OpDefinition", result: Any, context: "BaseDirectInvocationContext" ) -> Any: """Type checks and returns the result of a op. @@ -493,7 +494,7 @@ def type_check_gen(gen): def _type_check_function_output( - op_def: "OpDefinition", result: T, context: "RunlessOpExecutionContext" + op_def: "OpDefinition", result: T, context: "BaseDirectInvocationContext" ) -> T: from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator @@ -512,25 +513,25 @@ def _type_check_function_output( def _type_check_output( output_def: "OutputDefinition", output: Union[Output, DynamicOutput], - context: "RunlessOpExecutionContext", + context: "BaseDirectInvocationContext", ) -> None: """Validates and performs core type check on a provided output. Args: output_def (OutputDefinition): The output definition to validate against. output (Any): The output to validate. - context (RunlessOpExecutionContext): Context containing resources to be used for type + context (BaseDirectInvocationContext): Context containing resources to be used for type check. """ from ..execution.plan.execute_step import do_type_check - op_label = context.bound_properties.step_description + step_label = context.bound_properties.step_description dagster_type = output_def.dagster_type type_check = do_type_check(context.for_type(dagster_type), dagster_type, output.value) if not type_check.success: raise DagsterTypeCheckDidNotPass( description=( - f'Type check failed for {op_label} output "{output.output_name}" - ' + f'Type check failed for {step_label} output "{output.output_name}" - ' f'expected type "{dagster_type.display_name}". ' f"Description: {type_check.description}" ), diff --git a/python_modules/dagster/dagster/_core/execution/context/invocation.py b/python_modules/dagster/dagster/_core/execution/context/invocation.py index 05d2cb31d02ad..0c735a70a3227 100644 --- a/python_modules/dagster/dagster/_core/execution/context/invocation.py +++ b/python_modules/dagster/dagster/_core/execution/context/invocation.py @@ -1,3 +1,5 @@ +import warnings +from abc import abstractmethod from contextlib import ExitStack from typing import ( AbstractSet, @@ -56,7 +58,13 @@ from dagster._utils.merger import merge_dicts from dagster._utils.warnings import deprecation_warning -from .compute import AssetExecutionContext, ExecutionProperties, OpExecutionContext, RunProperties +from .compute import ( + AssetExecutionContext, + ContextHasExecutionProperties, + ExecutionProperties, + OpExecutionContext, + RunProperties, +) from .system import StepExecutionContext, TypeCheckContext @@ -221,7 +229,8 @@ def set_requires_typed_event_stream(self, *, error_message: Optional[str]) -> No self._typed_event_stream_error_message = error_message -class BaseDirectInvocationContext: +class BaseDirectInvocationContext(ContextHasExecutionProperties): + @abstractmethod def bind( self, op_def: OpDefinition, @@ -232,6 +241,14 @@ def bind( ): pass + @abstractmethod + def for_type(self, dagster_type: DagsterType) -> TypeCheckContext: + pass + + @abstractmethod + def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None: + pass + class RunlessOpExecutionContext(OpExecutionContext, BaseDirectInvocationContext): """The ``context`` object available as the first argument to an op's compute function when @@ -821,9 +838,18 @@ def unbind(self): self._bound = False - @property - def op_execution_context(self) -> OpExecutionContext: - return self._op_execution_context + def for_type(self, dagster_type: DagsterType) -> TypeCheckContext: + self._check_bound(fn_name="for_type", fn_type="method") + resources = cast(NamedTuple, self.resources) + return TypeCheckContext( + self.run_id, + self.log, + ScopedResourcesBuilder(resources._asdict()), + dagster_type, + ) + + def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None: + self._op_execution_context.observe_output(output_name=output_name, mapping_key=mapping_key) @property def run_properties(self) -> RunProperties: