From df0afcb40a5183261493aaff459b7bf483c15937 Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Fri, 15 Sep 2023 10:21:38 -0400 Subject: [PATCH] [ext] Use MaterializeResult for ext_protocol --- .../test_external_execution.py | 7 ++-- .../dagster/dagster/_core/ext/client.py | 6 ++-- .../dagster/dagster/_core/ext/context.py | 35 +++++++++++++++---- .../dagster/dagster/_core/ext/subprocess.py | 8 +++-- .../dagster_databricks/ext.py | 9 +++-- .../dagster-docker/dagster_docker/ext.py | 7 ++-- .../libraries/dagster-k8s/dagster_k8s/ext.py | 11 +++--- 7 files changed, 62 insertions(+), 21 deletions(-) diff --git a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py index 84396dce781b4..a54058ce21c88 100644 --- a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py +++ b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py @@ -148,7 +148,7 @@ def test_ext_subprocess( def foo(context: AssetExecutionContext, ext: ExtSubprocess): extras = {"bar": "baz"} cmd = [_PYTHON_EXECUTABLE, external_script] - ext.run( + return ext.run( cmd, context=context, extras=extras, @@ -201,7 +201,7 @@ def script_fn(): def foo(context: AssetExecutionContext, ext: ExtSubprocess): with temp_script(script_fn) as script_path: cmd = [_PYTHON_EXECUTABLE, script_path] - ext.run(cmd, context=context) + return ext.run(cmd, context=context) with instance_for_test() as instance: materialize( @@ -313,6 +313,9 @@ def subproc_run(context: AssetExecutionContext): extras=extras, ) as ext_context: subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False) + _ext_context = ext_context + mat_results = _ext_context.get_materialize_results() + return mat_results[0] if len(mat_results) == 1 else mat_results with instance_for_test() as instance: materialize( diff --git a/python_modules/dagster/dagster/_core/ext/client.py b/python_modules/dagster/dagster/_core/ext/client.py index 2e89e50754be3..ab3c9fdd7e97d 100644 --- a/python_modules/dagster/dagster/_core/ext/client.py +++ b/python_modules/dagster/dagster/_core/ext/client.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import TYPE_CHECKING, Iterator, Optional +from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union from dagster_ext import ( ExtContextData, @@ -11,6 +11,8 @@ from dagster._core.execution.context.compute import OpExecutionContext if TYPE_CHECKING: + from dagster._core.definitions.result import MaterializeResult + from .context import ExtMessageHandler @@ -21,7 +23,7 @@ def run( *, context: OpExecutionContext, extras: Optional[ExtExtras] = None, - ) -> None: ... + ) -> Union["MaterializeResult", Tuple["MaterializeResult", ...]]: ... class ExtContextInjector(ABC): diff --git a/python_modules/dagster/dagster/_core/ext/context.py b/python_modules/dagster/dagster/_core/ext/context.py index 6f7e5c47339ce..5a42e03d4ad19 100644 --- a/python_modules/dagster/dagster/_core/ext/context.py +++ b/python_modules/dagster/dagster/_core/ext/context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Mapping, Optional, get_args +from typing import Any, Dict, Mapping, Optional, Tuple, get_args from dagster_ext import ( DAGSTER_EXT_ENV_KEYS, @@ -19,6 +19,7 @@ from dagster._core.definitions.events import AssetKey from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value from dagster._core.definitions.partition_key_range import PartitionKeyRange +from dagster._core.definitions.result import MaterializeResult from dagster._core.definitions.time_window_partitions import TimeWindow from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.execution.context.invocation import BoundOpExecutionContext @@ -27,6 +28,16 @@ class ExtMessageHandler: def __init__(self, context: OpExecutionContext) -> None: self._context = context + self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {} + self._data_versions: Dict[AssetKey, DataVersion] = {} + + @property + def metadata(self) -> Dict[AssetKey, Dict[str, MetadataValue]]: + return self._metadata + + @property + def data_versions(self) -> Dict[AssetKey, DataVersion]: + return self._data_versions # Type ignores because we currently validate in individual handlers def handle_message(self, message: ExtMessage) -> None: @@ -43,10 +54,9 @@ def _handle_report_asset_metadata( check.str_param(asset_key, "asset_key") check.str_param(label, "label") check.opt_literal_param(type, "type", get_args(ExtMetadataType)) - key = AssetKey.from_user_string(asset_key) - output_name = self._context.output_for_asset_key(key) + resolved_asset_key = AssetKey.from_user_string(asset_key) metadata_value = self._resolve_metadata_value(value, type) - self._context.add_output_metadata({label: metadata_value}, output_name) + self._metadata.setdefault(resolved_asset_key, {})[label] = metadata_value def _resolve_metadata_value( self, value: Any, metadata_type: Optional[ExtMetadataType] @@ -85,8 +95,8 @@ def _resolve_metadata_value( def _handle_report_asset_data_version(self, asset_key: str, data_version: str) -> None: check.str_param(asset_key, "asset_key") check.str_param(data_version, "data_version") - key = AssetKey.from_user_string(asset_key) - self._context.set_data_version(key, DataVersion(data_version)) + resolved_asset_key = AssetKey.from_user_string(asset_key) + self._data_versions[resolved_asset_key] = DataVersion(data_version) def _handle_log(self, message: str, level: str = "info") -> None: check.str_param(message, "message") @@ -118,6 +128,19 @@ def get_external_process_env_vars(self): ), } + def get_materialize_results(self) -> Tuple[MaterializeResult, ...]: + return tuple( + self._materialize_result_for_asset(AssetKey.from_user_string(key)) + for key in self.context_data["asset_keys"] or [] + ) + + def _materialize_result_for_asset(self, asset_key: AssetKey): + return MaterializeResult( + asset_key=asset_key, + metadata=self.message_handler.metadata.get(asset_key), + data_version=self.message_handler.data_versions.get(asset_key), + ) + def build_external_execution_context_data( context: OpExecutionContext, diff --git a/python_modules/dagster/dagster/_core/ext/subprocess.py b/python_modules/dagster/dagster/_core/ext/subprocess.py index aa07751501d09..e46240472600e 100644 --- a/python_modules/dagster/dagster/_core/ext/subprocess.py +++ b/python_modules/dagster/dagster/_core/ext/subprocess.py @@ -1,10 +1,11 @@ from subprocess import Popen -from typing import Mapping, Optional, Sequence, Union +from typing import Mapping, Optional, Sequence, Tuple, Union from dagster_ext import ExtExtras from dagster import _check as check from dagster._core.definitions.resource_annotation import ResourceParam +from dagster._core.definitions.result import MaterializeResult from dagster._core.errors import DagsterExternalExecutionError from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.ext.client import ( @@ -66,7 +67,7 @@ def run( extras: Optional[ExtExtras] = None, env: Optional[Mapping[str, str]] = None, cwd: Optional[str] = None, - ) -> None: + ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: with ext_protocol( context=context, context_injector=self.context_injector, @@ -88,6 +89,9 @@ def run( raise DagsterExternalExecutionError( f"External execution process failed with code {process.returncode}" ) + _ext_context = ext_context + mat_results = _ext_context.get_materialize_results() + return mat_results[0] if len(mat_results) == 1 else mat_results ExtSubprocess = ResourceParam[_ExtSubprocess] diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py index 22bb89034cd07..ba26f5e32c1fb 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py @@ -5,10 +5,11 @@ import string import time from contextlib import contextmanager -from typing import Iterator, Mapping, Optional +from typing import Iterator, Mapping, Optional, Tuple, Union import dagster._check as check from dagster._core.definitions.resource_annotation import ResourceParam +from dagster._core.definitions.result import MaterializeResult from dagster._core.errors import DagsterExternalExecutionError from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.ext.client import ExtClient, ExtContextInjector, ExtMessageReader @@ -66,7 +67,7 @@ def run( context: OpExecutionContext, extras: Optional[ExtExtras] = None, submit_args: Optional[Mapping[str, str]] = None, - ) -> None: + ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: """Run a Databricks job with the EXT protocol. Args: @@ -106,7 +107,7 @@ def run( jobs.RunLifeCycleState.SKIPPED, ): if run.state.result_state == jobs.RunResultState.SUCCESS: - return + break else: raise DagsterExternalExecutionError( f"Error running Databricks job: {run.state.state_message}" @@ -116,6 +117,8 @@ def run( f"Error running Databricks job: {run.state.state_message}" ) time.sleep(5) + _ext_context = ext_context + return _ext_context.get_materialize_results() ExtDatabricks = ResourceParam[_ExtDatabricks] diff --git a/python_modules/libraries/dagster-docker/dagster_docker/ext.py b/python_modules/libraries/dagster-docker/dagster_docker/ext.py index 0862d72d014cc..f15c4f406edfa 100644 --- a/python_modules/libraries/dagster-docker/dagster_docker/ext.py +++ b/python_modules/libraries/dagster-docker/dagster_docker/ext.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Any, Iterator, Mapping, Optional, Sequence, Union +from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union import docker from dagster import ( @@ -7,6 +7,7 @@ ResourceParam, _check as check, ) +from dagster._core.definitions.result import MaterializeResult from dagster._core.ext.client import ( ExtClient, ExtContextInjector, @@ -94,7 +95,7 @@ def run( registry: Optional[Mapping[str, str]] = None, container_kwargs: Optional[Mapping[str, Any]] = None, extras: Optional[ExtExtras] = None, - ) -> None: + ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: """Create a docker container and run it to completion, enriched with the ext protocol. Args: @@ -162,6 +163,8 @@ def run( raise DagsterExtError(f"Container exited with non-zero status code: {result}") finally: container.stop() + _ext_context = ext_context + return _ext_context.get_materialize_results() def _create_container( self, diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py index 6a37341f770eb..abffc02ac8e18 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py @@ -1,7 +1,7 @@ import random import string from contextlib import contextmanager -from typing import Any, Iterator, Mapping, Optional, Sequence, Union +from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union import kubernetes from dagster import ( @@ -9,6 +9,7 @@ _check as check, ) from dagster._core.definitions.resource_annotation import ResourceParam +from dagster._core.definitions.result import MaterializeResult from dagster._core.errors import DagsterInvariantViolationError from dagster._core.ext.client import ( ExtClient, @@ -123,7 +124,7 @@ def run( base_pod_meta: Optional[Mapping[str, Any]] = None, base_pod_spec: Optional[Mapping[str, Any]] = None, extras: Optional[ExtExtras] = None, - ) -> None: + ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: """Publish a kubernetes pod and wait for it to complete, enriched with the ext protocol. Args: @@ -158,7 +159,7 @@ def run( extras=extras, context_injector=self.context_injector, message_reader=self.message_reader, - ) as ext_process: + ) as ext_context: namespace = namespace or "default" pod_name = get_pod_name(context.run_id, context.op.name) pod_body = build_pod_body( @@ -166,7 +167,7 @@ def run( image=image, command=command, env_vars={ - **ext_process.get_external_process_env_vars(), + **ext_context.get_external_process_env_vars(), **(self.env or {}), **(env or {}), }, @@ -196,6 +197,8 @@ def run( ) finally: client.core_api.delete_namespaced_pod(pod_name, namespace) + _ext_context = ext_context + return _ext_context.get_materialize_results() def build_pod_body(