Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only have one kind of context for direct invocation #17554

Merged
merged 44 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
1b2ac55
how crazy is it to not have bound op execution context
jamiedemaria Oct 31, 2023
579eb42
fix test
jamiedemaria Nov 1, 2023
c806e3c
wip test
jamiedemaria Nov 1, 2023
da5b4e9
test for boundness
jamiedemaria Nov 1, 2023
45d13de
make a copy of the context instead
jamiedemaria Nov 1, 2023
bdb7cdf
with unbind
jamiedemaria Nov 2, 2023
6c8cbe7
remove unbind
jamiedemaria Nov 2, 2023
edd3092
return copy byt dont override enter exit del
jamiedemaria Nov 3, 2023
610f8db
warn on double invocation
jamiedemaria Nov 3, 2023
1b2adbf
test the state and warnings
jamiedemaria Nov 3, 2023
cdd4733
lil ffixes
jamiedemaria Nov 13, 2023
707d8af
update test
jamiedemaria Nov 13, 2023
516a556
fix most of failing tests
jamiedemaria Nov 14, 2023
48cb1b9
re-org to have sub objects
jamiedemaria Nov 28, 2023
df5e1d2
remove unused _assets_def
jamiedemaria Nov 28, 2023
e1b8fcb
fix prop access
jamiedemaria Nov 28, 2023
ee4a6fe
fix config
jamiedemaria Nov 28, 2023
e2827d5
update boundproperties to be a plain class so attrs are mutable
jamiedemaria Nov 29, 2023
7fa3db6
wip
jamiedemaria Nov 29, 2023
19ad255
add tests for different execution types
jamiedemaria Nov 30, 2023
aad8d50
fix dictionary check
jamiedemaria Nov 30, 2023
b23b982
test fixes
jamiedemaria Nov 30, 2023
1a3aba4
test update
jamiedemaria Nov 30, 2023
a3b273d
test demo for unbinding on errors
jamiedemaria Dec 1, 2023
08a9549
re-org tests
jamiedemaria Dec 1, 2023
9273e00
update comments
jamiedemaria Dec 1, 2023
858c510
handle raised errors
jamiedemaria Dec 4, 2023
d737f5f
make pyright happy
jamiedemaria Dec 4, 2023
eff0b6a
re-org to invocation props
jamiedemaria Dec 4, 2023
b725535
clean up tests
jamiedemaria Dec 4, 2023
7edf1b2
rename DirectInvocationOpExecutionContext to RunlessOpExecutionContext
jamiedemaria Dec 5, 2023
1548074
rename to runlessexecutionproperties
jamiedemaria Dec 6, 2023
fc1c282
use bound properties in invocation
jamiedemaria Dec 6, 2023
051d742
make things properties
jamiedemaria Dec 6, 2023
1011311
fix new fn callsite
jamiedemaria Dec 6, 2023
9fa1d0e
fix test
jamiedemaria Dec 6, 2023
129bfee
access via property
jamiedemaria Dec 8, 2023
9537bde
add is_bound prop
jamiedemaria Dec 8, 2023
0371e8e
use a methods that's actually on the context
jamiedemaria Dec 8, 2023
deec5a3
rename
jamiedemaria Dec 19, 2023
46328a5
comments
jamiedemaria Dec 19, 2023
92ea8d8
missed a name
jamiedemaria Dec 19, 2023
5f174fe
final cleanup
jamiedemaria Jan 29, 2024
8e4327d
update boundproperties to perinvocationproperties
jamiedemaria Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 81 additions & 42 deletions python_modules/dagster/dagster/_core/definitions/op_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .result import MaterializeResult

if TYPE_CHECKING:
from ..execution.context.invocation import BoundOpExecutionContext
from ..execution.context.invocation import DirectOpExecutionContext
from .assets import AssetsDefinition
from .composition import PendingNodeInvocation
from .decorators.op_decorator import DecoratedOpFunction
Expand Down Expand Up @@ -109,7 +109,7 @@ def direct_invocation_result(
) -> Any:
from dagster._config.pythonic_config import Config
from dagster._core.execution.context.invocation import (
UnboundOpExecutionContext,
DirectOpExecutionContext,
build_op_context,
)

Expand Down Expand Up @@ -149,12 +149,12 @@ def direct_invocation_result(
" no context was provided when invoking."
)
if len(args) > 0:
if args[0] is not None and not isinstance(args[0], UnboundOpExecutionContext):
if args[0] is not None and not isinstance(args[0], DirectOpExecutionContext):
raise DagsterInvalidInvocationError(
f"Decorated function '{compute_fn.name}' has context argument, "
"but no context was provided when invoking."
)
context = cast(UnboundOpExecutionContext, args[0])
context = cast(DirectOpExecutionContext, args[0])
# update args to omit context
args = args[1:]
else: # context argument is provided under kwargs
Expand All @@ -165,14 +165,14 @@ def direct_invocation_result(
f"'{context_param_name}', but no value for '{context_param_name}' was "
f"found when invoking. Provided kwargs: {kwargs}"
)
context = cast(UnboundOpExecutionContext, kwargs[context_param_name])
context = cast(DirectOpExecutionContext, kwargs[context_param_name])
# update kwargs to remove context
kwargs = {
kwarg: val for kwarg, val in kwargs.items() if not kwarg == context_param_name
}
# allow passing context, even if the function doesn't have an arg for it
elif len(args) > 0 and isinstance(args[0], UnboundOpExecutionContext):
context = cast(UnboundOpExecutionContext, args[0])
elif len(args) > 0 and isinstance(args[0], DirectOpExecutionContext):
context = cast(DirectOpExecutionContext, args[0])
args = args[1:]

resource_arg_mapping = {arg.name: arg.name for arg in compute_fn.get_resource_args()}
Expand Down Expand Up @@ -206,24 +206,31 @@ def direct_invocation_result(
),
)

input_dict = _resolve_inputs(op_def, input_args, input_kwargs, bound_context)
try:
# if the compute function fails, we want to ensure we unbind the context. This
# try-except handles "vanilla" asset and op invocation (generators and async handled in
# _type_check_output_wrapper)

result = invoke_compute_fn(
fn=compute_fn.decorated_fn,
context=bound_context,
kwargs=input_dict,
context_arg_provided=compute_fn.has_context_arg(),
config_arg_cls=(
compute_fn.get_config_arg().annotation if compute_fn.has_config_arg() else None
),
resource_args=resource_arg_mapping,
)
input_dict = _resolve_inputs(op_def, input_args, input_kwargs, bound_context)

return _type_check_output_wrapper(op_def, result, bound_context)
result = invoke_compute_fn(
fn=compute_fn.decorated_fn,
context=bound_context,
kwargs=input_dict,
context_arg_provided=compute_fn.has_context_arg(),
config_arg_cls=(
compute_fn.get_config_arg().annotation if compute_fn.has_config_arg() else None
),
resource_args=resource_arg_mapping,
)
return _type_check_output_wrapper(op_def, result, bound_context)
except Exception:
bound_context.unbind()
raise


def _resolve_inputs(
op_def: "OpDefinition", args, kwargs, context: "BoundOpExecutionContext"
op_def: "OpDefinition", args, kwargs, context: "DirectOpExecutionContext"
) -> Mapping[str, Any]:
from dagster._core.execution.plan.execute_step import do_type_check

Expand Down Expand Up @@ -263,7 +270,7 @@ def _resolve_inputs(

node_label = op_def.node_type_str
raise DagsterInvalidInvocationError(
f"Too many input arguments were provided for {node_label} '{context.alias}'."
f"Too many input arguments were provided for {node_label} '{context.per_invocation_properties.alias}'."
f" {suggestion}"
)

Expand Down Expand Up @@ -306,7 +313,7 @@ def _resolve_inputs(
input_dict[k] = v

# Type check inputs
op_label = context.describe_op()
op_label = context.per_invocation_properties.step_description

for input_name, val in input_dict.items():
input_def = input_defs_by_name[input_name]
Expand All @@ -326,31 +333,42 @@ def _resolve_inputs(
return input_dict


def _key_for_result(result: MaterializeResult, context: "BoundOpExecutionContext") -> AssetKey:
def _key_for_result(result: MaterializeResult, context: "DirectOpExecutionContext") -> AssetKey:
if not context.per_invocation_properties.assets_def:
raise DagsterInvariantViolationError(
f"Op {context.per_invocation_properties.alias} does not have an assets definition."
)
if result.asset_key:
return result.asset_key

if len(context.assets_def.keys) == 1:
return next(iter(context.assets_def.keys))
if (
context.per_invocation_properties.assets_def
and len(context.per_invocation_properties.assets_def.keys) == 1
):
return next(iter(context.per_invocation_properties.assets_def.keys))

raise DagsterInvariantViolationError(
"MaterializeResult did not include asset_key and it can not be inferred. Specify which"
f" asset_key, options are: {context.assets_def.keys}"
f" asset_key, options are: {context.per_invocation_properties.assets_def.keys}"
)


def _output_name_for_result_obj(
event: MaterializeResult,
context: "BoundOpExecutionContext",
context: "DirectOpExecutionContext",
):
if not context.per_invocation_properties.assets_def:
raise DagsterInvariantViolationError(
f"Op {context.per_invocation_properties.alias} does not have an assets definition."
)
asset_key = _key_for_result(event, context)
return context.assets_def.get_output_name_for_asset_key(asset_key)
return context.per_invocation_properties.assets_def.get_output_name_for_asset_key(asset_key)


def _handle_gen_event(
event: T,
op_def: "OpDefinition",
context: "BoundOpExecutionContext",
context: "DirectOpExecutionContext",
output_defs: Mapping[str, OutputDefinition],
outputs_seen: Set[str],
) -> T:
Expand All @@ -376,15 +394,15 @@ def _handle_gen_event(
output_def, DynamicOutputDefinition
):
raise DagsterInvariantViolationError(
f"Invocation of {op_def.node_type_str} '{context.alias}' yielded"
f"Invocation of {op_def.node_type_str} '{context.per_invocation_properties.alias}' yielded"
f" an output '{output_def.name}' multiple times."
)
outputs_seen.add(output_def.name)
return event


def _type_check_output_wrapper(
op_def: "OpDefinition", result: Any, context: "BoundOpExecutionContext"
op_def: "OpDefinition", result: Any, context: "DirectOpExecutionContext"
) -> Any:
"""Type checks and returns the result of a op.

Expand All @@ -399,8 +417,14 @@ def _type_check_output_wrapper(
async def to_gen(async_gen):
outputs_seen = set()

async for event in async_gen:
yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen)
try:
# if the compute function fails, we want to ensure we unbind the context. For
# async generators, the errors will only be surfaced here
async for event in async_gen:
yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen)
except Exception:
context.unbind()
raise

for output_def in op_def.output_defs:
if (
Expand All @@ -413,17 +437,24 @@ async def to_gen(async_gen):
yield Output(output_name=output_def.name, value=None)
else:
raise DagsterInvariantViolationError(
f"Invocation of {op_def.node_type_str} '{context.alias}' did not"
f"Invocation of {op_def.node_type_str} '{context.per_invocation_properties.alias}' did not"
f" return an output for non-optional output '{output_def.name}'"
)
context.unbind()

return to_gen(result)

# Coroutine result case
elif inspect.iscoroutine(result):

async def type_check_coroutine(coro):
out = await coro
try:
# if the compute function fails, we want to ensure we unbind the context. For
# async, the errors will only be surfaced here
out = await coro
except Exception:
context.unbind()
raise
return _type_check_function_output(op_def, out, context)

return type_check_coroutine(result)
Expand All @@ -433,8 +464,14 @@ async def type_check_coroutine(coro):

def type_check_gen(gen):
outputs_seen = set()
for event in gen:
yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen)
try:
# if the compute function fails, we want to ensure we unbind the context. For
# generators, the errors will only be surfaced here
for event in gen:
yield _handle_gen_event(event, op_def, context, output_defs, outputs_seen)
except Exception:
context.unbind()
raise

for output_def in op_def.output_defs:
if (
Expand All @@ -447,9 +484,10 @@ def type_check_gen(gen):
yield Output(output_name=output_def.name, value=None)
else:
raise DagsterInvariantViolationError(
f'Invocation of {op_def.node_type_str} "{context.alias}" did not'
f'Invocation of {op_def.node_type_str} "{context.per_invocation_properties.alias}" did not'
f' return an output for non-optional output "{output_def.name}"'
)
context.unbind()

return type_check_gen(result)

Expand All @@ -458,7 +496,7 @@ def type_check_gen(gen):


def _type_check_function_output(
op_def: "OpDefinition", result: T, context: "BoundOpExecutionContext"
op_def: "OpDefinition", result: T, context: "DirectOpExecutionContext"
) -> T:
from ..execution.plan.compute_generator import validate_and_coerce_op_result_to_iterator

Expand All @@ -470,25 +508,26 @@ def _type_check_function_output(
# ensure result objects are contextually valid
_output_name_for_result_obj(event, context)

context.unbind()
return result


def _type_check_output(
output_def: "OutputDefinition",
output: Union[Output, DynamicOutput],
context: "BoundOpExecutionContext",
context: "DirectOpExecutionContext",
) -> None:
"""Validates and performs core type check on a provided output.

Args:
output_def (OutputDefinition): The output definition to validate against.
output (Any): The output to validate.
context (BoundOpExecutionContext): Context containing resources to be used for type
context (DirectOpExecutionContext): Context containing resources to be used for type
check.
"""
from ..execution.plan.execute_step import do_type_check

op_label = context.describe_op()
op_label = context.per_invocation_properties.step_description
dagster_type = output_def.dagster_type
type_check = do_type_check(context.for_type(dagster_type), dagster_type, output.value)
if not type_check.success:
Expand Down
Loading