From 01edc65e63d295b71f1fbfb3627f351afce9d3e8 Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Fri, 15 Sep 2023 10:21:38 -0400 Subject: [PATCH] [ext] Use MaterializeResult for ext_protocol --- .../tests/test_external_asset.py | 5 ++- .../test_external_execution.py | 34 ++++++++++++++- .../dagster/dagster/_core/ext/client.py | 6 ++- .../dagster/dagster/_core/ext/context.py | 42 ++++++++++++++++--- .../dagster/dagster/_core/ext/subprocess.py | 7 +++- .../dagster/dagster/_core/ext/utils.py | 4 +- .../dagster_databricks/ext.py | 8 ++-- .../dagster-docker/dagster_docker/ext.py | 6 ++- .../libraries/dagster-k8s/dagster_k8s/ext.py | 11 +++-- 9 files changed, 99 insertions(+), 24 deletions(-) diff --git a/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py b/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py index 6be73a917fe00..b72e2f5e16aa6 100644 --- a/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py +++ b/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py @@ -26,7 +26,7 @@ def number_y( context: AssetExecutionContext, ext_k8s_pod: ExtK8sPod, ): - ext_k8s_pod.run( + return ext_k8s_pod.run( context=context, namespace=namespace, image=docker_image, @@ -138,7 +138,7 @@ def number_y( ], ) - ext_k8s_pod.run( + return ext_k8s_pod.run( context=context, namespace=namespace, extras={ @@ -197,6 +197,7 @@ def number_y_job(context: AssetExecutionContext): k8s_job_name=job_name, ) reader.consume_pod_logs(core_api, job_name, namespace) + return ext_context.get_materialize_results() result = materialize( [number_y_job], diff --git a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py index 84396dce781b4..35c47b8570fb0 100644 --- a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py +++ b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py @@ -148,7 +148,7 @@ def test_ext_subprocess( def foo(context: AssetExecutionContext, ext: ExtSubprocess): extras = {"bar": "baz"} cmd = [_PYTHON_EXECUTABLE, external_script] - ext.run( + return ext.run( cmd, context=context, extras=extras, @@ -201,7 +201,7 @@ def script_fn(): def foo(context: AssetExecutionContext, ext: ExtSubprocess): with temp_script(script_fn) as script_path: cmd = [_PYTHON_EXECUTABLE, script_path] - ext.run(cmd, context=context) + return ext.run(cmd, context=context) with instance_for_test() as instance: materialize( @@ -313,6 +313,9 @@ def subproc_run(context: AssetExecutionContext): extras=extras, ) as ext_context: subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False) + _ext_context = ext_context + mat_results = _ext_context.get_materialize_results() + return mat_results[0] if len(mat_results) == 1 else mat_results with instance_for_test() as instance: materialize( @@ -325,3 +328,30 @@ def subproc_run(context: AssetExecutionContext): assert mat.asset_materialization.tags assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha" assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG] + + +def test_ext_no_client_premature_get_results(external_script): + @asset + def subproc_run(context: AssetExecutionContext): + extras = {"bar": "baz"} + cmd = [_PYTHON_EXECUTABLE, external_script] + + with ext_protocol( + context, + ExtTempFileContextInjector(), + ExtTempFileMessageReader(), + extras=extras, + ) as ext_context: + subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False) + return ext_context.get_materialize_results() + + with pytest.raises( + DagsterExternalExecutionError, + match=( + "`get_materialize_results` must be called after the `ext_protocol` context manager has" + " exited." + ), + ): + materialize( + [subproc_run], + ) diff --git a/python_modules/dagster/dagster/_core/ext/client.py b/python_modules/dagster/dagster/_core/ext/client.py index 2e89e50754be3..ab3c9fdd7e97d 100644 --- a/python_modules/dagster/dagster/_core/ext/client.py +++ b/python_modules/dagster/dagster/_core/ext/client.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import TYPE_CHECKING, Iterator, Optional +from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union from dagster_ext import ( ExtContextData, @@ -11,6 +11,8 @@ from dagster._core.execution.context.compute import OpExecutionContext if TYPE_CHECKING: + from dagster._core.definitions.result import MaterializeResult + from .context import ExtMessageHandler @@ -21,7 +23,7 @@ def run( *, context: OpExecutionContext, extras: Optional[ExtExtras] = None, - ) -> None: ... + ) -> Union["MaterializeResult", Tuple["MaterializeResult", ...]]: ... class ExtContextInjector(ABC): diff --git a/python_modules/dagster/dagster/_core/ext/context.py b/python_modules/dagster/dagster/_core/ext/context.py index 6f7e5c47339ce..0cfbe18c98147 100644 --- a/python_modules/dagster/dagster/_core/ext/context.py +++ b/python_modules/dagster/dagster/_core/ext/context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Mapping, Optional, get_args +from typing import Any, Dict, Mapping, Optional, Tuple, get_args from dagster_ext import ( DAGSTER_EXT_ENV_KEYS, @@ -19,7 +19,9 @@ from dagster._core.definitions.events import AssetKey from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value from dagster._core.definitions.partition_key_range import PartitionKeyRange +from dagster._core.definitions.result import MaterializeResult from dagster._core.definitions.time_window_partitions import TimeWindow +from dagster._core.errors import DagsterExternalExecutionError from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.execution.context.invocation import BoundOpExecutionContext @@ -27,6 +29,16 @@ class ExtMessageHandler: def __init__(self, context: OpExecutionContext) -> None: self._context = context + self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {} + self._data_versions: Dict[AssetKey, DataVersion] = {} + + @property + def metadata(self) -> Dict[AssetKey, Dict[str, MetadataValue]]: + return self._metadata + + @property + def data_versions(self) -> Dict[AssetKey, DataVersion]: + return self._data_versions # Type ignores because we currently validate in individual handlers def handle_message(self, message: ExtMessage) -> None: @@ -43,10 +55,9 @@ def _handle_report_asset_metadata( check.str_param(asset_key, "asset_key") check.str_param(label, "label") check.opt_literal_param(type, "type", get_args(ExtMetadataType)) - key = AssetKey.from_user_string(asset_key) - output_name = self._context.output_for_asset_key(key) + resolved_asset_key = AssetKey.from_user_string(asset_key) metadata_value = self._resolve_metadata_value(value, type) - self._context.add_output_metadata({label: metadata_value}, output_name) + self._metadata.setdefault(resolved_asset_key, {})[label] = metadata_value def _resolve_metadata_value( self, value: Any, metadata_type: Optional[ExtMetadataType] @@ -85,8 +96,8 @@ def _resolve_metadata_value( def _handle_report_asset_data_version(self, asset_key: str, data_version: str) -> None: check.str_param(asset_key, "asset_key") check.str_param(data_version, "data_version") - key = AssetKey.from_user_string(asset_key) - self._context.set_data_version(key, DataVersion(data_version)) + resolved_asset_key = AssetKey.from_user_string(asset_key) + self._data_versions[resolved_asset_key] = DataVersion(data_version) def _handle_log(self, message: str, level: str = "info") -> None: check.str_param(message, "message") @@ -108,6 +119,7 @@ class ExtOrchestrationContext: message_handler: ExtMessageHandler context_injector_params: ExtParams message_reader_params: ExtParams + is_task_finished: bool = False def get_external_process_env_vars(self): return { @@ -118,6 +130,24 @@ def get_external_process_env_vars(self): ), } + def get_materialize_results(self) -> Tuple[MaterializeResult, ...]: + if not self.is_task_finished: + raise DagsterExternalExecutionError( + "`get_materialize_results` must be called after the `ext_protocol` context manager" + " has exited." + ) + return tuple( + self._materialize_result_for_asset(AssetKey.from_user_string(key)) + for key in self.context_data["asset_keys"] or [] + ) + + def _materialize_result_for_asset(self, asset_key: AssetKey): + return MaterializeResult( + asset_key=asset_key, + metadata=self.message_handler.metadata.get(asset_key), + data_version=self.message_handler.data_versions.get(asset_key), + ) + def build_external_execution_context_data( context: OpExecutionContext, diff --git a/python_modules/dagster/dagster/_core/ext/subprocess.py b/python_modules/dagster/dagster/_core/ext/subprocess.py index aa07751501d09..41b6e8d526009 100644 --- a/python_modules/dagster/dagster/_core/ext/subprocess.py +++ b/python_modules/dagster/dagster/_core/ext/subprocess.py @@ -1,10 +1,11 @@ from subprocess import Popen -from typing import Mapping, Optional, Sequence, Union +from typing import Mapping, Optional, Sequence, Tuple, Union from dagster_ext import ExtExtras from dagster import _check as check from dagster._core.definitions.resource_annotation import ResourceParam +from dagster._core.definitions.result import MaterializeResult from dagster._core.errors import DagsterExternalExecutionError from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.ext.client import ( @@ -66,7 +67,7 @@ def run( extras: Optional[ExtExtras] = None, env: Optional[Mapping[str, str]] = None, cwd: Optional[str] = None, - ) -> None: + ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: with ext_protocol( context=context, context_injector=self.context_injector, @@ -88,6 +89,8 @@ def run( raise DagsterExternalExecutionError( f"External execution process failed with code {process.returncode}" ) + mat_results = ext_context.get_materialize_results() + return mat_results[0] if len(mat_results) == 1 else mat_results ExtSubprocess = ResourceParam[_ExtSubprocess] diff --git a/python_modules/dagster/dagster/_core/ext/utils.py b/python_modules/dagster/dagster/_core/ext/utils.py index 96ded10869647..34ee9a1f71490 100644 --- a/python_modules/dagster/dagster/_core/ext/utils.py +++ b/python_modules/dagster/dagster/_core/ext/utils.py @@ -205,9 +205,11 @@ def ext_protocol( ) as ci_params, message_reader.read_messages( msg_handler, ) as mr_params: - yield ExtOrchestrationContext( + ext_context = ExtOrchestrationContext( context_data=context_data, message_handler=msg_handler, context_injector_params=ci_params, message_reader_params=mr_params, ) + yield ext_context + ext_context.is_task_finished = True diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py index 22bb89034cd07..b7f522a1e463d 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py @@ -5,10 +5,11 @@ import string import time from contextlib import contextmanager -from typing import Iterator, Mapping, Optional +from typing import Iterator, Mapping, Optional, Tuple, Union import dagster._check as check from dagster._core.definitions.resource_annotation import ResourceParam +from dagster._core.definitions.result import MaterializeResult from dagster._core.errors import DagsterExternalExecutionError from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.ext.client import ExtClient, ExtContextInjector, ExtMessageReader @@ -66,7 +67,7 @@ def run( context: OpExecutionContext, extras: Optional[ExtExtras] = None, submit_args: Optional[Mapping[str, str]] = None, - ) -> None: + ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: """Run a Databricks job with the EXT protocol. Args: @@ -106,7 +107,7 @@ def run( jobs.RunLifeCycleState.SKIPPED, ): if run.state.result_state == jobs.RunResultState.SUCCESS: - return + break else: raise DagsterExternalExecutionError( f"Error running Databricks job: {run.state.state_message}" @@ -116,6 +117,7 @@ def run( f"Error running Databricks job: {run.state.state_message}" ) time.sleep(5) + return ext_context.get_materialize_results() ExtDatabricks = ResourceParam[_ExtDatabricks] diff --git a/python_modules/libraries/dagster-docker/dagster_docker/ext.py b/python_modules/libraries/dagster-docker/dagster_docker/ext.py index 0862d72d014cc..aa379722416a2 100644 --- a/python_modules/libraries/dagster-docker/dagster_docker/ext.py +++ b/python_modules/libraries/dagster-docker/dagster_docker/ext.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Any, Iterator, Mapping, Optional, Sequence, Union +from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union import docker from dagster import ( @@ -7,6 +7,7 @@ ResourceParam, _check as check, ) +from dagster._core.definitions.result import MaterializeResult from dagster._core.ext.client import ( ExtClient, ExtContextInjector, @@ -94,7 +95,7 @@ def run( registry: Optional[Mapping[str, str]] = None, container_kwargs: Optional[Mapping[str, Any]] = None, extras: Optional[ExtExtras] = None, - ) -> None: + ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: """Create a docker container and run it to completion, enriched with the ext protocol. Args: @@ -162,6 +163,7 @@ def run( raise DagsterExtError(f"Container exited with non-zero status code: {result}") finally: container.stop() + return ext_context.get_materialize_results() def _create_container( self, diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py index 6a37341f770eb..ff4f99ee048c2 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py @@ -1,7 +1,7 @@ import random import string from contextlib import contextmanager -from typing import Any, Iterator, Mapping, Optional, Sequence, Union +from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union import kubernetes from dagster import ( @@ -9,6 +9,7 @@ _check as check, ) from dagster._core.definitions.resource_annotation import ResourceParam +from dagster._core.definitions.result import MaterializeResult from dagster._core.errors import DagsterInvariantViolationError from dagster._core.ext.client import ( ExtClient, @@ -123,7 +124,7 @@ def run( base_pod_meta: Optional[Mapping[str, Any]] = None, base_pod_spec: Optional[Mapping[str, Any]] = None, extras: Optional[ExtExtras] = None, - ) -> None: + ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: """Publish a kubernetes pod and wait for it to complete, enriched with the ext protocol. Args: @@ -158,7 +159,7 @@ def run( extras=extras, context_injector=self.context_injector, message_reader=self.message_reader, - ) as ext_process: + ) as ext_context: namespace = namespace or "default" pod_name = get_pod_name(context.run_id, context.op.name) pod_body = build_pod_body( @@ -166,7 +167,7 @@ def run( image=image, command=command, env_vars={ - **ext_process.get_external_process_env_vars(), + **ext_context.get_external_process_env_vars(), **(self.env or {}), **(env or {}), }, @@ -196,6 +197,8 @@ def run( ) finally: client.core_api.delete_namespaced_pod(pod_name, namespace) + mats = ext_context.get_materialize_results() + return mats[0] if len(mats) == 1 else mats def build_pod_body(