Skip to content

Commit

Permalink
yield MaterializeResult
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Sep 22, 2023
1 parent 6fe5203 commit 90fbb88
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def number_y(
context: AssetExecutionContext,
ext_k8s_pod: ExtK8sPod,
):
return ext_k8s_pod.run(
yield from ext_k8s_pod.run(
context=context,
namespace=namespace,
image=docker_image,
Expand Down Expand Up @@ -138,7 +138,7 @@ def number_y(
],
)

return ext_k8s_pod.run(
yield from ext_k8s_pod.run(
context=context,
namespace=namespace,
extras={
Expand Down Expand Up @@ -197,7 +197,7 @@ def number_y_job(context: AssetExecutionContext):
k8s_job_name=job_name,
)
reader.consume_pod_logs(core_api, job_name, namespace)
return ext_context.get_materialize_results()
yield from ext_context.get_materialize_results()

result = materialize(
[number_y_job],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

import boto3
import pytest
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.data_version import (
DATA_VERSION_IS_USER_PROVIDED_TAG,
DATA_VERSION_TAG,
)
from dagster._core.definitions.decorators.asset_decorator import asset
from dagster._core.definitions.decorators.asset_decorator import asset, multi_asset
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.materialize import materialize
from dagster._core.definitions.metadata import (
Expand All @@ -30,8 +31,8 @@
TextMetadataValue,
UrlMetadataValue,
)
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import AssetExecutionContext
from dagster._core.errors import DagsterExternalExecutionError, DagsterInvariantViolationError
from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext
from dagster._core.execution.context.invocation import build_asset_context
from dagster._core.ext.subprocess import (
ExtSubprocess,
Expand Down Expand Up @@ -150,7 +151,7 @@ def test_ext_subprocess(
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
return ext.run(
yield from ext.run(
cmd,
context=context,
extras=extras,
Expand All @@ -163,11 +164,7 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
resource = ExtSubprocess(context_injector=context_injector, message_reader=message_reader)

with instance_for_test() as instance:
materialize(
[foo],
instance=instance,
resources={"ext": resource},
)
materialize([foo], instance=instance, resources={"ext": resource})
mat = instance.get_latest_materialization_event(foo.key)
assert mat and mat.asset_materialization
assert isinstance(mat.asset_materialization.metadata["bar"], MarkdownMetadataValue)
Expand All @@ -180,6 +177,35 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)


def test_ext_multi_asset():
def script_fn():
from dagster_ext import init_dagster_ext

context = init_dagster_ext()
context.report_asset_materialization(
{"foo_meta": "ok"}, data_version="alpha", asset_key="foo"
)
context.report_asset_materialization(data_version="alpha", asset_key="bar")

@multi_asset(specs=[AssetSpec("foo"), AssetSpec("bar")])
def foo_bar(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
yield from ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize([foo_bar], instance=instance, resources={"ext": ExtSubprocess()})
foo_mat = instance.get_latest_materialization_event(AssetKey(["foo"]))
assert foo_mat and foo_mat.asset_materialization
assert foo_mat.asset_materialization.metadata["foo_meta"].value == "ok"
assert foo_mat.asset_materialization.tags
assert foo_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
bar_mat = instance.get_latest_materialization_event(AssetKey(["foo"]))
assert bar_mat and bar_mat.asset_materialization
assert bar_mat.asset_materialization.tags
assert bar_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"


def test_ext_typed_metadata():
def script_fn():
from dagster_ext import init_dagster_ext
Expand Down Expand Up @@ -207,7 +233,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
return ext.run(cmd, context=context)
yield from ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize(
Expand Down Expand Up @@ -254,7 +280,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
ext.run(cmd, context=context)
yield from ext.run(cmd, context=context)

with pytest.raises(DagsterExternalExecutionError):
materialize([foo], resources={"ext": ExtSubprocess()})
Expand Down Expand Up @@ -321,9 +347,7 @@ def subproc_run(context: AssetExecutionContext):
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
_ext_context = ext_context
mat_results = _ext_context.get_materialize_results()
return mat_results[0] if len(mat_results) == 1 else mat_results
yield from ext_context.get_materialize_results()

with instance_for_test() as instance:
materialize(
Expand All @@ -338,28 +362,22 @@ def subproc_run(context: AssetExecutionContext):
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]


def test_ext_no_client_premature_get_results(external_script):
@asset
def subproc_run(context: AssetExecutionContext):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
def test_ext_no_client_no_yield():
def script_fn():
pass

with ext_protocol(
context,
ExtTempFileContextInjector(),
ExtTempFileMessageReader(),
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
return ext_context.get_materialize_results()
@asset
def foo(context: OpExecutionContext):
with temp_script(script_fn) as external_script:
with ext_protocol(
context,
ExtTempFileContextInjector(),
ExtTempFileMessageReader(),
) as ext_context:
cmd = [_PYTHON_EXECUTABLE, external_script]
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)

with pytest.raises(
DagsterExternalExecutionError,
match=(
"`get_materialize_results` must be called after the `ext_protocol` context manager has"
" exited."
),
DagsterInvariantViolationError, match="did not yield or return expected outputs"
):
materialize(
[subproc_run],
)
materialize([foo])
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_core/ext/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterator, Optional

from dagster_ext import (
ExtContextData,
Expand All @@ -23,7 +23,7 @@ def run(
*,
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
) -> Union["MaterializeResult", Tuple["MaterializeResult", ...]]: ...
) -> Iterator["MaterializeResult"]: ...


class ExtContextInjector(ABC):
Expand Down
57 changes: 30 additions & 27 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Mapping, Optional
from typing import Any, Mapping, Optional, get_args
from queue import Queue
from typing import Any, Dict, Iterator, Mapping, Optional, Set
from typing import Any, Dict, Mapping, Optional, Tuple, get_args

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
Expand All @@ -25,14 +24,31 @@
from dagster._core.definitions.partition_key_range import PartitionKeyRange
from dagster._core.definitions.result import MaterializeResult
from dagster._core.definitions.time_window_partitions import TimeWindow
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.invocation import BoundOpExecutionContext
from dagster._core.ext.client import ExtMessageReader


class ExtMessageHandler:
def __init__(self, context: OpExecutionContext) -> None:
self._context = context
# Queue is thread-safe
self._result_queue: Queue[MaterializeResult] = Queue()
# Only read by the main thread after all messages are handled, so no need for a lock
self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys)
self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
self._data_versions: Dict[AssetKey, DataVersion] = {}

@contextmanager
def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]:
with message_reader.read_messages(self) as params:
yield params
for key in self._unmaterialized_assets:
self._result_queue.put(MaterializeResult(asset_key=key))

def clear_result_queue(self) -> Iterator[MaterializeResult]:
while not self._result_queue.empty():
yield self._result_queue.get()

def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) -> MetadataValue:
if metadata_type == EXT_METADATA_TYPE_INFER:
Expand Down Expand Up @@ -83,11 +99,14 @@ def _handle_report_asset_materialization(
resolved_metadata = {
k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
}
if data_version is not None:
self._context.set_data_version(resolved_asset_key, DataVersion(data_version))
if resolved_metadata:
output_name = self._context.output_for_asset_key(resolved_asset_key)
self._context.add_output_metadata(resolved_metadata, output_name)
resolved_data_version = None if data_version is None else DataVersion(data_version)
result = MaterializeResult(
asset_key=resolved_asset_key,
metadata=resolved_metadata,
data_version=resolved_data_version,
)
self._result_queue.put(result)
self._unmaterialized_assets.remove(resolved_asset_key)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
Expand All @@ -109,7 +128,6 @@ class ExtOrchestrationContext:
message_handler: ExtMessageHandler
context_injector_params: ExtParams
message_reader_params: ExtParams
is_task_finished: bool = False

def get_external_process_env_vars(self):
return {
Expand All @@ -120,23 +138,8 @@ def get_external_process_env_vars(self):
),
}

def get_materialize_results(self) -> Tuple[MaterializeResult, ...]:
if not self.is_task_finished:
raise DagsterExternalExecutionError(
"`get_materialize_results` must be called after the `ext_protocol` context manager"
" has exited."
)
return tuple(
self._materialize_result_for_asset(AssetKey.from_user_string(key))
for key in self.context_data["asset_keys"] or []
)

def _materialize_result_for_asset(self, asset_key: AssetKey):
return MaterializeResult(
asset_key=asset_key,
metadata=self.message_handler.metadata.get(asset_key),
data_version=self.message_handler.data_versions.get(asset_key),
)
def get_materialize_results(self) -> Iterator[MaterializeResult]:
yield from self.message_handler.clear_result_queue()


def build_external_execution_context_data(
Expand Down
12 changes: 7 additions & 5 deletions python_modules/dagster/dagster/_core/ext/subprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from subprocess import Popen
from typing import Mapping, Optional, Sequence, Tuple, Union
from typing import Iterator, Mapping, Optional, Sequence, Union

from dagster_ext import ExtExtras

Expand Down Expand Up @@ -67,7 +67,7 @@ def run(
extras: Optional[ExtExtras] = None,
env: Optional[Mapping[str, str]] = None,
cwd: Optional[str] = None,
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
) -> Iterator[MaterializeResult]:
with ext_protocol(
context=context,
context_injector=self.context_injector,
Expand All @@ -83,14 +83,16 @@ def run(
**(env or {}),
},
)
process.wait()
while True:
yield from ext_context.get_materialize_results()
if process.poll() is not None:
break

if process.returncode != 0:
raise DagsterExternalExecutionError(
f"External execution process failed with code {process.returncode}"
)
mat_results = ext_context.get_materialize_results()
return mat_results[0] if len(mat_results) == 1 else mat_results
yield from ext_context.get_materialize_results()


ExtSubprocess = ResourceParam[_ExtSubprocess]
12 changes: 5 additions & 7 deletions python_modules/dagster/dagster/_core/ext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,16 @@ def ext_protocol(
"""Enter the context managed context injector and message reader that power the EXT protocol and receive the environment variables
that need to be provided to the external process.
"""
# This will trigger an error if expected outputs are not yielded
context.set_explicit_mode()
context_data = build_external_execution_context_data(context, extras)
message_handler = ExtMessageHandler(context)
with context_injector.inject_context(
context_data,
) as ci_params, message_reader.read_messages(
message_handler,
) as mr_params:
ext_context = ExtOrchestrationContext(
context_data
) as ci_params, message_handler.handle_messages(message_reader) as mr_params:
yield ExtOrchestrationContext(
context_data=context_data,
message_handler=message_handler,
context_injector_params=ci_params,
message_reader_params=mr_params,
)
yield ext_context
ext_context.is_task_finished = True
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import string
import time
from contextlib import contextmanager
from typing import Iterator, Mapping, Optional, Tuple, Union
from typing import Iterator, Mapping, Optional

import dagster._check as check
from dagster._core.definitions.resource_annotation import ResourceParam
Expand Down Expand Up @@ -67,7 +67,7 @@ def run(
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
submit_args: Optional[Mapping[str, str]] = None,
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
) -> Iterator[MaterializeResult]:
"""Run a Databricks job with the EXT protocol.
Args:
Expand Down Expand Up @@ -116,8 +116,9 @@ def run(
raise DagsterExternalExecutionError(
f"Error running Databricks job: {run.state.state_message}"
)
yield from ext_context.get_materialize_results()
time.sleep(5)
return ext_context.get_materialize_results()
yield from ext_context.get_materialize_results()


ExtDatabricks = ResourceParam[_ExtDatabricks]
Expand Down
Loading

0 comments on commit 90fbb88

Please sign in to comment.