Skip to content

Commit

Permalink
update typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Nov 28, 2023
1 parent bbfb0a4 commit 493bdaa
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
from .result import MaterializeResult

if TYPE_CHECKING:
from ..execution.context.invocation import DirectInvocationOpExecutionContext
from ..execution.context.invocation import (
DirectInvocationOpExecutionContext,
)
from .assets import AssetsDefinition
from .composition import PendingNodeInvocation
from .decorators.op_decorator import DecoratedOpFunction
Expand Down Expand Up @@ -109,7 +111,7 @@ def direct_invocation_result(
) -> Any:
from dagster._config.pythonic_config import Config
from dagster._core.execution.context.invocation import (
DirectInvocationOpExecutionContext,
BaseDirectInvocationContext,
build_op_context,
)

Expand Down Expand Up @@ -149,12 +151,12 @@ def direct_invocation_result(
" no context was provided when invoking."
)
if len(args) > 0:
if args[0] is not None and not isinstance(args[0], DirectInvocationOpExecutionContext):
if args[0] is not None and not isinstance(args[0], BaseDirectInvocationContext):
raise DagsterInvalidInvocationError(
f"Decorated function '{compute_fn.name}' has context argument, "
"but no context was provided when invoking."
)
context = cast(DirectInvocationOpExecutionContext, args[0])
context = args[0]
# update args to omit context
args = args[1:]
else: # context argument is provided under kwargs
Expand All @@ -165,14 +167,14 @@ def direct_invocation_result(
f"'{context_param_name}', but no value for '{context_param_name}' was "
f"found when invoking. Provided kwargs: {kwargs}"
)
context = cast(DirectInvocationOpExecutionContext, kwargs[context_param_name])
context = kwargs[context_param_name]
# update kwargs to remove context
kwargs = {
kwarg: val for kwarg, val in kwargs.items() if not kwarg == context_param_name
}
# allow passing context, even if the function doesn't have an arg for it
elif len(args) > 0 and isinstance(args[0], DirectInvocationOpExecutionContext):
context = cast(DirectInvocationOpExecutionContext, args[0])
elif len(args) > 0 and isinstance(args[0], BaseDirectInvocationContext):
context = args[0]
args = args[1:]

resource_arg_mapping = {arg.name: arg.name for arg in compute_fn.get_resource_args()}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,19 @@ def __new__(cls):
)


class DirectInvocationOpExecutionContext(OpExecutionContext):
class BaseDirectInvocationContext:
def bind(
self,
op_def: OpDefinition,
pending_invocation: Optional[PendingNodeInvocation[OpDefinition]],
assets_def: Optional[AssetsDefinition],
config_from_args: Optional[Mapping[str, Any]],
resources_from_args: Optional[Mapping[str, Any]],
):
pass


class DirectInvocationOpExecutionContext(OpExecutionContext, BaseDirectInvocationContext):
"""The ``context`` object available as the first argument to an op's compute function when
being invoked directly. Can also be used as a context manager.
"""
Expand Down Expand Up @@ -676,7 +688,7 @@ def _validate_resource_requirements(
ensure_requirements_satisfied(resource_defs, [requirement])


class DirectInvocationAssetExecutionContext(AssetExecutionContext):
class DirectInvocationAssetExecutionContext(AssetExecutionContext, BaseDirectInvocationContext):
"""The ``context`` object available as the first argument to an op's compute function when
being invoked directly. Can also be used as a context manager.
"""
Expand Down

0 comments on commit 493bdaa

Please sign in to comment.