Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ext] asset check support #16466

Merged
merged 1 commit into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions python_modules/dagster-ext/dagster_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class ExtDataProvenance(TypedDict):
is_user_provided: bool


ExtAssetCheckSeverity = Literal["WARN", "ERROR"]

ExtMetadataRawValue = Union[int, float, str, Mapping[str, Any], Sequence[Any], bool, None]


Expand Down Expand Up @@ -808,6 +810,35 @@ def report_asset_materialization(
)
self._materialized_assets.add(asset_key)

def report_asset_check(
self,
check_name: str,
success: bool,
severity: ExtAssetCheckSeverity = "ERROR",
metadata: Optional[Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]] = None,
asset_key: Optional[str] = None,
) -> None:
asset_key = _resolve_optionally_passed_asset_key(
self._data, asset_key, "report_asset_check"
)
check_name = _assert_param_type(check_name, str, "report_asset_check", "check_name")
success = _assert_param_type(success, bool, "report_asset_check", "success")
metadata = (
_normalize_param_metadata(metadata, "report_asset_check", "metadata")
if metadata
else None
)
self._write_message(
"report_asset_check",
{
"asset_key": asset_key,
"check_name": check_name,
"success": success,
"metadata": metadata,
"severity": severity,
},
)

def log(self, message: str, level: str = "info") -> None:
message = _assert_param_type(message, str, "log", "asset_key")
level = _assert_param_value(level, ["info", "warning", "error"], "log", "level")
Expand Down
12 changes: 12 additions & 0 deletions python_modules/dagster-ext/dagster_ext_tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def test_single_asset_context():
)

_assert_unknown_asset_key(context, "report_asset_materialization", asset_key="fake")
context.report_asset_check(
"foo_check",
True,
metadata={
"meta_1": 1,
"meta_2": {"raw_value": "foo", "type": "text"},
},
)

_assert_unknown_asset_key(context, "report_asset_check", "foo_check", True, asset_key="fake")


def test_multi_asset_context():
Expand Down Expand Up @@ -115,6 +125,8 @@ def test_multi_asset_context():

_assert_undefined_asset_key(context, "report_asset_materialization", "bar")
_assert_unknown_asset_key(context, "report_asset_materialization", "bar", asset_key="fake")
_assert_undefined_asset_key(context, "report_asset_check", "foo_check", True)
_assert_unknown_asset_key(context, "report_asset_check", "foo_check", True, asset_key="fake")


def test_no_partition_context():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import boto3
import pytest
from dagster._core.definitions.asset_check_spec import AssetCheckSpec
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.data_version import (
DATA_VERSION_IS_USER_PROVIDED_TAG,
Expand Down Expand Up @@ -44,6 +45,7 @@
ext_protocol,
)
from dagster._core.instance_for_test import instance_for_test
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus
from dagster_aws.ext import ExtS3MessageReader
from moto.server import ThreadedMotoServer

Expand Down Expand Up @@ -96,6 +98,15 @@ def script_fn():
metadata={"bar": {"raw_value": context.get_extra("bar"), "type": "md"}},
data_version="alpha",
)
context.report_asset_check(
"foo_check",
success=True,
severity="WARN",
metadata={
"meta_1": 1,
"meta_2": {"raw_value": "foo", "type": "text"},
},
)

with temp_script(script_fn) as script_path:
yield script_path
Expand Down Expand Up @@ -147,7 +158,7 @@ def test_ext_subprocess(
else:
assert False, "Unreachable"

@asset
@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["foo"]))])
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
Expand Down Expand Up @@ -176,6 +187,14 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
captured = capsys.readouterr()
assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)

asset_check_executions = instance.event_log_storage.get_asset_check_executions(
asset_key=foo.key,
check_name="foo_check",
limit=1,
)
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED


def test_ext_multi_asset():
def script_fn():
Expand Down Expand Up @@ -335,7 +354,7 @@ def script_fn():


def test_ext_no_client(external_script):
@asset
@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["subproc_run"]))])
def subproc_run(context: AssetExecutionContext):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
Expand All @@ -361,6 +380,14 @@ def subproc_run(context: AssetExecutionContext):
assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]

asset_check_executions = instance.event_log_storage.get_asset_check_executions(
asset_key=subproc_run.key,
check_name="foo_check",
limit=1,
)
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED


def test_ext_no_client_no_yield():
def script_fn():
Expand Down
35 changes: 31 additions & 4 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from queue import Queue
from typing import Any, Dict, Iterator, Mapping, Optional, Set
from typing import Any, Iterator, Mapping, Optional, Set, Union

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
Expand All @@ -20,6 +20,8 @@
from typing_extensions import TypeAlias

import dagster._check as check
from dagster._core.definitions.asset_check_result import AssetCheckResult
from dagster._core.definitions.asset_check_spec import AssetCheckSeverity
from dagster._core.definitions.data_version import DataProvenance, DataVersion
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value
Expand All @@ -30,7 +32,7 @@
from dagster._core.execution.context.invocation import BoundOpExecutionContext
from dagster._core.ext.client import ExtMessageReader

ExtResult: TypeAlias = MaterializeResult
ExtResult: TypeAlias = Union[MaterializeResult, AssetCheckResult]


class ExtMessageHandler:
Expand All @@ -40,8 +42,6 @@ def __init__(self, context: OpExecutionContext) -> None:
self._result_queue: Queue[ExtResult] = Queue()
# Only read by the main thread after all messages are handled, so no need for a lock
self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys)
self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
self._data_versions: Dict[AssetKey, DataVersion] = {}

@contextmanager
def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]:
Expand Down Expand Up @@ -97,6 +97,8 @@ def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) ->
def handle_message(self, message: ExtMessage) -> None:
if message["method"] == "report_asset_materialization":
self._handle_report_asset_materialization(**message["params"]) # type: ignore
elif message["method"] == "report_asset_check":
self._handle_report_asset_check(**message["params"]) # type: ignore
elif message["method"] == "log":
self._handle_log(**message["params"]) # type: ignore

Expand All @@ -120,6 +122,31 @@ def _handle_report_asset_materialization(
self._result_queue.put(result)
self._unmaterialized_assets.remove(resolved_asset_key)

def _handle_report_asset_check(
self,
asset_key: str,
check_name: str,
success: bool,
severity: str,
metadata: Mapping[str, ExtMetadataValue],
) -> None:
check.str_param(asset_key, "asset_key")
check.str_param(check_name, "check_name")
check.bool_param(success, "success")
check.literal_param(severity, "severity", [x.value for x in AssetCheckSeverity])
metadata = check.opt_mapping_param(metadata, "metadata", key_type=str)
resolved_asset_key = AssetKey.from_user_string(asset_key)
resolved_metadata = self._resolve_metadata(metadata)
resolved_severity = AssetCheckSeverity(severity)
result = AssetCheckResult(
asset_key=resolved_asset_key,
check_name=check_name,
success=success,
severity=resolved_severity,
metadata=resolved_metadata,
)
self._result_queue.put(result)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
self._context.log.log(level, message)
Expand Down