Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Sep 21, 2023
1 parent 5ba294c commit 536624b
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def number_y_job(context: AssetExecutionContext):
k8s_job_name=job_name,
)
reader.consume_pod_logs(core_api, job_name, namespace)
yield from ext_context.get_materialize_results()
yield from ext_context.get_results()

result = materialize(
[number_y_job],
Expand Down
35 changes: 33 additions & 2 deletions python_modules/dagster-ext/dagster_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class ExtDataProvenance(TypedDict):
is_user_provided: bool


ExtAssetCheckSeverity = Literal["WARN", "ERROR"]

ExtMetadataRawValue = Union[int, float, str, Mapping[str, Any], Sequence[Any], bool, None]


Expand Down Expand Up @@ -165,7 +167,7 @@ def _resolve_optionally_passed_asset_key(
data: ExtContextData,
asset_key: Optional[str],
method: str,
already_materialized_assets: Set[str],
already_materialized_assets: Optional[Set[str]] = None,
) -> str:
asset_keys = _assert_defined_asset_property(data["asset_keys"], method)
asset_key = _assert_opt_param_type(asset_key, str, method, "asset_key")
Expand All @@ -180,7 +182,7 @@ def _resolve_optionally_passed_asset_key(
" targets multiple assets."
)
asset_key = asset_keys[0]
if asset_key in already_materialized_assets:
if already_materialized_assets is not None and asset_key in already_materialized_assets:
raise DagsterExtError(
f"Calling `{method}` with asset key `{asset_key}` is undefined. Asset has already been"
" materialized, so no additional data can be reported for it."
Expand Down Expand Up @@ -800,6 +802,35 @@ def report_asset_materialization(
)
self._materialized_assets.add(asset_key)

def report_asset_check(
self,
check_name: str,
success: bool,
severity: ExtAssetCheckSeverity = "ERROR",
metadata: Optional[Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]] = None,
asset_key: Optional[str] = None,
) -> None:
asset_key = _resolve_optionally_passed_asset_key(
self._data, asset_key, "report_asset_check"
)
check_name = _assert_param_type(check_name, str, "report_asset_check", "check_name")
success = _assert_param_type(success, bool, "report_asset_check", "success")
metadata = (
_normalize_param_metadata(metadata, "report_asset_check", "metadata")
if metadata
else None
)
self._write_message(
"report_asset_check",
{
"asset_key": asset_key,
"check_name": check_name,
"success": success,
"metadata": metadata,
"severity": severity,
},
)

def log(self, message: str, level: str = "info") -> None:
message = _assert_param_type(message, str, "log", "asset_key")
level = _assert_param_value(level, ["info", "warning", "error"], "log", "level")
Expand Down
12 changes: 12 additions & 0 deletions python_modules/dagster-ext/dagster_ext_tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def test_single_asset_context():
)

_assert_unknown_asset_key(context, "report_asset_materialization", asset_key="fake")
context.report_asset_check(
"foo_check",
True,
metadata={
"meta_1": 1,
"meta_2": {"raw_value": "foo", "type": "text"},
},
)

_assert_unknown_asset_key(context, "report_asset_check", "foo_check", True, asset_key="fake")


def test_multi_asset_context():
Expand Down Expand Up @@ -115,6 +125,8 @@ def test_multi_asset_context():

_assert_undefined_asset_key(context, "report_asset_materialization", "bar")
_assert_unknown_asset_key(context, "report_asset_materialization", "bar", asset_key="fake")
_assert_undefined_asset_key(context, "report_asset_check", "foo_check", True)
_assert_unknown_asset_key(context, "report_asset_check", "foo_check", True, asset_key="fake")


def test_no_partition_context():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import boto3
import pytest
from dagster._core.definitions.asset_check_spec import AssetCheckSpec
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.data_version import (
DATA_VERSION_IS_USER_PROVIDED_TAG,
Expand Down Expand Up @@ -44,6 +45,7 @@
ext_protocol,
)
from dagster._core.instance_for_test import instance_for_test
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus
from dagster_aws.ext import ExtS3MessageReader
from moto.server import ThreadedMotoServer

Expand Down Expand Up @@ -96,6 +98,15 @@ def script_fn():
metadata={"bar": {"raw_value": context.get_extra("bar"), "type": "md"}},
data_version="alpha",
)
context.report_asset_check(
"foo_check",
success=True,
severity="WARN",
metadata={
"meta_1": 1,
"meta_2": {"raw_value": "foo", "type": "text"},
},
)

with temp_script(script_fn) as script_path:
yield script_path
Expand Down Expand Up @@ -147,7 +158,7 @@ def test_ext_subprocess(
else:
assert False, "Unreachable"

@asset
@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["foo"]))])
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
Expand Down Expand Up @@ -176,6 +187,14 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
captured = capsys.readouterr()
assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)

asset_check_executions = instance.event_log_storage.get_asset_check_executions(
asset_key=foo.key,
check_name="foo_check",
limit=1,
)
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED


def test_ext_multi_asset():
def script_fn():
Expand Down Expand Up @@ -335,7 +354,7 @@ def script_fn():


def test_ext_no_client(external_script):
@asset
@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["subproc_run"]))])
def subproc_run(context: AssetExecutionContext):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
Expand All @@ -347,7 +366,7 @@ def subproc_run(context: AssetExecutionContext):
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
yield from ext_context.get_materialize_results()
yield from ext_context.get_results()

with instance_for_test() as instance:
materialize(
Expand All @@ -361,6 +380,14 @@ def subproc_run(context: AssetExecutionContext):
assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]

asset_check_executions = instance.event_log_storage.get_asset_check_executions(
asset_key=subproc_run.key,
check_name="foo_check",
limit=1,
)
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED


def test_ext_no_client_no_yield():
def script_fn():
Expand Down
6 changes: 2 additions & 4 deletions python_modules/dagster/dagster/_core/ext/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from dagster._core.execution.context.compute import OpExecutionContext

if TYPE_CHECKING:
from dagster._core.definitions.result import MaterializeResult

from .context import ExtMessageHandler
from .context import ExtMessageHandler, ExtResult


class ExtClient(ABC):
Expand All @@ -23,7 +21,7 @@ def run(
*,
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
) -> Iterator["MaterializeResult"]: ...
) -> Iterator["ExtResult"]: ...


class ExtContextInjector(ABC):
Expand Down
45 changes: 39 additions & 6 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from queue import Queue
from typing import Any, Dict, Iterator, Mapping, Optional, Set
from typing import Any, Iterator, Mapping, Optional, Set, Union

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
Expand All @@ -12,12 +12,16 @@
ExtExtras,
ExtMessage,
ExtMetadataType,
ExtMetadataValue,
ExtParams,
ExtTimeWindow,
encode_env_var,
)
from typing_extensions import TypeAlias

import dagster._check as check
from dagster._core.definitions.asset_check_result import AssetCheckResult
from dagster._core.definitions.asset_check_spec import AssetCheckSeverity
from dagster._core.definitions.data_version import DataProvenance, DataVersion
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value
Expand All @@ -28,16 +32,16 @@
from dagster._core.execution.context.invocation import BoundOpExecutionContext
from dagster._core.ext.client import ExtMessageReader

ExtResult: TypeAlias = Union[MaterializeResult, AssetCheckResult]


class ExtMessageHandler:
def __init__(self, context: OpExecutionContext) -> None:
self._context = context
# Queue is thread-safe
self._result_queue: Queue[MaterializeResult] = Queue()
self._result_queue: Queue[ExtResult] = Queue()
# Only read by the main thread after all messages are handled, so no need for a lock
self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys)
self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
self._data_versions: Dict[AssetKey, DataVersion] = {}

@contextmanager
def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]:
Expand All @@ -46,7 +50,7 @@ def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParam
for key in self._unmaterialized_assets:
self._result_queue.put(MaterializeResult(asset_key=key))

def clear_result_queue(self) -> Iterator[MaterializeResult]:
def clear_result_queue(self) -> Iterator[ExtResult]:
while not self._result_queue.empty():
yield self._result_queue.get()

Expand Down Expand Up @@ -86,6 +90,8 @@ def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) ->
def handle_message(self, message: ExtMessage) -> None:
if message["method"] == "report_asset_materialization":
self._handle_report_asset_materialization(**message["params"]) # type: ignore
elif message["method"] == "report_asset_check":
self._handle_report_asset_check(**message["params"]) # type: ignore
elif message["method"] == "log":
self._handle_log(**message["params"]) # type: ignore

Expand All @@ -108,6 +114,33 @@ def _handle_report_asset_materialization(
self._result_queue.put(result)
self._unmaterialized_assets.remove(resolved_asset_key)

def _handle_report_asset_check(
self,
asset_key: str,
check_name: str,
success: bool,
severity: str,
metadata: Mapping[str, ExtMetadataValue],
) -> None:
check.str_param(asset_key, "asset_key")
check.str_param(check_name, "check_name")
check.bool_param(success, "success")
check.literal_param(severity, "severity", [x.value for x in AssetCheckSeverity])
metadata = check.opt_mapping_param(metadata, "metadata", key_type=str)
resolved_asset_key = AssetKey.from_user_string(asset_key)
resolved_metadata = {
k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
}
resolved_severity = AssetCheckSeverity(severity)
result = AssetCheckResult(
asset_key=resolved_asset_key,
check_name=check_name,
success=success,
severity=resolved_severity,
metadata=resolved_metadata,
)
self._result_queue.put(result)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
self._context.log.log(level, message)
Expand Down Expand Up @@ -138,7 +171,7 @@ def get_external_process_env_vars(self):
),
}

def get_materialize_results(self) -> Iterator[MaterializeResult]:
def get_results(self) -> Iterator[ExtResult]:
yield from self.message_handler.clear_result_queue()


Expand Down
8 changes: 4 additions & 4 deletions python_modules/dagster/dagster/_core/ext/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from dagster import _check as check
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.ext.client import (
ExtClient,
ExtContextInjector,
ExtMessageReader,
)
from dagster._core.ext.context import ExtResult
from dagster._core.ext.utils import (
ExtTempFileContextInjector,
ExtTempFileMessageReader,
Expand Down Expand Up @@ -67,7 +67,7 @@ def run(
extras: Optional[ExtExtras] = None,
env: Optional[Mapping[str, str]] = None,
cwd: Optional[str] = None,
) -> Iterator[MaterializeResult]:
) -> Iterator[ExtResult]:
with ext_protocol(
context=context,
context_injector=self.context_injector,
Expand All @@ -84,15 +84,15 @@ def run(
},
)
while True:
yield from ext_context.get_materialize_results()
yield from ext_context.get_results()
if process.poll() is not None:
break

if process.returncode != 0:
raise DagsterExternalExecutionError(
f"External execution process failed with code {process.returncode}"
)
yield from ext_context.get_materialize_results()
yield from ext_context.get_results()


ExtSubprocess = ResourceParam[_ExtSubprocess]
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

import dagster._check as check
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.ext.client import ExtClient, ExtContextInjector, ExtMessageReader
from dagster._core.ext.context import ExtResult
from dagster._core.ext.utils import (
ExtBlobStoreMessageReader,
ext_protocol,
Expand Down Expand Up @@ -67,7 +67,7 @@ def run(
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
submit_args: Optional[Mapping[str, str]] = None,
) -> Iterator[MaterializeResult]:
) -> Iterator[ExtResult]:
"""Run a Databricks job with the EXT protocol.
Args:
Expand Down Expand Up @@ -116,9 +116,9 @@ def run(
raise DagsterExternalExecutionError(
f"Error running Databricks job: {run.state.state_message}"
)
yield from ext_context.get_materialize_results()
yield from ext_context.get_results()
time.sleep(5)
yield from ext_context.get_materialize_results()
yield from ext_context.get_results()


ExtDatabricks = ResourceParam[_ExtDatabricks]
Expand Down
Loading

0 comments on commit 536624b

Please sign in to comment.