From 8830b3374c51c2ed11ab9c2b7f273bd91c8e6e50 Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Wed, 20 Sep 2023 15:15:57 -0400 Subject: [PATCH] yield MaterializeResult --- pyproject.toml | 4 +- .../dagster-ext/dagster_ext/__init__.py | 82 +++++++++++++++++-- .../dagster_ext_tests/test_context.py | 7 ++ .../test_external_execution.py | 74 ++++++++--------- .../_core/definitions/metadata/table.py | 2 +- .../dagster/dagster/_core/ext/context.py | 82 +++++++++++++++---- .../dagster/dagster/_core/ext/subprocess.py | 12 +-- .../dagster/dagster/_core/ext/utils.py | 2 +- .../dagster_databricks/ext.py | 4 +- .../dagster-docker/dagster_docker/ext.py | 4 +- .../libraries/dagster-k8s/dagster_k8s/ext.py | 7 +- 11 files changed, 202 insertions(+), 78 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 45d8e64db0c3e..e48ef636ca9cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -257,8 +257,8 @@ required-version = "0.0.289" [tool.ruff.flake8-builtins] -# We use `id` in many places and almost never want to use the python builtin. -builtins-ignorelist = ["id"] +# Id and type are frequently helpful as local variable or parameter names. +builtins-ignorelist = ["id", "type"] [tool.ruff.flake8-tidy-imports.banned-api] diff --git a/python_modules/dagster-ext/dagster_ext/__init__.py b/python_modules/dagster-ext/dagster_ext/__init__.py index a98bc911b118a..88cd8f99aedff 100644 --- a/python_modules/dagster-ext/dagster_ext/__init__.py +++ b/python_modules/dagster-ext/dagster_ext/__init__.py @@ -16,16 +16,19 @@ TYPE_CHECKING, Any, ClassVar, + Dict, Generic, Iterator, Literal, Mapping, Optional, Sequence, + Set, TextIO, Type, TypedDict, TypeVar, + Union, cast, get_args, ) @@ -98,6 +101,14 @@ class ExtDataProvenance(TypedDict): is_user_provided: bool +ExtMetadataRawValue = Union[int, float, str, Mapping[str, Any], Sequence[Any], bool, None] + + +class ExtMetadataValue(TypedDict): + type: Optional["ExtMetadataType"] + raw_value: ExtMetadataRawValue + + ExtMetadataType = Literal[ "text", "url", @@ -148,7 +159,10 @@ def _assert_single_asset(data: ExtContextData, key: str) -> None: def _resolve_optionally_passed_asset_key( - data: ExtContextData, asset_key: Optional[str], method: str + data: ExtContextData, + asset_key: Optional[str], + method: str, + already_materialized_assets: Set[str], ) -> str: asset_keys = _assert_defined_asset_property(data["asset_keys"], method) asset_key = _assert_opt_param_type(asset_key, str, method, "asset_key") @@ -163,6 +177,11 @@ def _resolve_optionally_passed_asset_key( " targets multiple assets." ) asset_key = asset_keys[0] + if asset_key in already_materialized_assets: + raise DagsterExtError( + f"Calling `{method}` with asset key `{asset_key}` is undefined. Asset has already been" + " materialized, so no additional data can be reported for it." + ) return asset_key @@ -259,6 +278,30 @@ def _assert_param_json_serializable(value: _T, method: str, param: str) -> _T: return value +def _normalize_param_metadata( + metadata: Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]], method: str, param: str +) -> Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]: + _assert_param_type(metadata, dict, method, param) + new_metadata: Dict[str, ExtMetadataValue] = {} + for key, value in metadata.items(): + if not isinstance(key, str): + raise DagsterExtError( + f"Invalid type for parameter `{param}` of `{method}`. Expected a dict with string" + f" keys, got a key `{key}` of type `{type(key)}`." + ) + elif isinstance(value, dict): + if not {*value.keys()} == {*ExtMetadataValue.__annotations__.keys()}: + raise DagsterExtError( + f"Invalid type for parameter `{param}` of `{method}`. Expected a dict with" + " string keys and values that are either raw metadata values or dictionaries" + f" with schema `{{raw_value: ..., type: ...}}`. Got a value `{value}`." + ) + new_metadata[key] = cast(ExtMetadataValue, value) + else: + new_metadata[key] = {"raw_value": value, "type": None} + return new_metadata + + def _param_from_env_var(key: str) -> Any: raw_value = os.environ.get(_param_name_to_env_var(key)) return decode_env_var(raw_value) if raw_value is not None else None @@ -625,6 +668,7 @@ def __init__( ) -> None: self._data = data self.message_channel = message_channel + self.materialized_assets: Set[str] = set() def _write_message(self, method: str, params: Optional[Mapping[str, Any]] = None) -> None: message = ExtMessage(method=method, params=params) @@ -730,26 +774,27 @@ def extras(self) -> Mapping[str, Any]: def report_asset_metadata( self, label: str, - value: Any, + value: ExtMetadataRawValue, metadata_type: Optional[ExtMetadataType] = None, asset_key: Optional[str] = None, ) -> None: asset_key = _resolve_optionally_passed_asset_key( - self._data, asset_key, "report_asset_metadata" + self._data, asset_key, "report_asset_metadata", self.materialized_assets ) label = _assert_param_type(label, str, "report_asset_metadata", "label") value = _assert_param_json_serializable(value, "report_asset_metadata", "value") - metadata_type = _assert_opt_param_value( - metadata_type, get_args(ExtMetadataType), "report_asset_metadata", "type" + type = _assert_opt_param_value( + metadata_type, get_args(ExtMetadataType), "report_asset_metadata", "metadata_type" ) + self._write_message( "report_asset_metadata", - {"asset_key": asset_key, "label": label, "value": value, "type": metadata_type}, + {"asset_key": asset_key, "label": label, "value": {"raw_value": value, "type": type}}, ) def report_asset_data_version(self, data_version: str, asset_key: Optional[str] = None) -> None: asset_key = _resolve_optionally_passed_asset_key( - self._data, asset_key, "report_asset_data_version" + self._data, asset_key, "report_asset_data_version", self.materialized_assets ) data_version = _assert_param_type( data_version, str, "report_asset_data_version", "data_version" @@ -758,6 +803,29 @@ def report_asset_data_version(self, data_version: str, asset_key: Optional[str] "report_asset_data_version", {"asset_key": asset_key, "data_version": data_version} ) + def report_asset_materialization( + self, + metadata: Optional[Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]] = None, + data_version: Optional[str] = None, + asset_key: Optional[str] = None, + ): + asset_key = _resolve_optionally_passed_asset_key( + self._data, asset_key, "report_asset_materialization", self.materialized_assets + ) + metadata = ( + _normalize_param_metadata(metadata, "report_asset_check_result", "metadata") + if metadata + else None + ) + data_version = _assert_opt_param_type( + data_version, str, "report_asset_data_version", "data_version" + ) + self._write_message( + "report_asset_materialization", + {"asset_key": asset_key, "data_version": data_version, "metadata": metadata}, + ) + self.materialized_assets.add(asset_key) + def log(self, message: str, level: str = "info") -> None: message = _assert_param_type(message, str, "log", "asset_key") level = _assert_param_value(level, ["info", "warning", "error"], "log", "level") diff --git a/python_modules/dagster-ext/dagster_ext_tests/test_context.py b/python_modules/dagster-ext/dagster_ext_tests/test_context.py index 9b2ab7aca7ce2..b17c95f117735 100644 --- a/python_modules/dagster-ext/dagster_ext_tests/test_context.py +++ b/python_modules/dagster-ext/dagster_ext_tests/test_context.py @@ -162,3 +162,10 @@ def test_extras_context(): assert context.get_extra("foo") == "bar" with pytest.raises(DagsterExtError, match="Extra `bar` is undefined"): context.get_extra("bar") + + +def test_report_after_materialization(): + context = _make_external_execution_context(asset_keys=["foo"]) + with pytest.raises(DagsterExtError, match="already been materialized"): + context.report_asset_materialization(asset_key="foo") + context.report_asset_data_version("alpha", asset_key="foo") diff --git a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py index 35c47b8570fb0..fef08554b8303 100644 --- a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py +++ b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py @@ -9,11 +9,12 @@ import boto3 import pytest +from dagster._core.definitions.asset_spec import AssetSpec from dagster._core.definitions.data_version import ( DATA_VERSION_IS_USER_PROVIDED_TAG, DATA_VERSION_TAG, ) -from dagster._core.definitions.decorators.asset_decorator import asset +from dagster._core.definitions.decorators.asset_decorator import asset, multi_asset from dagster._core.definitions.events import AssetKey from dagster._core.definitions.materialize import materialize from dagster._core.definitions.metadata import ( @@ -148,7 +149,7 @@ def test_ext_subprocess( def foo(context: AssetExecutionContext, ext: ExtSubprocess): extras = {"bar": "baz"} cmd = [_PYTHON_EXECUTABLE, external_script] - return ext.run( + yield from ext.run( cmd, context=context, extras=extras, @@ -161,11 +162,7 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess): resource = ExtSubprocess(context_injector=context_injector, message_reader=message_reader) with instance_for_test() as instance: - materialize( - [foo], - instance=instance, - resources={"ext": resource}, - ) + materialize([foo], instance=instance, resources={"ext": resource}) mat = instance.get_latest_materialization_event(foo.key) assert mat and mat.asset_materialization assert isinstance(mat.asset_materialization.metadata["bar"], MarkdownMetadataValue) @@ -178,6 +175,35 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess): assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE) +def test_ext_multi_asset(): + def script_fn(): + from dagster_ext import init_dagster_ext + + context = init_dagster_ext() + context.report_asset_materialization( + {"foo_meta": "ok"}, data_version="alpha", asset_key="foo" + ) + context.report_asset_data_version("alpha", asset_key="bar") + + @multi_asset(specs=[AssetSpec("foo"), AssetSpec("bar")]) + def foo_bar(context: AssetExecutionContext, ext: ExtSubprocess): + with temp_script(script_fn) as script_path: + cmd = [_PYTHON_EXECUTABLE, script_path] + yield from ext.run(cmd, context=context) + + with instance_for_test() as instance: + materialize([foo_bar], instance=instance, resources={"ext": ExtSubprocess()}) + foo_mat = instance.get_latest_materialization_event(AssetKey(["foo"])) + assert foo_mat and foo_mat.asset_materialization + assert foo_mat.asset_materialization.metadata["foo_meta"].value == "ok" + assert foo_mat.asset_materialization.tags + assert foo_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha" + bar_mat = instance.get_latest_materialization_event(AssetKey(["foo"])) + assert bar_mat and bar_mat.asset_materialization + assert bar_mat.asset_materialization.tags + assert bar_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha" + + def test_ext_typed_metadata(): def script_fn(): from dagster_ext import init_dagster_ext @@ -201,7 +227,7 @@ def script_fn(): def foo(context: AssetExecutionContext, ext: ExtSubprocess): with temp_script(script_fn) as script_path: cmd = [_PYTHON_EXECUTABLE, script_path] - return ext.run(cmd, context=context) + yield from ext.run(cmd, context=context) with instance_for_test() as instance: materialize( @@ -248,7 +274,7 @@ def script_fn(): def foo(context: AssetExecutionContext, ext: ExtSubprocess): with temp_script(script_fn) as script_path: cmd = [_PYTHON_EXECUTABLE, script_path] - ext.run(cmd, context=context) + yield from ext.run(cmd, context=context) with pytest.raises(DagsterExternalExecutionError): materialize([foo], resources={"ext": ExtSubprocess()}) @@ -314,8 +340,7 @@ def subproc_run(context: AssetExecutionContext): ) as ext_context: subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False) _ext_context = ext_context - mat_results = _ext_context.get_materialize_results() - return mat_results[0] if len(mat_results) == 1 else mat_results + yield from _ext_context.get_materialize_results() with instance_for_test() as instance: materialize( @@ -328,30 +353,3 @@ def subproc_run(context: AssetExecutionContext): assert mat.asset_materialization.tags assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha" assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG] - - -def test_ext_no_client_premature_get_results(external_script): - @asset - def subproc_run(context: AssetExecutionContext): - extras = {"bar": "baz"} - cmd = [_PYTHON_EXECUTABLE, external_script] - - with ext_protocol( - context, - ExtTempFileContextInjector(), - ExtTempFileMessageReader(), - extras=extras, - ) as ext_context: - subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False) - return ext_context.get_materialize_results() - - with pytest.raises( - DagsterExternalExecutionError, - match=( - "`get_materialize_results` must be called after the `ext_protocol` context manager has" - " exited." - ), - ): - materialize( - [subproc_run], - ) diff --git a/python_modules/dagster/dagster/_core/definitions/metadata/table.py b/python_modules/dagster/dagster/_core/definitions/metadata/table.py index cf2f60b78b8e9..69ae8dab265e7 100644 --- a/python_modules/dagster/dagster/_core/definitions/metadata/table.py +++ b/python_modules/dagster/dagster/_core/definitions/metadata/table.py @@ -196,7 +196,7 @@ class TableColumn( def __new__( cls, name: str, - type: str = "string", # noqa: A002 + type: str = "string", description: Optional[str] = None, constraints: Optional["TableColumnConstraints"] = None, ): diff --git a/python_modules/dagster/dagster/_core/ext/context.py b/python_modules/dagster/dagster/_core/ext/context.py index 0cfbe18c98147..3782bb20d1701 100644 --- a/python_modules/dagster/dagster/_core/ext/context.py +++ b/python_modules/dagster/dagster/_core/ext/context.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import Any, Dict, Mapping, Optional, Tuple, get_args +from threading import Lock +from typing import Any, Dict, Iterator, List, Mapping, Optional, get_args from dagster_ext import ( DAGSTER_EXT_ENV_KEYS, @@ -9,6 +10,7 @@ ExtExtras, ExtMessage, ExtMetadataType, + ExtMetadataValue, ExtParams, ExtTimeWindow, encode_env_var, @@ -21,7 +23,6 @@ from dagster._core.definitions.partition_key_range import PartitionKeyRange from dagster._core.definitions.result import MaterializeResult from dagster._core.definitions.time_window_partitions import TimeWindow -from dagster._core.errors import DagsterExternalExecutionError from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.execution.context.invocation import BoundOpExecutionContext @@ -29,9 +30,19 @@ class ExtMessageHandler: def __init__(self, context: OpExecutionContext) -> None: self._context = context + self._asset_materialized_map: Dict[AssetKey, bool] = { + key: False for key in context.selected_asset_keys + } + self._asset_materialization_queue: List[AssetKey] = [] self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {} self._data_versions: Dict[AssetKey, DataVersion] = {} + # The lock is used to guard modification of the asset materialization tracking data + # structures, which are modified by both handler methods (and thus the message reader + # thread) and the `clear_materialization_queue` method, which can be called by the + # main thread. + self._lock = Lock() + @property def metadata(self) -> Dict[AssetKey, Dict[str, MetadataValue]]: return self._metadata @@ -40,24 +51,54 @@ def metadata(self) -> Dict[AssetKey, Dict[str, MetadataValue]]: def data_versions(self) -> Dict[AssetKey, DataVersion]: return self._data_versions + def clear_materialization_queue(self) -> Iterator[AssetKey]: + with self._lock: + while self._asset_materialization_queue: + key = self._asset_materialization_queue.pop(0) + yield key + self._asset_materialized_map[key] = True + + def enqueue_remaining_assets(self) -> None: + with self._lock: + unmaterialized_assets = {k for k, v in self._asset_materialized_map.items() if not v} + for key in unmaterialized_assets: + self._asset_materialization_queue.append(key) + + def enqueue_asset(self, asset_key: AssetKey) -> None: + with self._lock: + self._asset_materialization_queue.append(asset_key) + # Type ignores because we currently validate in individual handlers def handle_message(self, message: ExtMessage) -> None: if message["method"] == "report_asset_metadata": self._handle_report_asset_metadata(**message["params"]) # type: ignore elif message["method"] == "report_asset_data_version": self._handle_report_asset_data_version(**message["params"]) # type: ignore + elif message["method"] == "report_asset_materialization": + self._handle_report_asset_materialization(**message["params"]) # type: ignore elif message["method"] == "log": self._handle_log(**message["params"]) # type: ignore def _handle_report_asset_metadata( - self, asset_key: str, label: str, value: Any, type: ExtMetadataType # noqa: A002 + self, asset_key: str, label: str, value: ExtMetadataValue ) -> None: check.str_param(asset_key, "asset_key") check.str_param(label, "label") - check.opt_literal_param(type, "type", get_args(ExtMetadataType)) + check.mapping_param(value, "value", key_type=str) + assert "raw_value" in value + check.opt_literal_param(value["type"], "type", get_args(ExtMetadataType)) resolved_asset_key = AssetKey.from_user_string(asset_key) - metadata_value = self._resolve_metadata_value(value, type) - self._metadata.setdefault(resolved_asset_key, {})[label] = metadata_value + self._update_metadata(resolved_asset_key, {label: value}) + + def _update_metadata( + self, asset_key: AssetKey, metadata: Mapping[str, ExtMetadataValue] + ) -> None: + resolved_metadata = { + k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items() + } + if asset_key not in self._metadata: + self._metadata[asset_key] = {} + self._metadata[asset_key].update(resolved_metadata) def _resolve_metadata_value( self, value: Any, metadata_type: Optional[ExtMetadataType] @@ -93,6 +134,18 @@ def _resolve_metadata_value( else: check.failed(f"Unexpected metadata type {metadata_type}") + def _handle_report_asset_materialization( + self, asset_key: str, metadata: Optional[Mapping[str, Any]], data_version: Optional[str] + ) -> None: + check.str_param(asset_key, "asset_key") + check.opt_str_param(data_version, "data_version") + metadata = check.opt_mapping_param(metadata, "metadata", key_type=str) + resolved_asset_key = AssetKey.from_user_string(asset_key) + self._update_metadata(resolved_asset_key, metadata) + if data_version is not None: + self._data_versions[resolved_asset_key] = DataVersion(data_version) + self.enqueue_asset(resolved_asset_key) + def _handle_report_asset_data_version(self, asset_key: str, data_version: str) -> None: check.str_param(asset_key, "asset_key") check.str_param(data_version, "data_version") @@ -119,7 +172,6 @@ class ExtOrchestrationContext: message_handler: ExtMessageHandler context_injector_params: ExtParams message_reader_params: ExtParams - is_task_finished: bool = False def get_external_process_env_vars(self): return { @@ -130,15 +182,13 @@ def get_external_process_env_vars(self): ), } - def get_materialize_results(self) -> Tuple[MaterializeResult, ...]: - if not self.is_task_finished: - raise DagsterExternalExecutionError( - "`get_materialize_results` must be called after the `ext_protocol` context manager" - " has exited." - ) - return tuple( - self._materialize_result_for_asset(AssetKey.from_user_string(key)) - for key in self.context_data["asset_keys"] or [] + def on_protocol_close(self): + self.message_handler.enqueue_remaining_assets() + + def get_materialize_results(self) -> Iterator[MaterializeResult]: + return iter( + self._materialize_result_for_asset(key) + for key in self.message_handler.clear_materialization_queue() ) def _materialize_result_for_asset(self, asset_key: AssetKey): diff --git a/python_modules/dagster/dagster/_core/ext/subprocess.py b/python_modules/dagster/dagster/_core/ext/subprocess.py index 41b6e8d526009..368ef4361bbb9 100644 --- a/python_modules/dagster/dagster/_core/ext/subprocess.py +++ b/python_modules/dagster/dagster/_core/ext/subprocess.py @@ -1,5 +1,5 @@ from subprocess import Popen -from typing import Mapping, Optional, Sequence, Tuple, Union +from typing import Iterator, Mapping, Optional, Sequence, Union from dagster_ext import ExtExtras @@ -67,7 +67,7 @@ def run( extras: Optional[ExtExtras] = None, env: Optional[Mapping[str, str]] = None, cwd: Optional[str] = None, - ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: + ) -> Iterator[MaterializeResult]: with ext_protocol( context=context, context_injector=self.context_injector, @@ -83,14 +83,16 @@ def run( **(env or {}), }, ) - process.wait() + while True: + yield from ext_context.get_materialize_results() + if process.poll() is not None: + break if process.returncode != 0: raise DagsterExternalExecutionError( f"External execution process failed with code {process.returncode}" ) - mat_results = ext_context.get_materialize_results() - return mat_results[0] if len(mat_results) == 1 else mat_results + yield from ext_context.get_materialize_results() ExtSubprocess = ResourceParam[_ExtSubprocess] diff --git a/python_modules/dagster/dagster/_core/ext/utils.py b/python_modules/dagster/dagster/_core/ext/utils.py index 34ee9a1f71490..7703388f05ec4 100644 --- a/python_modules/dagster/dagster/_core/ext/utils.py +++ b/python_modules/dagster/dagster/_core/ext/utils.py @@ -212,4 +212,4 @@ def ext_protocol( message_reader_params=mr_params, ) yield ext_context - ext_context.is_task_finished = True + ext_context.on_protocol_close() diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py index b7f522a1e463d..9acf4e0cb7990 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py @@ -5,7 +5,7 @@ import string import time from contextlib import contextmanager -from typing import Iterator, Mapping, Optional, Tuple, Union +from typing import Iterator, Mapping, Optional import dagster._check as check from dagster._core.definitions.resource_annotation import ResourceParam @@ -67,7 +67,7 @@ def run( context: OpExecutionContext, extras: Optional[ExtExtras] = None, submit_args: Optional[Mapping[str, str]] = None, - ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: + ) -> Iterator[MaterializeResult]: """Run a Databricks job with the EXT protocol. Args: diff --git a/python_modules/libraries/dagster-docker/dagster_docker/ext.py b/python_modules/libraries/dagster-docker/dagster_docker/ext.py index aa379722416a2..539b2e73169d2 100644 --- a/python_modules/libraries/dagster-docker/dagster_docker/ext.py +++ b/python_modules/libraries/dagster-docker/dagster_docker/ext.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Iterator, Mapping, Optional, Sequence, Union import docker from dagster import ( @@ -95,7 +95,7 @@ def run( registry: Optional[Mapping[str, str]] = None, container_kwargs: Optional[Mapping[str, Any]] = None, extras: Optional[ExtExtras] = None, - ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: + ) -> Iterator[MaterializeResult]: """Create a docker container and run it to completion, enriched with the ext protocol. Args: diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py index ff4f99ee048c2..1de0c33f15c67 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py @@ -1,7 +1,7 @@ import random import string from contextlib import contextmanager -from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Iterator, Mapping, Optional, Sequence, Union import kubernetes from dagster import ( @@ -124,7 +124,7 @@ def run( base_pod_meta: Optional[Mapping[str, Any]] = None, base_pod_spec: Optional[Mapping[str, Any]] = None, extras: Optional[ExtExtras] = None, - ) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]: + ) -> Iterator[MaterializeResult]: """Publish a kubernetes pod and wait for it to complete, enriched with the ext protocol. Args: @@ -197,8 +197,7 @@ def run( ) finally: client.core_api.delete_namespaced_pod(pod_name, namespace) - mats = ext_context.get_materialize_results() - return mats[0] if len(mats) == 1 else mats + return ext_context.get_materialize_results() def build_pod_body(