Skip to content

Commit

Permalink
[ext] add asset checks to ext
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Sep 14, 2023
1 parent 7a0617c commit c9cb902
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 8 deletions.
67 changes: 66 additions & 1 deletion python_modules/dagster-ext/dagster_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
Iterator,
Literal,
Expand All @@ -26,6 +27,7 @@
Type,
TypedDict,
TypeVar,
Union,
cast,
get_args,
)
Expand Down Expand Up @@ -98,6 +100,16 @@ class ExtDataProvenance(TypedDict):
is_user_provided: bool


ExtAssetCheckSeverity = Literal["WARN", "ERROR"]

ExtMetadataRawValue = Union[int, float, str, Mapping[str, Any], Sequence[Any], bool, None]


class ExtMetadataValue(TypedDict):
metadata_type: Optional["ExtMetadataType"]
value: ExtMetadataRawValue


ExtMetadataType = Literal[
"text",
"url",
Expand Down Expand Up @@ -248,6 +260,30 @@ def _assert_opt_param_value(
return value


def _normalize_param_metadata(
metadata: Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]], method: str, param: str
) -> Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]:
_assert_param_type(metadata, dict, method, param)
new_metadata: Dict[str, ExtMetadataValue] = {}
for key, value in metadata.items():
if not isinstance(key, str):
raise DagsterExtError(
f"Invalid type for parameter `{param}` of `{method}`. Expected a dict with string"
f" keys, got a key `{key}` of type `{type(key)}`."
)
elif isinstance(value, dict):
if not {*value.keys()} == {*ExtMetadataValue.__annotations__.keys()}:
raise DagsterExtError(
f"Invalid type for parameter `{param}` of `{method}`. Expected a dict with"
" string keys and values that are either raw metadata values or dictionaries"
f" with schema `{{value: ..., metadata_type: ...}}`. Got a value `{value}`."
)
new_metadata[key] = cast(ExtMetadataValue, value)
else:
new_metadata[key] = {"value": value, "metadata_type": None}
return new_metadata


def _assert_param_json_serializable(value: _T, method: str, param: str) -> _T:
try:
json.dumps(value)
Expand Down Expand Up @@ -701,7 +737,7 @@ def extras(self) -> Mapping[str, Any]:
def report_asset_metadata(
self,
label: str,
value: Any,
value: ExtMetadataRawValue,
metadata_type: Optional[ExtMetadataType] = None,
asset_key: Optional[str] = None,
) -> None:
Expand Down Expand Up @@ -729,6 +765,35 @@ def report_asset_data_version(self, data_version: str, asset_key: Optional[str]
"report_asset_data_version", {"asset_key": asset_key, "data_version": data_version}
)

def report_asset_check_result(
self,
check_name: str,
success: bool,
severity: ExtAssetCheckSeverity = "ERROR",
metadata: Optional[Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]] = None,
asset_key: Optional[str] = None,
) -> None:
asset_key = _resolve_optionally_passed_asset_key(
self._data, asset_key, "report_asset_check_result"
)
check_name = _assert_param_type(check_name, str, "report_asset_check_result", "check_name")
success = _assert_param_type(success, bool, "report_asset_check_result", "success")
metadata = (
_normalize_param_metadata(metadata, "report_asset_check_result", "metadata")
if metadata
else None
)
self._write_message(
"report_asset_check",
{
"asset_key": asset_key,
"check_name": check_name,
"success": success,
"metadata": metadata,
"severity": severity,
},
)

def log(self, message: str, level: str = "info") -> None:
message = _assert_param_type(message, str, "log", "asset_key")
level = _assert_param_value(level, ["info", "warning", "error"], "log", "level")
Expand Down
15 changes: 15 additions & 0 deletions python_modules/dagster-ext/dagster_ext_tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,20 @@ def test_single_asset_context():
context.report_asset_metadata("bar", "boo")
context.report_asset_metadata("baz", 2, "int")
context.report_asset_data_version("bar")
context.report_asset_check_result(
"foo_check",
True,
metadata={
"meta_1": 1,
"meta_2": {"value": "foo", "metadata_type": "text"},
},
)

_assert_unknown_asset_key(context, "report_asset_metadata", "bar", "baz", asset_key="fake")
_assert_unknown_asset_key(context, "report_asset_data_version", "bar", asset_key="fake")
_assert_unknown_asset_key(
context, "report_asset_check_result", "foo_check", True, asset_key="fake"
)


def test_multi_asset_context():
Expand Down Expand Up @@ -114,6 +125,10 @@ def test_multi_asset_context():
_assert_unknown_asset_key(context, "report_asset_metadata", "bar", "baz", asset_key="fake")
_assert_undefined_asset_key(context, "report_asset_data_version", "bar")
_assert_unknown_asset_key(context, "report_asset_data_version", "bar", asset_key="fake")
_assert_undefined_asset_key(context, "report_asset_check_result", "foo_check", True)
_assert_unknown_asset_key(
context, "report_asset_check_result", "foo_check", True, asset_key="fake"
)


def test_no_partition_context():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import boto3
import pytest
from dagster._core.definitions.asset_check_spec import AssetCheckSpec
from dagster._core.definitions.data_version import (
DATA_VERSION_IS_USER_PROVIDED_TAG,
DATA_VERSION_TAG,
Expand Down Expand Up @@ -43,6 +44,7 @@
ext_protocol,
)
from dagster._core.instance_for_test import instance_for_test
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus
from dagster_aws.ext import ExtS3MessageReader
from moto.server import ThreadedMotoServer

Expand Down Expand Up @@ -93,6 +95,15 @@ def script_fn():
time.sleep(0.1) # sleep to make sure that we encompass multiple intervals for blob store IO
context.report_asset_metadata("bar", context.get_extra("bar"), metadata_type="md")
context.report_asset_data_version("alpha")
context.report_asset_check_result(
"foo_check",
success=True,
severity="WARN",
metadata={
"meta_1": 1,
"meta_2": {"value": "foo", "metadata_type": "text"},
},
)

with temp_script(script_fn) as script_path:
yield script_path
Expand Down Expand Up @@ -144,7 +155,7 @@ def test_ext_subprocess(
else:
assert False, "Unreachable"

@asset
@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["foo"]))])
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
Expand Down Expand Up @@ -177,6 +188,14 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
captured = capsys.readouterr()
assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)

asset_check_executions = instance.event_log_storage.get_asset_check_executions(
asset_key=foo.key,
check_name="foo_check",
limit=1,
)
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED


def test_ext_typed_metadata():
def script_fn():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,15 @@ def asset_checks_def_for_node(
def asset_checks_defs(self) -> Iterable[AssetChecksDefinition]:
return self.asset_checks_defs_by_node_handle.values()

def get_asset_check_for_output_name(self, output_name: str) -> Optional[AssetCheckHandle]:
for (
asset_check_handle,
node_output_handle,
) in self.node_output_handles_by_asset_check_handle.items():
if node_output_handle.output_name == output_name:
return asset_check_handle
return None

def get_output_name_for_asset_check(self, asset_check_handle: AssetCheckHandle) -> str:
"""Output name in the leaf op."""
return self.node_output_handles_by_asset_check_handle[asset_check_handle].output_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import dagster._check as check
from dagster._annotations import deprecated, experimental, public
from dagster._core.definitions.asset_check_spec import AssetCheckSpec
from dagster._core.definitions.asset_check_result import AssetCheckResult
from dagster._core.definitions.asset_check_spec import AssetCheckHandle, AssetCheckSpec
from dagster._core.definitions.assets import AssetsDefinition
from dagster._core.definitions.data_version import (
DataProvenance,
Expand Down Expand Up @@ -464,6 +465,53 @@ def get_output_metadata(
output_name=output_name, mapping_key=mapping_key
)

@public
@experimental
def add_asset_check_result(self, asset_check_result: AssetCheckResult) -> None:
"""Add an asset check result for an asset being materialized in the current step.
Args:
asset_check_result (AssetCheckResult): The asset check result to add.
**Examples:**
.. code-block:: python
from dagster import op, AssetKey, AssetCheckSeverity
@asset
def foo_asset(context):
...
context.add_asset_check_result(
AssetCheckResult(
asset_key=AssetKey("my_asset"),
check_name="my_check",
success=True,
severity=AssetCheckSeverity.WARNING,
metadata={"foo": "bar"}
)
)
...
"""
check.inst_param(asset_check_result, "asset_check_result", AssetCheckResult)
self._step_execution_context.add_result_object(asset_check_result)

def has_asset_check_result(self, handle: AssetCheckHandle) -> bool:
result_objects = self.get_step_execution_context().result_objects
return (
next(
(
obj
for obj in result_objects
if isinstance(obj, AssetCheckResult)
and obj.asset_key == handle.asset_key
and obj.check_name == handle.name
),
None,
)
is not None
)

def get_step_execution_context(self) -> StepExecutionContext:
"""Allows advanced users (e.g. framework authors) to punch through to the underlying
step execution context.
Expand Down
10 changes: 10 additions & 0 deletions python_modules/dagster/dagster/_core/execution/context/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import dagster._check as check
from dagster._annotations import public
from dagster._core.definitions.asset_check_result import AssetCheckResult
from dagster._core.definitions.data_version import (
DATA_VERSION_TAG,
SKIP_PARTITION_DATA_VERSION_DEPENDENCY_THRESHOLD,
Expand Down Expand Up @@ -79,6 +80,7 @@
from dagster._core.definitions.dependency import NodeHandle
from dagster._core.definitions.resource_definition import Resources
from dagster._core.event_api import EventLogRecord
from dagster._core.execution.plan.compute import OpOutputUnion
from dagster._core.execution.plan.plan import ExecutionPlan
from dagster._core.execution.plan.state import KnownExecutionState
from dagster._core.instance import DagsterInstance
Expand Down Expand Up @@ -553,6 +555,7 @@ def __init__(
self._step_output_capture = {}

self._output_metadata: Dict[str, Any] = {}
self._result_objects: List["OpOutputUnion"] = []
self._seen_outputs: Dict[str, Union[str, Set[str]]] = {}

self._input_asset_version_info: Dict[AssetKey, Optional["InputAssetVersionInfo"]] = {}
Expand Down Expand Up @@ -790,6 +793,13 @@ def get_output_metadata(
return metadata.get(mapping_key)
return metadata

def add_result_object(self, obj: "OpOutputUnion") -> None:
self._result_objects.append(obj)

@property
def result_objects(self) -> Sequence[AssetCheckResult]:
return self._result_objects

def _get_source_run_id_from_logs(self, step_output_handle: StepOutputHandle) -> Optional[str]:
# walk through event logs to find the right run_id based on the run lineage

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,34 @@ def _coerce_op_compute_fn_to_iterator(
def _zip_and_iterate_op_result(
result: Any, context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> Iterator[Tuple[int, Any, OutputDefinition]]:
if len(output_defs) > 1:
result = _validate_multi_return(context, result, output_defs)
for position, (output_def, element) in enumerate(zip(output_defs, result)):
expected_return_outputs = _filter_expected_output_defs(context, output_defs)
if len(expected_return_outputs) > 1:
result = _validate_multi_return(context, result, expected_return_outputs)
for position, (output_def, element) in enumerate(zip(expected_return_outputs, result)):
yield position, output_def, element
else:
yield 0, output_defs[0], result
yield 0, expected_return_outputs[0], result


# Filter out output_defs corresponding to asset check results that have already been registered on the
# context-- we don't expect these to be returned.
def _filter_expected_output_defs(
context: OpExecutionContext, output_defs: Sequence[OutputDefinition]
) -> Sequence[OutputDefinition]:
return [
output_def
for output_def in output_defs
if not _is_already_registered_asset_check_result(context, output_def)
]


def _is_already_registered_asset_check_result(
context: OpExecutionContext, output_def: OutputDefinition
):
asset_check_handle = context.job_def.asset_layer.get_asset_check_for_output_name(
output_def.name
)
return asset_check_handle and context.has_asset_check_result(asset_check_handle)


def _validate_multi_return(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from itertools import chain
from typing import (
AbstractSet,
Any,
Expand Down Expand Up @@ -90,7 +91,7 @@ def _process_asset_results_to_events(
- An AssetCheckEvaluation, which combines the check result with information from the context
to create a full picture of the asset check's evaluation.
"""
for user_event in user_event_sequence:
for user_event in chain(user_event_sequence, step_context.result_objects):
if isinstance(user_event, MaterializeResult):
assets_def = step_context.job_def.asset_layer.assets_def_for_node(
step_context.node_handle
Expand Down
Loading

0 comments on commit c9cb902

Please sign in to comment.