Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Sep 22, 2023
1 parent f7f8147 commit 04fa87e
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 6 deletions.
31 changes: 31 additions & 0 deletions python_modules/dagster-ext/dagster_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class ExtDataProvenance(TypedDict):
is_user_provided: bool


ExtAssetCheckSeverity = Literal["WARN", "ERROR"]

ExtMetadataRawValue = Union[int, float, str, Mapping[str, Any], Sequence[Any], bool, None]


Expand Down Expand Up @@ -799,6 +801,35 @@ def report_asset_materialization(
)
self._materialized_assets.add(asset_key)

def report_asset_check(
self,
check_name: str,
success: bool,
severity: ExtAssetCheckSeverity = "ERROR",
metadata: Optional[Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]] = None,
asset_key: Optional[str] = None,
) -> None:
asset_key = _resolve_optionally_passed_asset_key(
self._data, asset_key, "report_asset_check"
)
check_name = _assert_param_type(check_name, str, "report_asset_check", "check_name")
success = _assert_param_type(success, bool, "report_asset_check", "success")
metadata = (
_normalize_param_metadata(metadata, "report_asset_check", "metadata")
if metadata
else None
)
self._write_message(
"report_asset_check",
{
"asset_key": asset_key,
"check_name": check_name,
"success": success,
"metadata": metadata,
"severity": severity,
},
)

def log(self, message: str, level: str = "info") -> None:
message = _assert_param_type(message, str, "log", "asset_key")
level = _assert_param_value(level, ["info", "warning", "error"], "log", "level")
Expand Down
12 changes: 12 additions & 0 deletions python_modules/dagster-ext/dagster_ext_tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def test_single_asset_context():
)

_assert_unknown_asset_key(context, "report_asset_materialization", asset_key="fake")
context.report_asset_check(
"foo_check",
True,
metadata={
"meta_1": 1,
"meta_2": {"raw_value": "foo", "type": "text"},
},
)

_assert_unknown_asset_key(context, "report_asset_check", "foo_check", True, asset_key="fake")


def test_multi_asset_context():
Expand Down Expand Up @@ -115,6 +125,8 @@ def test_multi_asset_context():

_assert_undefined_asset_key(context, "report_asset_materialization", "bar")
_assert_unknown_asset_key(context, "report_asset_materialization", "bar", asset_key="fake")
_assert_undefined_asset_key(context, "report_asset_check", "foo_check", True)
_assert_unknown_asset_key(context, "report_asset_check", "foo_check", True, asset_key="fake")


def test_no_partition_context():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import boto3
import pytest
from dagster._core.definitions.asset_check_spec import AssetCheckSpec
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.data_version import (
DATA_VERSION_IS_USER_PROVIDED_TAG,
Expand Down Expand Up @@ -44,6 +45,7 @@
ext_protocol,
)
from dagster._core.instance_for_test import instance_for_test
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus
from dagster_aws.ext import ExtS3MessageReader
from moto.server import ThreadedMotoServer

Expand Down Expand Up @@ -96,6 +98,15 @@ def script_fn():
metadata={"bar": {"raw_value": context.get_extra("bar"), "type": "md"}},
data_version="alpha",
)
context.report_asset_check(
"foo_check",
success=True,
severity="WARN",
metadata={
"meta_1": 1,
"meta_2": {"raw_value": "foo", "type": "text"},
},
)

with temp_script(script_fn) as script_path:
yield script_path
Expand Down Expand Up @@ -147,7 +158,7 @@ def test_ext_subprocess(
else:
assert False, "Unreachable"

@asset
@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["foo"]))])
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
Expand Down Expand Up @@ -176,6 +187,14 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
captured = capsys.readouterr()
assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)

asset_check_executions = instance.event_log_storage.get_asset_check_executions(
asset_key=foo.key,
check_name="foo_check",
limit=1,
)
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED


def test_ext_multi_asset():
def script_fn():
Expand Down Expand Up @@ -335,7 +354,7 @@ def script_fn():


def test_ext_no_client(external_script):
@asset
@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["subproc_run"]))])
def subproc_run(context: AssetExecutionContext):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
Expand All @@ -361,6 +380,14 @@ def subproc_run(context: AssetExecutionContext):
assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]

asset_check_executions = instance.event_log_storage.get_asset_check_executions(
asset_key=subproc_run.key,
check_name="foo_check",
limit=1,
)
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED


def test_ext_no_client_no_yield():
def script_fn():
Expand Down
37 changes: 33 additions & 4 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from queue import Queue
from typing import Any, Dict, Iterator, Mapping, Optional, Set
from typing import Any, Iterator, Mapping, Optional, Set, Union

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
Expand All @@ -20,6 +20,8 @@
from typing_extensions import TypeAlias

import dagster._check as check
from dagster._core.definitions.asset_check_result import AssetCheckResult
from dagster._core.definitions.asset_check_spec import AssetCheckSeverity
from dagster._core.definitions.data_version import DataProvenance, DataVersion
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value
Expand All @@ -30,7 +32,7 @@
from dagster._core.execution.context.invocation import BoundOpExecutionContext
from dagster._core.ext.client import ExtMessageReader

ExtResult: TypeAlias = MaterializeResult
ExtResult: TypeAlias = Union[MaterializeResult, AssetCheckResult]


class ExtMessageHandler:
Expand All @@ -40,8 +42,6 @@ def __init__(self, context: OpExecutionContext) -> None:
self._result_queue: Queue[ExtResult] = Queue()
# Only read by the main thread after all messages are handled, so no need for a lock
self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys)
self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
self._data_versions: Dict[AssetKey, DataVersion] = {}

@contextmanager
def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]:
Expand Down Expand Up @@ -97,6 +97,8 @@ def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) ->
def handle_message(self, message: ExtMessage) -> None:
if message["method"] == "report_asset_materialization":
self._handle_report_asset_materialization(**message["params"]) # type: ignore
elif message["method"] == "report_asset_check":
self._handle_report_asset_check(**message["params"]) # type: ignore
elif message["method"] == "log":
self._handle_log(**message["params"]) # type: ignore

Expand All @@ -120,6 +122,33 @@ def _handle_report_asset_materialization(
self._result_queue.put(result)
self._unmaterialized_assets.remove(resolved_asset_key)

def _handle_report_asset_check(
self,
asset_key: str,
check_name: str,
success: bool,
severity: str,
metadata: Mapping[str, ExtMetadataValue],
) -> None:
check.str_param(asset_key, "asset_key")
check.str_param(check_name, "check_name")
check.bool_param(success, "success")
check.literal_param(severity, "severity", [x.value for x in AssetCheckSeverity])
metadata = check.opt_mapping_param(metadata, "metadata", key_type=str)
resolved_asset_key = AssetKey.from_user_string(asset_key)
resolved_metadata = {
k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
}
resolved_severity = AssetCheckSeverity(severity)
result = AssetCheckResult(
asset_key=resolved_asset_key,
check_name=check_name,
success=success,
severity=resolved_severity,
metadata=resolved_metadata,
)
self._result_queue.put(result)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
self._context.log.log(level, message)
Expand Down

0 comments on commit 04fa87e

Please sign in to comment.