Skip to content

Commit

Permalink
[ext] Use MaterializeResult for ext_protocol (#16624)
Browse files Browse the repository at this point in the history
## Summary & Motivation

Make ext rely on yielding `MaterializeResult` to register metadata and
data version, as opposed to modifying `OpExecutionContext`. This allows
results to stream back as they are reported rather than being bulk
reported when computation completes.

This required addition of a `report_asset_materialization` method that
can be called on the `ExtContext`. This will queue a
`MaterializationResult` on the orchestration side. The queue can be
cleared from the `ExtOrchestrationContext` at any time by calling
`ExtOrchestrationContext.get_results`. Errors are raised if attempting
materialize an asset twice or report data version/metadata after
materialization.

Once the `ext_protocol` block exits, any as-yet-unmaterialized assets
are queued on the `MessageHandler`, so that calling
`ExtOrchestrationContext.get_results` after exit will yield all the
remaining `MaterializeResult` objects. Note that yielding from this
method after `ext_protocol` close is required to guarantee all buffered
data is yielded, since there is no guarantee that all messages have been
processed before `ext_protocol` completes its exit routine.

To head off the confusing scenario where a user forgets to yield outside
the block and sees auto-created materializations that lack any reported
metadata, we call `set_require_typed_event_stream` on the
`OpExecutionContext`. This will cause an error during output processing
if an expected output was not returned/yielded.

## How I Tested These Changes

New unit tests.
  • Loading branch information
smackesey authored Sep 22, 2023
1 parent 547a624 commit f7b3fee
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def number_y(
context: AssetExecutionContext,
ext_k8s_pod: ExtK8sPod,
):
ext_k8s_pod.run(
yield from ext_k8s_pod.run(
context=context,
namespace=namespace,
image=docker_image,
Expand Down Expand Up @@ -138,7 +138,7 @@ def number_y(
],
)

ext_k8s_pod.run(
yield from ext_k8s_pod.run(
context=context,
namespace=namespace,
extras={
Expand Down Expand Up @@ -197,6 +197,7 @@ def number_y_job(context: AssetExecutionContext):
k8s_job_name=job_name,
)
reader.consume_pod_logs(core_api, job_name, namespace)
yield from ext_context.get_results()

result = materialize(
[number_y_job],
Expand Down
15 changes: 7 additions & 8 deletions python_modules/dagster-ext/dagster_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Mapping,
Optional,
Sequence,
Set,
TextIO,
Type,
TypedDict,
Expand Down Expand Up @@ -165,7 +164,6 @@ def _resolve_optionally_passed_asset_key(
data: ExtContextData,
asset_key: Optional[str],
method: str,
already_materialized_assets: Set[str],
) -> str:
asset_keys = _assert_defined_asset_property(data["asset_keys"], method)
asset_key = _assert_opt_param_type(asset_key, str, method, "asset_key")
Expand All @@ -180,11 +178,6 @@ def _resolve_optionally_passed_asset_key(
" targets multiple assets."
)
asset_key = asset_keys[0]
if asset_key in already_materialized_assets:
raise DagsterExtError(
f"Calling `{method}` with asset key `{asset_key}` is undefined. Asset has already been"
" materialized, so no additional data can be reported for it."
)
return asset_key


Expand Down Expand Up @@ -784,8 +777,14 @@ def report_asset_materialization(
asset_key: Optional[str] = None,
):
asset_key = _resolve_optionally_passed_asset_key(
self._data, asset_key, "report_asset_materialization", self._materialized_assets
self._data, asset_key, "report_asset_materialization"
)
if asset_key in self._materialized_assets:
raise DagsterExtError(
f"Calling `report_asset_materialization` with asset key `{asset_key}` is undefined."
" Asset has already been materialized, so no additional data can be reported"
" for it."
)
metadata = (
_normalize_param_metadata(metadata, "report_asset_materialization", "metadata")
if metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

import boto3
import pytest
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.data_version import (
DATA_VERSION_IS_USER_PROVIDED_TAG,
DATA_VERSION_TAG,
)
from dagster._core.definitions.decorators.asset_decorator import asset
from dagster._core.definitions.decorators.asset_decorator import asset, multi_asset
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.materialize import materialize
from dagster._core.definitions.metadata import (
Expand All @@ -30,8 +31,8 @@
TextMetadataValue,
UrlMetadataValue,
)
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import AssetExecutionContext
from dagster._core.errors import DagsterExternalExecutionError, DagsterInvariantViolationError
from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext
from dagster._core.execution.context.invocation import build_asset_context
from dagster._core.ext.subprocess import (
ExtSubprocess,
Expand Down Expand Up @@ -150,7 +151,7 @@ def test_ext_subprocess(
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
ext.run(
yield from ext.run(
cmd,
context=context,
extras=extras,
Expand All @@ -163,11 +164,7 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
resource = ExtSubprocess(context_injector=context_injector, message_reader=message_reader)

with instance_for_test() as instance:
materialize(
[foo],
instance=instance,
resources={"ext": resource},
)
materialize([foo], instance=instance, resources={"ext": resource})
mat = instance.get_latest_materialization_event(foo.key)
assert mat and mat.asset_materialization
assert isinstance(mat.asset_materialization.metadata["bar"], MarkdownMetadataValue)
Expand All @@ -180,6 +177,35 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)


def test_ext_multi_asset():
def script_fn():
from dagster_ext import init_dagster_ext

context = init_dagster_ext()
context.report_asset_materialization(
{"foo_meta": "ok"}, data_version="alpha", asset_key="foo"
)
context.report_asset_materialization(data_version="alpha", asset_key="bar")

@multi_asset(specs=[AssetSpec("foo"), AssetSpec("bar")])
def foo_bar(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
yield from ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize([foo_bar], instance=instance, resources={"ext": ExtSubprocess()})
foo_mat = instance.get_latest_materialization_event(AssetKey(["foo"]))
assert foo_mat and foo_mat.asset_materialization
assert foo_mat.asset_materialization.metadata["foo_meta"].value == "ok"
assert foo_mat.asset_materialization.tags
assert foo_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
bar_mat = instance.get_latest_materialization_event(AssetKey(["foo"]))
assert bar_mat and bar_mat.asset_materialization
assert bar_mat.asset_materialization.tags
assert bar_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"


def test_ext_typed_metadata():
def script_fn():
from dagster_ext import init_dagster_ext
Expand Down Expand Up @@ -207,7 +233,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
ext.run(cmd, context=context)
yield from ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize(
Expand Down Expand Up @@ -254,7 +280,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
ext.run(cmd, context=context)
yield from ext.run(cmd, context=context)

with pytest.raises(DagsterExternalExecutionError):
materialize([foo], resources={"ext": ExtSubprocess()})
Expand Down Expand Up @@ -321,6 +347,7 @@ def subproc_run(context: AssetExecutionContext):
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
yield from ext_context.get_results()

with instance_for_test() as instance:
materialize(
Expand All @@ -333,3 +360,28 @@ def subproc_run(context: AssetExecutionContext):
assert mat.asset_materialization.tags
assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]


def test_ext_no_client_no_yield():
def script_fn():
pass

@asset
def foo(context: OpExecutionContext):
with temp_script(script_fn) as external_script:
with ext_protocol(
context,
ExtTempFileContextInjector(),
ExtTempFileMessageReader(),
) as ext_context:
cmd = [_PYTHON_EXECUTABLE, external_script]
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)

with pytest.raises(
DagsterInvariantViolationError,
match=(
r"did not yield or return expected outputs.*Did you forget to `yield from"
r" ext_context.get_results\(\)`?"
),
):
materialize([foo])
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_core/ext/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dagster._core.execution.context.compute import OpExecutionContext

if TYPE_CHECKING:
from .context import ExtMessageHandler
from .context import ExtMessageHandler, ExtResult


class ExtClient(ABC):
Expand All @@ -21,7 +21,7 @@ def run(
*,
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
) -> None: ...
) -> Iterator["ExtResult"]: ...


class ExtContextInjector(ABC):
Expand Down
59 changes: 49 additions & 10 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Mapping, Optional
from queue import Queue
from typing import Any, Dict, Iterator, Mapping, Optional, Set

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
Expand All @@ -10,24 +12,54 @@
ExtExtras,
ExtMessage,
ExtMetadataType,
ExtMetadataValue,
ExtParams,
ExtTimeWindow,
encode_env_var,
)
from typing_extensions import TypeAlias

import dagster._check as check
from dagster._core.definitions.data_version import DataProvenance, DataVersion
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value
from dagster._core.definitions.partition_key_range import PartitionKeyRange
from dagster._core.definitions.result import MaterializeResult
from dagster._core.definitions.time_window_partitions import TimeWindow
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.invocation import BoundOpExecutionContext
from dagster._core.ext.client import ExtMessageReader

ExtResult: TypeAlias = MaterializeResult


class ExtMessageHandler:
def __init__(self, context: OpExecutionContext) -> None:
self._context = context
# Queue is thread-safe
self._result_queue: Queue[ExtResult] = Queue()
# Only read by the main thread after all messages are handled, so no need for a lock
self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys)
self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
self._data_versions: Dict[AssetKey, DataVersion] = {}

@contextmanager
def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]:
with message_reader.read_messages(self) as params:
yield params
for key in self._unmaterialized_assets:
self._result_queue.put(MaterializeResult(asset_key=key))

def clear_result_queue(self) -> Iterator[ExtResult]:
while not self._result_queue.empty():
yield self._result_queue.get()

def _resolve_metadata(
self, metadata: Mapping[str, ExtMetadataValue]
) -> Mapping[str, MetadataValue]:
return {
k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
}

def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) -> MetadataValue:
if metadata_type == EXT_METADATA_TYPE_INFER:
Expand Down Expand Up @@ -69,20 +101,24 @@ def handle_message(self, message: ExtMessage) -> None:
self._handle_log(**message["params"]) # type: ignore

def _handle_report_asset_materialization(
self, asset_key: str, metadata: Optional[Mapping[str, Any]], data_version: Optional[str]
self,
asset_key: str,
metadata: Optional[Mapping[str, ExtMetadataValue]],
data_version: Optional[str],
) -> None:
check.str_param(asset_key, "asset_key")
check.opt_str_param(data_version, "data_version")
metadata = check.opt_mapping_param(metadata, "metadata", key_type=str)
resolved_asset_key = AssetKey.from_user_string(asset_key)
resolved_metadata = {
k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
}
if data_version is not None:
self._context.set_data_version(resolved_asset_key, DataVersion(data_version))
if resolved_metadata:
output_name = self._context.output_for_asset_key(resolved_asset_key)
self._context.add_output_metadata(resolved_metadata, output_name)
resolved_metadata = self._resolve_metadata(metadata)
resolved_data_version = None if data_version is None else DataVersion(data_version)
result = MaterializeResult(
asset_key=resolved_asset_key,
metadata=resolved_metadata,
data_version=resolved_data_version,
)
self._result_queue.put(result)
self._unmaterialized_assets.remove(resolved_asset_key)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
Expand Down Expand Up @@ -114,6 +150,9 @@ def get_external_process_env_vars(self):
),
}

def get_results(self) -> Iterator[ExtResult]:
yield from self.message_handler.clear_result_queue()


def build_external_execution_context_data(
context: OpExecutionContext,
Expand Down
9 changes: 6 additions & 3 deletions python_modules/dagster/dagster/_core/ext/subprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from subprocess import Popen
from typing import Mapping, Optional, Sequence, Union
from typing import Iterator, Mapping, Optional, Sequence, Union

from dagster_ext import ExtExtras

Expand All @@ -12,6 +12,7 @@
ExtContextInjector,
ExtMessageReader,
)
from dagster._core.ext.context import ExtResult
from dagster._core.ext.utils import (
ExtTempFileContextInjector,
ExtTempFileMessageReader,
Expand Down Expand Up @@ -66,7 +67,7 @@ def run(
extras: Optional[ExtExtras] = None,
env: Optional[Mapping[str, str]] = None,
cwd: Optional[str] = None,
) -> None:
) -> Iterator[ExtResult]:
with ext_protocol(
context=context,
context_injector=self.context_injector,
Expand All @@ -82,12 +83,14 @@ def run(
**(env or {}),
},
)
process.wait()
while process.poll() is None:
yield from ext_context.get_results()

if process.returncode != 0:
raise DagsterExternalExecutionError(
f"External execution process failed with code {process.returncode}"
)
yield from ext_context.get_results()


ExtSubprocess = ResourceParam[_ExtSubprocess]
Loading

0 comments on commit f7b3fee

Please sign in to comment.