Skip to content

Commit

Permalink
update interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Dec 7, 2023
1 parent 21ab088 commit 42acce4
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 20 deletions.
31 changes: 16 additions & 15 deletions python_modules/dagster/dagster/_core/definitions/op_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from .result import MaterializeResult

if TYPE_CHECKING:
from ..execution.context.invocation import RunlessOpExecutionContext
from ..execution.context.invocation import (
BaseDirectInvocationContext,
)
from .assets import AssetsDefinition
from .composition import PendingNodeInvocation
from .decorators.op_decorator import DecoratedOpFunction
Expand Down Expand Up @@ -230,7 +232,7 @@ def direct_invocation_result(


def _resolve_inputs(
op_def: "OpDefinition", args, kwargs, context: "RunlessOpExecutionContext"
op_def: "OpDefinition", args, kwargs, context: "BaseDirectInvocationContext"
) -> Mapping[str, Any]:
from dagster._core.execution.plan.execute_step import do_type_check

Expand Down Expand Up @@ -268,9 +270,8 @@ def _resolve_inputs(
"but no context parameter was defined for the op."
)

node_label = op_def.node_type_str
raise DagsterInvalidInvocationError(
f"Too many input arguments were provided for {node_label} '{context.bound_properties.alias}'."
f"Too many input arguments were provided for {context.bound_properties.step_description}'."
f" {suggestion}"
)

Expand Down Expand Up @@ -313,7 +314,7 @@ def _resolve_inputs(
input_dict[k] = v

# Type check inputs
op_label = context.bound_properties.step_description
step_label = context.bound_properties.step_description

for input_name, val in input_dict.items():
input_def = input_defs_by_name[input_name]
Expand All @@ -322,7 +323,7 @@ def _resolve_inputs(
if not type_check.success:
raise DagsterTypeCheckDidNotPass(
description=(
f'Type check failed for {op_label} input "{input_def.name}" - '
f'Type check failed for {step_label} input "{input_def.name}" - '
f'expected type "{dagster_type.display_name}". '
f"Description: {type_check.description}"
),
Expand Down Expand Up @@ -352,7 +353,7 @@ def _key_for_result(result: MaterializeResult, context: "RunlessOpExecutionConte

def _output_name_for_result_obj(
event: MaterializeResult,
context: "RunlessOpExecutionContext",
context: "BaseDirectInvocationContext",
):
if not context.bound_properties.assets_def:
raise DagsterInvariantViolationError(
Expand All @@ -365,7 +366,7 @@ def _output_name_for_result_obj(
def _handle_gen_event(
event: T,
op_def: "OpDefinition",
context: "RunlessOpExecutionContext",
context: "BaseDirectInvocationContext",
output_defs: Mapping[str, OutputDefinition],
outputs_seen: Set[str],
) -> T:
Expand All @@ -391,15 +392,15 @@ def _handle_gen_event(
output_def, DynamicOutputDefinition
):
raise DagsterInvariantViolationError(
f"Invocation of {op_def.node_type_str} '{context.bound_properties.alias}' yielded"
f"Invocation of {context.bound_properties.step_description} yielded"
f" an output '{output_def.name}' multiple times."
)
outputs_seen.add(output_def.name)
return event


def _type_check_output_wrapper(
op_def: "OpDefinition", result: Any, context: "RunlessOpExecutionContext"
op_def: "OpDefinition", result: Any, context: "BaseDirectInvocationContext"
) -> Any:
"""Type checks and returns the result of a op.
Expand Down Expand Up @@ -493,7 +494,7 @@ def type_check_gen(gen):


def _type_check_function_output(
op_def: "OpDefinition", result: T, context: "RunlessOpExecutionContext"
op_def: "OpDefinition", result: T, context: "BaseDirectInvocationContext"
) -> T:
from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator

Expand All @@ -512,25 +513,25 @@ def _type_check_function_output(
def _type_check_output(
output_def: "OutputDefinition",
output: Union[Output, DynamicOutput],
context: "RunlessOpExecutionContext",
context: "BaseDirectInvocationContext",
) -> None:
"""Validates and performs core type check on a provided output.
Args:
output_def (OutputDefinition): The output definition to validate against.
output (Any): The output to validate.
context (RunlessOpExecutionContext): Context containing resources to be used for type
context (BaseDirectInvocationContext): Context containing resources to be used for type
check.
"""
from ..execution.plan.execute_step import do_type_check

op_label = context.bound_properties.step_description
step_label = context.bound_properties.step_description
dagster_type = output_def.dagster_type
type_check = do_type_check(context.for_type(dagster_type), dagster_type, output.value)
if not type_check.success:
raise DagsterTypeCheckDidNotPass(
description=(
f'Type check failed for {op_label} output "{output.output_name}" - '
f'Type check failed for {step_label} output "{output.output_name}" - '
f'expected type "{dagster_type.display_name}". '
f"Description: {type_check.description}"
),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings
from abc import abstractmethod
from contextlib import ExitStack
from typing import (
AbstractSet,
Expand Down Expand Up @@ -56,7 +58,13 @@
from dagster._utils.merger import merge_dicts
from dagster._utils.warnings import deprecation_warning

from .compute import AssetExecutionContext, ExecutionProperties, OpExecutionContext, RunProperties
from .compute import (
AssetExecutionContext,
ContextHasExecutionProperties,
ExecutionProperties,
OpExecutionContext,
RunProperties,
)
from .system import StepExecutionContext, TypeCheckContext


Expand Down Expand Up @@ -221,7 +229,8 @@ def set_requires_typed_event_stream(self, *, error_message: Optional[str]) -> No
self._typed_event_stream_error_message = error_message


class BaseDirectInvocationContext:
class BaseDirectInvocationContext(ContextHasExecutionProperties):
@abstractmethod
def bind(
self,
op_def: OpDefinition,
Expand All @@ -232,6 +241,14 @@ def bind(
):
pass

@abstractmethod
def for_type(self, dagster_type: DagsterType) -> TypeCheckContext:
pass

@abstractmethod
def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None:
pass


class RunlessOpExecutionContext(OpExecutionContext, BaseDirectInvocationContext):
"""The ``context`` object available as the first argument to an op's compute function when
Expand Down Expand Up @@ -817,9 +834,18 @@ def unbind(self):

self._bound = False

@property
def op_execution_context(self) -> OpExecutionContext:
return self._op_execution_context
def for_type(self, dagster_type: DagsterType) -> TypeCheckContext:
self._check_bound(fn_name="for_type", fn_type="method")
resources = cast(NamedTuple, self.resources)
return TypeCheckContext(
self.run_id,
self.log,
ScopedResourcesBuilder(resources._asdict()),
dagster_type,
)

def observe_output(self, output_name: str, mapping_key: Optional[str] = None) -> None:
self._op_execution_context.observe_output(output_name=output_name, mapping_key=mapping_key)

@property
def run_properties(self) -> RunProperties:
Expand Down

0 comments on commit 42acce4

Please sign in to comment.