diff --git a/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py b/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py index 6be73a917fe00..450ee9874dd46 100644 --- a/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py +++ b/integration_tests/test_suites/k8s-test-suite/tests/test_external_asset.py @@ -26,7 +26,7 @@ def number_y( context: AssetExecutionContext, ext_k8s_pod: ExtK8sPod, ): - ext_k8s_pod.run( + yield from ext_k8s_pod.run( context=context, namespace=namespace, image=docker_image, @@ -138,7 +138,7 @@ def number_y( ], ) - ext_k8s_pod.run( + yield from ext_k8s_pod.run( context=context, namespace=namespace, extras={ @@ -197,6 +197,7 @@ def number_y_job(context: AssetExecutionContext): k8s_job_name=job_name, ) reader.consume_pod_logs(core_api, job_name, namespace) + yield from ext_context.get_results() result = materialize( [number_y_job], diff --git a/python_modules/dagster-ext/dagster_ext/__init__.py b/python_modules/dagster-ext/dagster_ext/__init__.py index 38366a039aeaf..172e52b127e13 100644 --- a/python_modules/dagster-ext/dagster_ext/__init__.py +++ b/python_modules/dagster-ext/dagster_ext/__init__.py @@ -23,7 +23,6 @@ Mapping, Optional, Sequence, - Set, TextIO, Type, TypedDict, @@ -165,7 +164,6 @@ def _resolve_optionally_passed_asset_key( data: ExtContextData, asset_key: Optional[str], method: str, - already_materialized_assets: Set[str], ) -> str: asset_keys = _assert_defined_asset_property(data["asset_keys"], method) asset_key = _assert_opt_param_type(asset_key, str, method, "asset_key") @@ -180,11 +178,6 @@ def _resolve_optionally_passed_asset_key( " targets multiple assets." ) asset_key = asset_keys[0] - if asset_key in already_materialized_assets: - raise DagsterExtError( - f"Calling `{method}` with asset key `{asset_key}` is undefined. Asset has already been" - " materialized, so no additional data can be reported for it." - ) return asset_key @@ -784,8 +777,14 @@ def report_asset_materialization( asset_key: Optional[str] = None, ): asset_key = _resolve_optionally_passed_asset_key( - self._data, asset_key, "report_asset_materialization", self._materialized_assets + self._data, asset_key, "report_asset_materialization" ) + if asset_key in self._materialized_assets: + raise DagsterExtError( + f"Calling `report_asset_materialization` with asset key `{asset_key}` is undefined." + " Asset has already been materialized, so no additional data can be reported" + " for it." + ) metadata = ( _normalize_param_metadata(metadata, "report_asset_materialization", "metadata") if metadata diff --git a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py index f5227f75655b3..c2b0b540904cf 100644 --- a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py +++ b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py @@ -9,11 +9,12 @@ import boto3 import pytest +from dagster._core.definitions.asset_spec import AssetSpec from dagster._core.definitions.data_version import ( DATA_VERSION_IS_USER_PROVIDED_TAG, DATA_VERSION_TAG, ) -from dagster._core.definitions.decorators.asset_decorator import asset +from dagster._core.definitions.decorators.asset_decorator import asset, multi_asset from dagster._core.definitions.events import AssetKey from dagster._core.definitions.materialize import materialize from dagster._core.definitions.metadata import ( @@ -30,8 +31,8 @@ TextMetadataValue, UrlMetadataValue, ) -from dagster._core.errors import DagsterExternalExecutionError -from dagster._core.execution.context.compute import AssetExecutionContext +from dagster._core.errors import DagsterExternalExecutionError, DagsterInvariantViolationError +from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext from dagster._core.execution.context.invocation import build_asset_context from dagster._core.ext.subprocess import ( ExtSubprocess, @@ -150,7 +151,7 @@ def test_ext_subprocess( def foo(context: AssetExecutionContext, ext: ExtSubprocess): extras = {"bar": "baz"} cmd = [_PYTHON_EXECUTABLE, external_script] - ext.run( + yield from ext.run( cmd, context=context, extras=extras, @@ -163,11 +164,7 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess): resource = ExtSubprocess(context_injector=context_injector, message_reader=message_reader) with instance_for_test() as instance: - materialize( - [foo], - instance=instance, - resources={"ext": resource}, - ) + materialize([foo], instance=instance, resources={"ext": resource}) mat = instance.get_latest_materialization_event(foo.key) assert mat and mat.asset_materialization assert isinstance(mat.asset_materialization.metadata["bar"], MarkdownMetadataValue) @@ -180,6 +177,35 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess): assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE) +def test_ext_multi_asset(): + def script_fn(): + from dagster_ext import init_dagster_ext + + context = init_dagster_ext() + context.report_asset_materialization( + {"foo_meta": "ok"}, data_version="alpha", asset_key="foo" + ) + context.report_asset_materialization(data_version="alpha", asset_key="bar") + + @multi_asset(specs=[AssetSpec("foo"), AssetSpec("bar")]) + def foo_bar(context: AssetExecutionContext, ext: ExtSubprocess): + with temp_script(script_fn) as script_path: + cmd = [_PYTHON_EXECUTABLE, script_path] + yield from ext.run(cmd, context=context) + + with instance_for_test() as instance: + materialize([foo_bar], instance=instance, resources={"ext": ExtSubprocess()}) + foo_mat = instance.get_latest_materialization_event(AssetKey(["foo"])) + assert foo_mat and foo_mat.asset_materialization + assert foo_mat.asset_materialization.metadata["foo_meta"].value == "ok" + assert foo_mat.asset_materialization.tags + assert foo_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha" + bar_mat = instance.get_latest_materialization_event(AssetKey(["foo"])) + assert bar_mat and bar_mat.asset_materialization + assert bar_mat.asset_materialization.tags + assert bar_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha" + + def test_ext_typed_metadata(): def script_fn(): from dagster_ext import init_dagster_ext @@ -207,7 +233,7 @@ def script_fn(): def foo(context: AssetExecutionContext, ext: ExtSubprocess): with temp_script(script_fn) as script_path: cmd = [_PYTHON_EXECUTABLE, script_path] - ext.run(cmd, context=context) + yield from ext.run(cmd, context=context) with instance_for_test() as instance: materialize( @@ -254,7 +280,7 @@ def script_fn(): def foo(context: AssetExecutionContext, ext: ExtSubprocess): with temp_script(script_fn) as script_path: cmd = [_PYTHON_EXECUTABLE, script_path] - ext.run(cmd, context=context) + yield from ext.run(cmd, context=context) with pytest.raises(DagsterExternalExecutionError): materialize([foo], resources={"ext": ExtSubprocess()}) @@ -321,6 +347,7 @@ def subproc_run(context: AssetExecutionContext): extras=extras, ) as ext_context: subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False) + yield from ext_context.get_results() with instance_for_test() as instance: materialize( @@ -333,3 +360,28 @@ def subproc_run(context: AssetExecutionContext): assert mat.asset_materialization.tags assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha" assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG] + + +def test_ext_no_client_no_yield(): + def script_fn(): + pass + + @asset + def foo(context: OpExecutionContext): + with temp_script(script_fn) as external_script: + with ext_protocol( + context, + ExtTempFileContextInjector(), + ExtTempFileMessageReader(), + ) as ext_context: + cmd = [_PYTHON_EXECUTABLE, external_script] + subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False) + + with pytest.raises( + DagsterInvariantViolationError, + match=( + r"did not yield or return expected outputs.*Did you forget to `yield from" + r" ext_context.get_results\(\)`?" + ), + ): + materialize([foo]) diff --git a/python_modules/dagster/dagster/_core/ext/client.py b/python_modules/dagster/dagster/_core/ext/client.py index 2e89e50754be3..5d96312eb9eb7 100644 --- a/python_modules/dagster/dagster/_core/ext/client.py +++ b/python_modules/dagster/dagster/_core/ext/client.py @@ -11,7 +11,7 @@ from dagster._core.execution.context.compute import OpExecutionContext if TYPE_CHECKING: - from .context import ExtMessageHandler + from .context import ExtMessageHandler, ExtResult class ExtClient(ABC): @@ -21,7 +21,7 @@ def run( *, context: OpExecutionContext, extras: Optional[ExtExtras] = None, - ) -> None: ... + ) -> Iterator["ExtResult"]: ... class ExtContextInjector(ABC): diff --git a/python_modules/dagster/dagster/_core/ext/context.py b/python_modules/dagster/dagster/_core/ext/context.py index 1448694d6471c..be766c9d86ae8 100644 --- a/python_modules/dagster/dagster/_core/ext/context.py +++ b/python_modules/dagster/dagster/_core/ext/context.py @@ -1,5 +1,7 @@ +from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Mapping, Optional +from queue import Queue +from typing import Any, Dict, Iterator, Mapping, Optional, Set from dagster_ext import ( DAGSTER_EXT_ENV_KEYS, @@ -10,24 +12,54 @@ ExtExtras, ExtMessage, ExtMetadataType, + ExtMetadataValue, ExtParams, ExtTimeWindow, encode_env_var, ) +from typing_extensions import TypeAlias import dagster._check as check from dagster._core.definitions.data_version import DataProvenance, DataVersion from dagster._core.definitions.events import AssetKey from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value from dagster._core.definitions.partition_key_range import PartitionKeyRange +from dagster._core.definitions.result import MaterializeResult from dagster._core.definitions.time_window_partitions import TimeWindow from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.execution.context.invocation import BoundOpExecutionContext +from dagster._core.ext.client import ExtMessageReader + +ExtResult: TypeAlias = MaterializeResult class ExtMessageHandler: def __init__(self, context: OpExecutionContext) -> None: self._context = context + # Queue is thread-safe + self._result_queue: Queue[ExtResult] = Queue() + # Only read by the main thread after all messages are handled, so no need for a lock + self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys) + self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {} + self._data_versions: Dict[AssetKey, DataVersion] = {} + + @contextmanager + def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]: + with message_reader.read_messages(self) as params: + yield params + for key in self._unmaterialized_assets: + self._result_queue.put(MaterializeResult(asset_key=key)) + + def clear_result_queue(self) -> Iterator[ExtResult]: + while not self._result_queue.empty(): + yield self._result_queue.get() + + def _resolve_metadata( + self, metadata: Mapping[str, ExtMetadataValue] + ) -> Mapping[str, MetadataValue]: + return { + k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items() + } def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) -> MetadataValue: if metadata_type == EXT_METADATA_TYPE_INFER: @@ -69,20 +101,24 @@ def handle_message(self, message: ExtMessage) -> None: self._handle_log(**message["params"]) # type: ignore def _handle_report_asset_materialization( - self, asset_key: str, metadata: Optional[Mapping[str, Any]], data_version: Optional[str] + self, + asset_key: str, + metadata: Optional[Mapping[str, ExtMetadataValue]], + data_version: Optional[str], ) -> None: check.str_param(asset_key, "asset_key") check.opt_str_param(data_version, "data_version") metadata = check.opt_mapping_param(metadata, "metadata", key_type=str) resolved_asset_key = AssetKey.from_user_string(asset_key) - resolved_metadata = { - k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items() - } - if data_version is not None: - self._context.set_data_version(resolved_asset_key, DataVersion(data_version)) - if resolved_metadata: - output_name = self._context.output_for_asset_key(resolved_asset_key) - self._context.add_output_metadata(resolved_metadata, output_name) + resolved_metadata = self._resolve_metadata(metadata) + resolved_data_version = None if data_version is None else DataVersion(data_version) + result = MaterializeResult( + asset_key=resolved_asset_key, + metadata=resolved_metadata, + data_version=resolved_data_version, + ) + self._result_queue.put(result) + self._unmaterialized_assets.remove(resolved_asset_key) def _handle_log(self, message: str, level: str = "info") -> None: check.str_param(message, "message") @@ -114,6 +150,9 @@ def get_external_process_env_vars(self): ), } + def get_results(self) -> Iterator[ExtResult]: + yield from self.message_handler.clear_result_queue() + def build_external_execution_context_data( context: OpExecutionContext, diff --git a/python_modules/dagster/dagster/_core/ext/subprocess.py b/python_modules/dagster/dagster/_core/ext/subprocess.py index aa07751501d09..be8f919763f19 100644 --- a/python_modules/dagster/dagster/_core/ext/subprocess.py +++ b/python_modules/dagster/dagster/_core/ext/subprocess.py @@ -1,5 +1,5 @@ from subprocess import Popen -from typing import Mapping, Optional, Sequence, Union +from typing import Iterator, Mapping, Optional, Sequence, Union from dagster_ext import ExtExtras @@ -12,6 +12,7 @@ ExtContextInjector, ExtMessageReader, ) +from dagster._core.ext.context import ExtResult from dagster._core.ext.utils import ( ExtTempFileContextInjector, ExtTempFileMessageReader, @@ -66,7 +67,7 @@ def run( extras: Optional[ExtExtras] = None, env: Optional[Mapping[str, str]] = None, cwd: Optional[str] = None, - ) -> None: + ) -> Iterator[ExtResult]: with ext_protocol( context=context, context_injector=self.context_injector, @@ -82,12 +83,14 @@ def run( **(env or {}), }, ) - process.wait() + while process.poll() is None: + yield from ext_context.get_results() if process.returncode != 0: raise DagsterExternalExecutionError( f"External execution process failed with code {process.returncode}" ) + yield from ext_context.get_results() ExtSubprocess = ResourceParam[_ExtSubprocess] diff --git a/python_modules/dagster/dagster/_core/ext/utils.py b/python_modules/dagster/dagster/_core/ext/utils.py index 9ea033ee4ea73..1a74e275d66b3 100644 --- a/python_modules/dagster/dagster/_core/ext/utils.py +++ b/python_modules/dagster/dagster/_core/ext/utils.py @@ -185,6 +185,12 @@ def extract_message_or_forward_to_stdout(handler: "ExtMessageHandler", log_line: sys.stdout.writelines((log_line, "\n")) +_FAIL_TO_YIELD_ERROR_MESSAGE = ( + "Did you forget to `yield from ext_context.get_results()`? `get_results` should be called once" + " after the `ext_protocol` block has exited to yield any remaining buffered results." +) + + @contextmanager def ext_protocol( context: OpExecutionContext, @@ -195,13 +201,13 @@ def ext_protocol( """Enter the context managed context injector and message reader that power the EXT protocol and receive the environment variables that need to be provided to the external process. """ + # This will trigger an error if expected outputs are not yielded + context.set_requires_typed_event_stream(error_message=_FAIL_TO_YIELD_ERROR_MESSAGE) context_data = build_external_execution_context_data(context, extras) message_handler = ExtMessageHandler(context) with context_injector.inject_context( - context_data, - ) as ci_params, message_reader.read_messages( - message_handler, - ) as mr_params: + context_data + ) as ci_params, message_handler.handle_messages(message_reader) as mr_params: yield ExtOrchestrationContext( context_data=context_data, message_handler=message_handler, diff --git a/python_modules/libraries/dagster-databricks/README.md b/python_modules/libraries/dagster-databricks/README.md index 12788b0beddcd..c59de34f84792 100644 --- a/python_modules/libraries/dagster-databricks/README.md +++ b/python_modules/libraries/dagster-databricks/README.md @@ -17,8 +17,9 @@ databricks jobs using Dagster's ext protocol. specification. After setting up ext communications channels (which by default use DBFS), it injects the information needed to connect to these channels from Databricks into the task specification. It then launches a Databricks job by -passing the specification to `WorkspaceClient.jobs.submit`. It polls the job -state and exits gracefully on success or failure: +passing the specification to `WorkspaceClient.jobs.submit`. It polls the job, +state, exits gracefully on success or failure, and returns a stream of +`MaterializeResult` events that should be yielded: ``` @@ -52,7 +53,7 @@ def databricks_asset(context: AssetExecutionContext, ext: ExtDatabricks): extras = {"sample_rate": 1.0} # synchronously execute the databricks job - ext.run( + yield from ext.run( task=task, context=context, extras=extras, @@ -114,6 +115,18 @@ launched within the scope of the `ext_process` context manager; (2) your job is launched on a cluster containing the environment variables available on the yielded `ext_context`. +While your databricks code is running, any calls to +`report_asset_materialization` in the external script are streamed back to +Dagster, causing a `MaterializationResult` object to be buffered on the +`ext_context`. You can either leave these objects buffered until execution is +complete (Option (1) in below example code) or stream them to Dagster machinery +during execution by calling `yield ext_context.get_results()` (Option (2)). + +With either option, once the `ext_protocol` block closes, you must call `yield +ext_context.get_results()` to yield any remaining buffered results, since we +cannot guarantee that all communications from databricks have been processed +until the `ext_protocol` block closes. + ``` import os @@ -142,10 +155,31 @@ def databricks_asset(context: AssetExecutionContext): message_reader=ExtDbfsMessageReader(client=client), ) as ext_context: + ##### Option (1) + # NON-STREAMING. Just pass the necessary environment variables down. + # During execution, all reported materializations are buffered on the + # `ext_context`. Yield them all after Databricks execution is finished. + # Dict[str, str] with environment variables containing ext comms info. env_vars = ext_context.get_external_process_env_vars() # Some function that handles launching/monitoring of the databricks job. # It must ensure that the `env_vars` are set on the executing cluster. custom_databricks_launch_code(env_vars) + + ##### Option (2) + # STREAMING. Pass `ext_context` down. During execution, you can yield any + # asset materializations that have been reported by calling ` + # ext_context.get_results()` as often as you like. `get_results` returns + # an iterator that your custom code can `yield from` to forward the + # results back to the materialize funciton. Note you will need to extract + # the env vars by calling `ext_context.get_external_process_env_vars()`, + # and launch the databricks job in the same way as with (1). + + # The function should return an `Iterator[MaterializeResult]`. + yield from custom_databricks_launch_code(ext_context) + + # With either option (1) or (2), this is required to yield any remaining + # buffered results. + yield from ext_context.get_results() ``` diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py index 22bb89034cd07..19451003635bf 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py @@ -12,6 +12,7 @@ from dagster._core.errors import DagsterExternalExecutionError from dagster._core.execution.context.compute import OpExecutionContext from dagster._core.ext.client import ExtClient, ExtContextInjector, ExtMessageReader +from dagster._core.ext.context import ExtResult from dagster._core.ext.utils import ( ExtBlobStoreMessageReader, ext_protocol, @@ -66,7 +67,7 @@ def run( context: OpExecutionContext, extras: Optional[ExtExtras] = None, submit_args: Optional[Mapping[str, str]] = None, - ) -> None: + ) -> Iterator[ExtResult]: """Run a Databricks job with the EXT protocol. Args: @@ -106,7 +107,7 @@ def run( jobs.RunLifeCycleState.SKIPPED, ): if run.state.result_state == jobs.RunResultState.SUCCESS: - return + break else: raise DagsterExternalExecutionError( f"Error running Databricks job: {run.state.state_message}" @@ -115,7 +116,9 @@ def run( raise DagsterExternalExecutionError( f"Error running Databricks job: {run.state.state_message}" ) + yield from ext_context.get_results() time.sleep(5) + yield from ext_context.get_results() ExtDatabricks = ResourceParam[_ExtDatabricks] diff --git a/python_modules/libraries/dagster-docker/dagster_docker/ext.py b/python_modules/libraries/dagster-docker/dagster_docker/ext.py index 0862d72d014cc..8fac849329439 100644 --- a/python_modules/libraries/dagster-docker/dagster_docker/ext.py +++ b/python_modules/libraries/dagster-docker/dagster_docker/ext.py @@ -14,6 +14,7 @@ ) from dagster._core.ext.context import ( ExtMessageHandler, + ExtResult, ) from dagster._core.ext.utils import ( ExtEnvContextInjector, @@ -94,7 +95,7 @@ def run( registry: Optional[Mapping[str, str]] = None, container_kwargs: Optional[Mapping[str, Any]] = None, extras: Optional[ExtExtras] = None, - ) -> None: + ) -> Iterator[ExtResult]: """Create a docker container and run it to completion, enriched with the ext protocol. Args: @@ -162,6 +163,7 @@ def run( raise DagsterExtError(f"Container exited with non-zero status code: {result}") finally: container.stop() + return ext_context.get_results() def _create_container( self, diff --git a/python_modules/libraries/dagster-docker/dagster_docker_tests/test_ext.py b/python_modules/libraries/dagster-docker/dagster_docker_tests/test_ext.py index cf403520da8b8..cad94729fabef 100644 --- a/python_modules/libraries/dagster-docker/dagster_docker_tests/test_ext.py +++ b/python_modules/libraries/dagster-docker/dagster_docker_tests/test_ext.py @@ -31,7 +31,7 @@ def number_x( context: AssetExecutionContext, ext_docker: ExtDocker, ): - ext_docker.run( + yield from ext_docker.run( image=docker_image, command=[ "python", @@ -88,7 +88,7 @@ def number_x( }, } - ext_docker.run( + yield from ext_docker.run( image=docker_image, command=[ "python", diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py index c0a42e123104d..8da2226cebdd9 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py @@ -18,6 +18,7 @@ ) from dagster._core.ext.context import ( ExtMessageHandler, + ExtResult, ) from dagster._core.ext.utils import ( ExtEnvContextInjector, @@ -123,7 +124,7 @@ def run( base_pod_meta: Optional[Mapping[str, Any]] = None, base_pod_spec: Optional[Mapping[str, Any]] = None, extras: Optional[ExtExtras] = None, - ) -> None: + ) -> Iterator[ExtResult]: """Publish a kubernetes pod and wait for it to complete, enriched with the ext protocol. Args: @@ -196,6 +197,7 @@ def run( ) finally: client.core_api.delete_namespaced_pod(pod_name, namespace) + return ext_context.get_results() def build_pod_body(