Skip to content

Commit

Permalink
yield MaterializeResult
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Sep 22, 2023
1 parent 02af50e commit d884d4b
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def number_y(
context: AssetExecutionContext,
ext_k8s_pod: ExtK8sPod,
):
return ext_k8s_pod.run(
yield from ext_k8s_pod.run(
context=context,
namespace=namespace,
image=docker_image,
Expand Down Expand Up @@ -138,7 +138,7 @@ def number_y(
],
)

return ext_k8s_pod.run(
yield from ext_k8s_pod.run(
context=context,
namespace=namespace,
extras={
Expand Down Expand Up @@ -197,7 +197,7 @@ def number_y_job(context: AssetExecutionContext):
k8s_job_name=job_name,
)
reader.consume_pod_logs(core_api, job_name, namespace)
return ext_context.get_materialize_results()
yield from ext_context.get_results()

result = materialize(
[number_y_job],
Expand Down
15 changes: 7 additions & 8 deletions python_modules/dagster-ext/dagster_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Mapping,
Optional,
Sequence,
Set,
TextIO,
Type,
TypedDict,
Expand Down Expand Up @@ -165,7 +164,6 @@ def _resolve_optionally_passed_asset_key(
data: ExtContextData,
asset_key: Optional[str],
method: str,
already_materialized_assets: Set[str],
) -> str:
asset_keys = _assert_defined_asset_property(data["asset_keys"], method)
asset_key = _assert_opt_param_type(asset_key, str, method, "asset_key")
Expand All @@ -180,11 +178,6 @@ def _resolve_optionally_passed_asset_key(
" targets multiple assets."
)
asset_key = asset_keys[0]
if asset_key in already_materialized_assets:
raise DagsterExtError(
f"Calling `{method}` with asset key `{asset_key}` is undefined. Asset has already been"
" materialized, so no additional data can be reported for it."
)
return asset_key


Expand Down Expand Up @@ -784,8 +777,14 @@ def report_asset_materialization(
asset_key: Optional[str] = None,
):
asset_key = _resolve_optionally_passed_asset_key(
self._data, asset_key, "report_asset_materialization", self._materialized_assets
self._data, asset_key, "report_asset_materialization"
)
if asset_key in self._materialized_assets:
raise DagsterExtError(
f"Calling `report_asset_materialization` with asset key `{asset_key}` is undefined."
" Asset has already been materialized, so no additional data can be reported"
" for it."
)
metadata = (
_normalize_param_metadata(metadata, "report_asset_materialization", "metadata")
if metadata
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

import boto3
import pytest
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.data_version import (
DATA_VERSION_IS_USER_PROVIDED_TAG,
DATA_VERSION_TAG,
)
from dagster._core.definitions.decorators.asset_decorator import asset
from dagster._core.definitions.decorators.asset_decorator import asset, multi_asset
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.materialize import materialize
from dagster._core.definitions.metadata import (
Expand All @@ -30,8 +31,8 @@
TextMetadataValue,
UrlMetadataValue,
)
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import AssetExecutionContext
from dagster._core.errors import DagsterExternalExecutionError, DagsterInvariantViolationError
from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext
from dagster._core.execution.context.invocation import build_asset_context
from dagster._core.ext.subprocess import (
ExtSubprocess,
Expand Down Expand Up @@ -150,7 +151,7 @@ def test_ext_subprocess(
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
return ext.run(
yield from ext.run(
cmd,
context=context,
extras=extras,
Expand All @@ -163,11 +164,7 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
resource = ExtSubprocess(context_injector=context_injector, message_reader=message_reader)

with instance_for_test() as instance:
materialize(
[foo],
instance=instance,
resources={"ext": resource},
)
materialize([foo], instance=instance, resources={"ext": resource})
mat = instance.get_latest_materialization_event(foo.key)
assert mat and mat.asset_materialization
assert isinstance(mat.asset_materialization.metadata["bar"], MarkdownMetadataValue)
Expand All @@ -180,6 +177,35 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)


def test_ext_multi_asset():
def script_fn():
from dagster_ext import init_dagster_ext

context = init_dagster_ext()
context.report_asset_materialization(
{"foo_meta": "ok"}, data_version="alpha", asset_key="foo"
)
context.report_asset_materialization(data_version="alpha", asset_key="bar")

@multi_asset(specs=[AssetSpec("foo"), AssetSpec("bar")])
def foo_bar(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
yield from ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize([foo_bar], instance=instance, resources={"ext": ExtSubprocess()})
foo_mat = instance.get_latest_materialization_event(AssetKey(["foo"]))
assert foo_mat and foo_mat.asset_materialization
assert foo_mat.asset_materialization.metadata["foo_meta"].value == "ok"
assert foo_mat.asset_materialization.tags
assert foo_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
bar_mat = instance.get_latest_materialization_event(AssetKey(["foo"]))
assert bar_mat and bar_mat.asset_materialization
assert bar_mat.asset_materialization.tags
assert bar_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"


def test_ext_typed_metadata():
def script_fn():
from dagster_ext import init_dagster_ext
Expand Down Expand Up @@ -207,7 +233,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
return ext.run(cmd, context=context)
yield from ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize(
Expand Down Expand Up @@ -254,7 +280,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
ext.run(cmd, context=context)
yield from ext.run(cmd, context=context)

with pytest.raises(DagsterExternalExecutionError):
materialize([foo], resources={"ext": ExtSubprocess()})
Expand Down Expand Up @@ -321,9 +347,7 @@ def subproc_run(context: AssetExecutionContext):
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
_ext_context = ext_context
mat_results = _ext_context.get_materialize_results()
return mat_results[0] if len(mat_results) == 1 else mat_results
yield from ext_context.get_results()

with instance_for_test() as instance:
materialize(
Expand All @@ -338,28 +362,26 @@ def subproc_run(context: AssetExecutionContext):
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]


def test_ext_no_client_premature_get_results(external_script):
@asset
def subproc_run(context: AssetExecutionContext):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
def test_ext_no_client_no_yield():
def script_fn():
pass

with ext_protocol(
context,
ExtTempFileContextInjector(),
ExtTempFileMessageReader(),
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
return ext_context.get_materialize_results()
@asset
def foo(context: OpExecutionContext):
with temp_script(script_fn) as external_script:
with ext_protocol(
context,
ExtTempFileContextInjector(),
ExtTempFileMessageReader(),
) as ext_context:
cmd = [_PYTHON_EXECUTABLE, external_script]
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)

with pytest.raises(
DagsterExternalExecutionError,
DagsterInvariantViolationError,
match=(
"`get_materialize_results` must be called after the `ext_protocol` context manager has"
" exited."
r"did not yield or return expected outputs.*Did you forget to `yield from"
r" ext_context.get_results\(\)`?"
),
):
materialize(
[subproc_run],
)
materialize([foo])
8 changes: 3 additions & 5 deletions python_modules/dagster/dagster/_core/ext/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterator, Optional

from dagster_ext import (
ExtContextData,
Expand All @@ -11,9 +11,7 @@
from dagster._core.execution.context.compute import OpExecutionContext

if TYPE_CHECKING:
from dagster._core.definitions.result import MaterializeResult

from .context import ExtMessageHandler
from .context import ExtMessageHandler, ExtResult


class ExtClient(ABC):
Expand All @@ -23,7 +21,7 @@ def run(
*,
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
) -> Union["MaterializeResult", Tuple["MaterializeResult", ...]]: ...
) -> Iterator["ExtResult"]: ...


class ExtContextInjector(ABC):
Expand Down
60 changes: 33 additions & 27 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Mapping, Optional
from typing import Any, Mapping, Optional, get_args
from queue import Queue
from typing import Any, Dict, Iterator, Mapping, Optional, Set
from typing import Any, Dict, Mapping, Optional, Tuple, get_args

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
Expand All @@ -17,6 +16,7 @@
ExtTimeWindow,
encode_env_var,
)
from typing_extensions import TypeAlias

import dagster._check as check
from dagster._core.definitions.data_version import DataProvenance, DataVersion
Expand All @@ -25,14 +25,33 @@
from dagster._core.definitions.partition_key_range import PartitionKeyRange
from dagster._core.definitions.result import MaterializeResult
from dagster._core.definitions.time_window_partitions import TimeWindow
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.invocation import BoundOpExecutionContext
from dagster._core.ext.client import ExtMessageReader

ExtResult: TypeAlias = MaterializeResult


class ExtMessageHandler:
def __init__(self, context: OpExecutionContext) -> None:
self._context = context
# Queue is thread-safe
self._result_queue: Queue[ExtResult] = Queue()
# Only read by the main thread after all messages are handled, so no need for a lock
self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys)
self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
self._data_versions: Dict[AssetKey, DataVersion] = {}

@contextmanager
def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]:
with message_reader.read_messages(self) as params:
yield params
for key in self._unmaterialized_assets:
self._result_queue.put(MaterializeResult(asset_key=key))

def clear_result_queue(self) -> Iterator[ExtResult]:
while not self._result_queue.empty():
yield self._result_queue.get()

def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) -> MetadataValue:
if metadata_type == EXT_METADATA_TYPE_INFER:
Expand Down Expand Up @@ -83,11 +102,14 @@ def _handle_report_asset_materialization(
resolved_metadata = {
k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
}
if data_version is not None:
self._context.set_data_version(resolved_asset_key, DataVersion(data_version))
if resolved_metadata:
output_name = self._context.output_for_asset_key(resolved_asset_key)
self._context.add_output_metadata(resolved_metadata, output_name)
resolved_data_version = None if data_version is None else DataVersion(data_version)
result = MaterializeResult(
asset_key=resolved_asset_key,
metadata=resolved_metadata,
data_version=resolved_data_version,
)
self._result_queue.put(result)
self._unmaterialized_assets.remove(resolved_asset_key)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
Expand All @@ -109,7 +131,6 @@ class ExtOrchestrationContext:
message_handler: ExtMessageHandler
context_injector_params: ExtParams
message_reader_params: ExtParams
is_task_finished: bool = False

def get_external_process_env_vars(self):
return {
Expand All @@ -120,23 +141,8 @@ def get_external_process_env_vars(self):
),
}

def get_materialize_results(self) -> Tuple[MaterializeResult, ...]:
if not self.is_task_finished:
raise DagsterExternalExecutionError(
"`get_materialize_results` must be called after the `ext_protocol` context manager"
" has exited."
)
return tuple(
self._materialize_result_for_asset(AssetKey.from_user_string(key))
for key in self.context_data["asset_keys"] or []
)

def _materialize_result_for_asset(self, asset_key: AssetKey):
return MaterializeResult(
asset_key=asset_key,
metadata=self.message_handler.metadata.get(asset_key),
data_version=self.message_handler.data_versions.get(asset_key),
)
def get_results(self) -> Iterator[ExtResult]:
yield from self.message_handler.clear_result_queue()


def build_external_execution_context_data(
Expand Down
Loading

0 comments on commit d884d4b

Please sign in to comment.