Skip to content

Commit

Permalink
[ext] Use MaterializeResult for ext_protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Sep 19, 2023
1 parent 09b2cea commit dd06840
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_ext_subprocess(
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
ext.run(
return ext.run(
cmd,
context=context,
extras=extras,
Expand Down Expand Up @@ -201,7 +201,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
ext.run(cmd, context=context)
return ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize(
Expand Down Expand Up @@ -313,6 +313,9 @@ def subproc_run(context: AssetExecutionContext):
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
_ext_context = ext_context
mat_results = _ext_context.get_materialize_results()
return mat_results[0] if len(mat_results) == 1 else mat_results

with instance_for_test() as instance:
materialize(
Expand All @@ -325,3 +328,30 @@ def subproc_run(context: AssetExecutionContext):
assert mat.asset_materialization.tags
assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]


def test_ext_no_client_premature_get_results(external_script):
@asset
def subproc_run(context: AssetExecutionContext):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]

with ext_protocol(
context,
ExtTempFileContextInjector(),
ExtTempFileMessageReader(),
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
return ext_context.get_materialize_results()

with pytest.raises(
DagsterExternalExecutionError,
match=(
"`get_materialize_results` must be called after the `ext_protocol` context manager has"
" exited."
),
):
materialize(
[subproc_run],
)
6 changes: 4 additions & 2 deletions python_modules/dagster/dagster/_core/ext/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional
from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union

from dagster_ext import (
ExtContextData,
Expand All @@ -11,6 +11,8 @@
from dagster._core.execution.context.compute import OpExecutionContext

if TYPE_CHECKING:
from dagster._core.definitions.result import MaterializeResult

from .context import ExtMessageHandler


Expand All @@ -21,7 +23,7 @@ def run(
*,
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
) -> None: ...
) -> Union["MaterializeResult", Tuple["MaterializeResult", ...]]: ...


class ExtContextInjector(ABC):
Expand Down
42 changes: 36 additions & 6 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Mapping, Optional, get_args
from typing import Any, Dict, Mapping, Optional, Tuple, get_args

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
Expand All @@ -19,14 +19,26 @@
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value
from dagster._core.definitions.partition_key_range import PartitionKeyRange
from dagster._core.definitions.result import MaterializeResult
from dagster._core.definitions.time_window_partitions import TimeWindow
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.invocation import BoundOpExecutionContext


class ExtMessageHandler:
def __init__(self, context: OpExecutionContext) -> None:
self._context = context
self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
self._data_versions: Dict[AssetKey, DataVersion] = {}

@property
def metadata(self) -> Dict[AssetKey, Dict[str, MetadataValue]]:
return self._metadata

@property
def data_versions(self) -> Dict[AssetKey, DataVersion]:
return self._data_versions

# Type ignores because we currently validate in individual handlers
def handle_message(self, message: ExtMessage) -> None:
Expand All @@ -43,10 +55,9 @@ def _handle_report_asset_metadata(
check.str_param(asset_key, "asset_key")
check.str_param(label, "label")
check.opt_literal_param(type, "type", get_args(ExtMetadataType))
key = AssetKey.from_user_string(asset_key)
output_name = self._context.output_for_asset_key(key)
resolved_asset_key = AssetKey.from_user_string(asset_key)
metadata_value = self._resolve_metadata_value(value, type)
self._context.add_output_metadata({label: metadata_value}, output_name)
self._metadata.setdefault(resolved_asset_key, {})[label] = metadata_value

def _resolve_metadata_value(
self, value: Any, metadata_type: Optional[ExtMetadataType]
Expand Down Expand Up @@ -85,8 +96,8 @@ def _resolve_metadata_value(
def _handle_report_asset_data_version(self, asset_key: str, data_version: str) -> None:
check.str_param(asset_key, "asset_key")
check.str_param(data_version, "data_version")
key = AssetKey.from_user_string(asset_key)
self._context.set_data_version(key, DataVersion(data_version))
resolved_asset_key = AssetKey.from_user_string(asset_key)
self._data_versions[resolved_asset_key] = DataVersion(data_version)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
Expand All @@ -108,6 +119,7 @@ class ExtOrchestrationContext:
message_handler: ExtMessageHandler
context_injector_params: ExtParams
message_reader_params: ExtParams
is_task_finished: bool = False

def get_external_process_env_vars(self):
return {
Expand All @@ -118,6 +130,24 @@ def get_external_process_env_vars(self):
),
}

def get_materialize_results(self) -> Tuple[MaterializeResult, ...]:
if not self.is_task_finished:
raise DagsterExternalExecutionError(
"`get_materialize_results` must be called after the `ext_protocol` context manager"
" has exited."
)
return tuple(
self._materialize_result_for_asset(AssetKey.from_user_string(key))
for key in self.context_data["asset_keys"] or []
)

def _materialize_result_for_asset(self, asset_key: AssetKey):
return MaterializeResult(
asset_key=asset_key,
metadata=self.message_handler.metadata.get(asset_key),
data_version=self.message_handler.data_versions.get(asset_key),
)


def build_external_execution_context_data(
context: OpExecutionContext,
Expand Down
7 changes: 5 additions & 2 deletions python_modules/dagster/dagster/_core/ext/subprocess.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from subprocess import Popen
from typing import Mapping, Optional, Sequence, Union
from typing import Mapping, Optional, Sequence, Tuple, Union

from dagster_ext import ExtExtras

from dagster import _check as check
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.ext.client import (
Expand Down Expand Up @@ -66,7 +67,7 @@ def run(
extras: Optional[ExtExtras] = None,
env: Optional[Mapping[str, str]] = None,
cwd: Optional[str] = None,
) -> None:
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
with ext_protocol(
context=context,
context_injector=self.context_injector,
Expand All @@ -88,6 +89,8 @@ def run(
raise DagsterExternalExecutionError(
f"External execution process failed with code {process.returncode}"
)
mat_results = ext_context.get_materialize_results()
return mat_results[0] if len(mat_results) == 1 else mat_results


ExtSubprocess = ResourceParam[_ExtSubprocess]
4 changes: 3 additions & 1 deletion python_modules/dagster/dagster/_core/ext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,11 @@ def ext_protocol(
) as ci_params, message_reader.read_messages(
msg_handler,
) as mr_params:
yield ExtOrchestrationContext(
ext_context = ExtOrchestrationContext(
context_data=context_data,
message_handler=msg_handler,
context_injector_params=ci_params,
message_reader_params=mr_params,
)
yield ext_context
ext_context.is_task_finished = True
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import string
import time
from contextlib import contextmanager
from typing import Iterator, Mapping, Optional
from typing import Iterator, Mapping, Optional, Tuple, Union

import dagster._check as check
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.ext.client import ExtClient, ExtContextInjector, ExtMessageReader
Expand Down Expand Up @@ -66,7 +67,7 @@ def run(
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
submit_args: Optional[Mapping[str, str]] = None,
) -> None:
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
"""Run a Databricks job with the EXT protocol.
Args:
Expand Down Expand Up @@ -106,7 +107,7 @@ def run(
jobs.RunLifeCycleState.SKIPPED,
):
if run.state.result_state == jobs.RunResultState.SUCCESS:
return
break
else:
raise DagsterExternalExecutionError(
f"Error running Databricks job: {run.state.state_message}"
Expand All @@ -116,6 +117,7 @@ def run(
f"Error running Databricks job: {run.state.state_message}"
)
time.sleep(5)
return ext_context.get_materialize_results()


ExtDatabricks = ResourceParam[_ExtDatabricks]
Expand Down
6 changes: 4 additions & 2 deletions python_modules/libraries/dagster-docker/dagster_docker/ext.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from contextlib import contextmanager
from typing import Any, Iterator, Mapping, Optional, Sequence, Union
from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union

import docker
from dagster import (
OpExecutionContext,
ResourceParam,
_check as check,
)
from dagster._core.definitions.result import MaterializeResult
from dagster._core.ext.client import (
ExtClient,
ExtContextInjector,
Expand Down Expand Up @@ -94,7 +95,7 @@ def run(
registry: Optional[Mapping[str, str]] = None,
container_kwargs: Optional[Mapping[str, Any]] = None,
extras: Optional[ExtExtras] = None,
) -> None:
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
"""Create a docker container and run it to completion, enriched with the ext protocol.
Args:
Expand Down Expand Up @@ -162,6 +163,7 @@ def run(
raise DagsterExtError(f"Container exited with non-zero status code: {result}")
finally:
container.stop()
return ext_context.get_materialize_results()

def _create_container(
self,
Expand Down
10 changes: 6 additions & 4 deletions python_modules/libraries/dagster-k8s/dagster_k8s/ext.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import random
import string
from contextlib import contextmanager
from typing import Any, Iterator, Mapping, Optional, Sequence, Union
from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union

import kubernetes
from dagster import (
OpExecutionContext,
_check as check,
)
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.ext.client import (
ExtClient,
Expand Down Expand Up @@ -123,7 +124,7 @@ def run(
base_pod_meta: Optional[Mapping[str, Any]] = None,
base_pod_spec: Optional[Mapping[str, Any]] = None,
extras: Optional[ExtExtras] = None,
) -> None:
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
"""Publish a kubernetes pod and wait for it to complete, enriched with the ext protocol.
Args:
Expand Down Expand Up @@ -158,15 +159,15 @@ def run(
extras=extras,
context_injector=self.context_injector,
message_reader=self.message_reader,
) as ext_process:
) as ext_context:
namespace = namespace or "default"
pod_name = get_pod_name(context.run_id, context.op.name)
pod_body = build_pod_body(
pod_name=pod_name,
image=image,
command=command,
env_vars={
**ext_process.get_external_process_env_vars(),
**ext_context.get_external_process_env_vars(),
**(self.env or {}),
**(env or {}),
},
Expand Down Expand Up @@ -196,6 +197,7 @@ def run(
)
finally:
client.core_api.delete_namespaced_pod(pod_name, namespace)
return ext_context.get_materialize_results()


def build_pod_body(
Expand Down

0 comments on commit dd06840

Please sign in to comment.