Skip to content

Commit

Permalink
yield MaterializeResult
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Sep 22, 2023
1 parent 6fe5203 commit 76fe8e0
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def number_y(
context: AssetExecutionContext,
ext_k8s_pod: ExtK8sPod,
):
return ext_k8s_pod.run(
yield from ext_k8s_pod.run(
context=context,
namespace=namespace,
image=docker_image,
Expand Down Expand Up @@ -138,7 +138,7 @@ def number_y(
],
)

return ext_k8s_pod.run(
yield from ext_k8s_pod.run(
context=context,
namespace=namespace,
extras={
Expand Down Expand Up @@ -197,7 +197,7 @@ def number_y_job(context: AssetExecutionContext):
k8s_job_name=job_name,
)
reader.consume_pod_logs(core_api, job_name, namespace)
return ext_context.get_materialize_results()
yield from ext_context.get_results()

result = materialize(
[number_y_job],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@

import boto3
import pytest
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.data_version import (
DATA_VERSION_IS_USER_PROVIDED_TAG,
DATA_VERSION_TAG,
)
from dagster._core.definitions.decorators.asset_decorator import asset
from dagster._core.definitions.decorators.asset_decorator import asset, multi_asset
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.materialize import materialize
from dagster._core.definitions.metadata import (
Expand All @@ -30,8 +31,8 @@
TextMetadataValue,
UrlMetadataValue,
)
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import AssetExecutionContext
from dagster._core.errors import DagsterExternalExecutionError, DagsterInvariantViolationError
from dagster._core.execution.context.compute import AssetExecutionContext, OpExecutionContext
from dagster._core.execution.context.invocation import build_asset_context
from dagster._core.ext.subprocess import (
ExtSubprocess,
Expand Down Expand Up @@ -150,7 +151,7 @@ def test_ext_subprocess(
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
return ext.run(
yield from ext.run(
cmd,
context=context,
extras=extras,
Expand All @@ -163,11 +164,7 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
resource = ExtSubprocess(context_injector=context_injector, message_reader=message_reader)

with instance_for_test() as instance:
materialize(
[foo],
instance=instance,
resources={"ext": resource},
)
materialize([foo], instance=instance, resources={"ext": resource})
mat = instance.get_latest_materialization_event(foo.key)
assert mat and mat.asset_materialization
assert isinstance(mat.asset_materialization.metadata["bar"], MarkdownMetadataValue)
Expand All @@ -180,6 +177,35 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)


def test_ext_multi_asset():
def script_fn():
from dagster_ext import init_dagster_ext

context = init_dagster_ext()
context.report_asset_materialization(
{"foo_meta": "ok"}, data_version="alpha", asset_key="foo"
)
context.report_asset_materialization(data_version="alpha", asset_key="bar")

@multi_asset(specs=[AssetSpec("foo"), AssetSpec("bar")])
def foo_bar(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
yield from ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize([foo_bar], instance=instance, resources={"ext": ExtSubprocess()})
foo_mat = instance.get_latest_materialization_event(AssetKey(["foo"]))
assert foo_mat and foo_mat.asset_materialization
assert foo_mat.asset_materialization.metadata["foo_meta"].value == "ok"
assert foo_mat.asset_materialization.tags
assert foo_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
bar_mat = instance.get_latest_materialization_event(AssetKey(["foo"]))
assert bar_mat and bar_mat.asset_materialization
assert bar_mat.asset_materialization.tags
assert bar_mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"


def test_ext_typed_metadata():
def script_fn():
from dagster_ext import init_dagster_ext
Expand Down Expand Up @@ -207,7 +233,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
return ext.run(cmd, context=context)
yield from ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize(
Expand Down Expand Up @@ -254,7 +280,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
ext.run(cmd, context=context)
yield from ext.run(cmd, context=context)

with pytest.raises(DagsterExternalExecutionError):
materialize([foo], resources={"ext": ExtSubprocess()})
Expand Down Expand Up @@ -321,9 +347,7 @@ def subproc_run(context: AssetExecutionContext):
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
_ext_context = ext_context
mat_results = _ext_context.get_materialize_results()
return mat_results[0] if len(mat_results) == 1 else mat_results
yield from ext_context.get_results()

with instance_for_test() as instance:
materialize(
Expand All @@ -338,28 +362,26 @@ def subproc_run(context: AssetExecutionContext):
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]


def test_ext_no_client_premature_get_results(external_script):
@asset
def subproc_run(context: AssetExecutionContext):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
def test_ext_no_client_no_yield():
def script_fn():
pass

with ext_protocol(
context,
ExtTempFileContextInjector(),
ExtTempFileMessageReader(),
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
return ext_context.get_materialize_results()
@asset
def foo(context: OpExecutionContext):
with temp_script(script_fn) as external_script:
with ext_protocol(
context,
ExtTempFileContextInjector(),
ExtTempFileMessageReader(),
) as ext_context:
cmd = [_PYTHON_EXECUTABLE, external_script]
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)

with pytest.raises(
DagsterExternalExecutionError,
DagsterInvariantViolationError,
match=(
"`get_materialize_results` must be called after the `ext_protocol` context manager has"
" exited."
r"did not yield or return expected outputs.*Did you forget to `yield from"
r" ext_context.get_results\(\)`?"
),
):
materialize(
[subproc_run],
)
materialize([foo])
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_core/ext/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterator, Optional

from dagster_ext import (
ExtContextData,
Expand All @@ -23,7 +23,7 @@ def run(
*,
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
) -> Union["MaterializeResult", Tuple["MaterializeResult", ...]]: ...
) -> Iterator["MaterializeResult"]: ...


class ExtContextInjector(ABC):
Expand Down
57 changes: 30 additions & 27 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Mapping, Optional
from typing import Any, Mapping, Optional, get_args
from queue import Queue
from typing import Any, Dict, Iterator, Mapping, Optional, Set
from typing import Any, Dict, Mapping, Optional, Tuple, get_args

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
Expand All @@ -25,14 +24,31 @@
from dagster._core.definitions.partition_key_range import PartitionKeyRange
from dagster._core.definitions.result import MaterializeResult
from dagster._core.definitions.time_window_partitions import TimeWindow
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.invocation import BoundOpExecutionContext
from dagster._core.ext.client import ExtMessageReader


class ExtMessageHandler:
def __init__(self, context: OpExecutionContext) -> None:
self._context = context
# Queue is thread-safe
self._result_queue: Queue[MaterializeResult] = Queue()
# Only read by the main thread after all messages are handled, so no need for a lock
self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys)
self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
self._data_versions: Dict[AssetKey, DataVersion] = {}

@contextmanager
def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]:
with message_reader.read_messages(self) as params:
yield params
for key in self._unmaterialized_assets:
self._result_queue.put(MaterializeResult(asset_key=key))

def clear_result_queue(self) -> Iterator[MaterializeResult]:
while not self._result_queue.empty():
yield self._result_queue.get()

def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) -> MetadataValue:
if metadata_type == EXT_METADATA_TYPE_INFER:
Expand Down Expand Up @@ -83,11 +99,14 @@ def _handle_report_asset_materialization(
resolved_metadata = {
k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
}
if data_version is not None:
self._context.set_data_version(resolved_asset_key, DataVersion(data_version))
if resolved_metadata:
output_name = self._context.output_for_asset_key(resolved_asset_key)
self._context.add_output_metadata(resolved_metadata, output_name)
resolved_data_version = None if data_version is None else DataVersion(data_version)
result = MaterializeResult(
asset_key=resolved_asset_key,
metadata=resolved_metadata,
data_version=resolved_data_version,
)
self._result_queue.put(result)
self._unmaterialized_assets.remove(resolved_asset_key)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
Expand All @@ -109,7 +128,6 @@ class ExtOrchestrationContext:
message_handler: ExtMessageHandler
context_injector_params: ExtParams
message_reader_params: ExtParams
is_task_finished: bool = False

def get_external_process_env_vars(self):
return {
Expand All @@ -120,23 +138,8 @@ def get_external_process_env_vars(self):
),
}

def get_materialize_results(self) -> Tuple[MaterializeResult, ...]:
if not self.is_task_finished:
raise DagsterExternalExecutionError(
"`get_materialize_results` must be called after the `ext_protocol` context manager"
" has exited."
)
return tuple(
self._materialize_result_for_asset(AssetKey.from_user_string(key))
for key in self.context_data["asset_keys"] or []
)

def _materialize_result_for_asset(self, asset_key: AssetKey):
return MaterializeResult(
asset_key=asset_key,
metadata=self.message_handler.metadata.get(asset_key),
data_version=self.message_handler.data_versions.get(asset_key),
)
def get_results(self) -> Iterator[MaterializeResult]:
yield from self.message_handler.clear_result_queue()


def build_external_execution_context_data(
Expand Down
10 changes: 5 additions & 5 deletions python_modules/dagster/dagster/_core/ext/subprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from subprocess import Popen
from typing import Mapping, Optional, Sequence, Tuple, Union
from typing import Iterator, Mapping, Optional, Sequence, Union

from dagster_ext import ExtExtras

Expand Down Expand Up @@ -67,7 +67,7 @@ def run(
extras: Optional[ExtExtras] = None,
env: Optional[Mapping[str, str]] = None,
cwd: Optional[str] = None,
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
) -> Iterator[MaterializeResult]:
with ext_protocol(
context=context,
context_injector=self.context_injector,
Expand All @@ -83,14 +83,14 @@ def run(
**(env or {}),
},
)
process.wait()
while process.poll() is not None:
yield from ext_context.get_results()

if process.returncode != 0:
raise DagsterExternalExecutionError(
f"External execution process failed with code {process.returncode}"
)
mat_results = ext_context.get_materialize_results()
return mat_results[0] if len(mat_results) == 1 else mat_results
yield from ext_context.get_results()


ExtSubprocess = ResourceParam[_ExtSubprocess]
18 changes: 11 additions & 7 deletions python_modules/dagster/dagster/_core/ext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ def extract_message_or_forward_to_stdout(handler: "ExtMessageHandler", log_line:
sys.stdout.writelines((log_line, "\n"))


_FAIL_TO_YIELD_ERROR_MESSAGE = (
"Did you forget to `yield from ext_context.get_results()`? This should be called once after the"
" `ext_protocol` block has exited to yield any remaining buffered results."
)


@contextmanager
def ext_protocol(
context: OpExecutionContext,
Expand All @@ -195,18 +201,16 @@ def ext_protocol(
"""Enter the context managed context injector and message reader that power the EXT protocol and receive the environment variables
that need to be provided to the external process.
"""
# This will trigger an error if expected outputs are not yielded
context.set_require_typed_event_stream(error_message=_FAIL_TO_YIELD_ERROR_MESSAGE)
context_data = build_external_execution_context_data(context, extras)
message_handler = ExtMessageHandler(context)
with context_injector.inject_context(
context_data,
) as ci_params, message_reader.read_messages(
message_handler,
) as mr_params:
ext_context = ExtOrchestrationContext(
context_data
) as ci_params, message_handler.handle_messages(message_reader) as mr_params:
yield ExtOrchestrationContext(
context_data=context_data,
message_handler=message_handler,
context_injector_params=ci_params,
message_reader_params=mr_params,
)
yield ext_context
ext_context.is_task_finished = True
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import string
import time
from contextlib import contextmanager
from typing import Iterator, Mapping, Optional, Tuple, Union
from typing import Iterator, Mapping, Optional

import dagster._check as check
from dagster._core.definitions.resource_annotation import ResourceParam
Expand Down Expand Up @@ -67,7 +67,7 @@ def run(
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
submit_args: Optional[Mapping[str, str]] = None,
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
) -> Iterator[MaterializeResult]:
"""Run a Databricks job with the EXT protocol.
Args:
Expand Down Expand Up @@ -116,8 +116,9 @@ def run(
raise DagsterExternalExecutionError(
f"Error running Databricks job: {run.state.state_message}"
)
yield from ext_context.get_results()
time.sleep(5)
return ext_context.get_materialize_results()
yield from ext_context.get_results()


ExtDatabricks = ResourceParam[_ExtDatabricks]
Expand Down
Loading

0 comments on commit 76fe8e0

Please sign in to comment.