Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Sep 22, 2023
1 parent 762811c commit ca5bead
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 20 deletions.
31 changes: 31 additions & 0 deletions python_modules/dagster-ext/dagster_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class ExtDataProvenance(TypedDict):
is_user_provided: bool


ExtAssetCheckSeverity = Literal["WARN", "ERROR"]

ExtMetadataRawValue = Union[int, float, str, Mapping[str, Any], Sequence[Any], bool, None]


Expand Down Expand Up @@ -799,6 +801,35 @@ def report_asset_materialization(
)
self._materialized_assets.add(asset_key)

def report_asset_check(
self,
check_name: str,
success: bool,
severity: ExtAssetCheckSeverity = "ERROR",
metadata: Optional[Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]] = None,
asset_key: Optional[str] = None,
) -> None:
asset_key = _resolve_optionally_passed_asset_key(
self._data, asset_key, "report_asset_check"
)
check_name = _assert_param_type(check_name, str, "report_asset_check", "check_name")
success = _assert_param_type(success, bool, "report_asset_check", "success")
metadata = (
_normalize_param_metadata(metadata, "report_asset_check", "metadata")
if metadata
else None
)
self._write_message(
"report_asset_check",
{
"asset_key": asset_key,
"check_name": check_name,
"success": success,
"metadata": metadata,
"severity": severity,
},
)

def log(self, message: str, level: str = "info") -> None:
message = _assert_param_type(message, str, "log", "asset_key")
level = _assert_param_value(level, ["info", "warning", "error"], "log", "level")
Expand Down
12 changes: 12 additions & 0 deletions python_modules/dagster-ext/dagster_ext_tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def test_single_asset_context():
)

_assert_unknown_asset_key(context, "report_asset_materialization", asset_key="fake")
context.report_asset_check(
"foo_check",
True,
metadata={
"meta_1": 1,
"meta_2": {"raw_value": "foo", "type": "text"},
},
)

_assert_unknown_asset_key(context, "report_asset_check", "foo_check", True, asset_key="fake")


def test_multi_asset_context():
Expand Down Expand Up @@ -115,6 +125,8 @@ def test_multi_asset_context():

_assert_undefined_asset_key(context, "report_asset_materialization", "bar")
_assert_unknown_asset_key(context, "report_asset_materialization", "bar", asset_key="fake")
_assert_undefined_asset_key(context, "report_asset_check", "foo_check", True)
_assert_unknown_asset_key(context, "report_asset_check", "foo_check", True, asset_key="fake")


def test_no_partition_context():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import boto3
import pytest
from dagster._core.definitions.asset_check_spec import AssetCheckSpec
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.data_version import (
DATA_VERSION_IS_USER_PROVIDED_TAG,
Expand Down Expand Up @@ -44,6 +45,7 @@
ext_protocol,
)
from dagster._core.instance_for_test import instance_for_test
from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus
from dagster_aws.ext import ExtS3MessageReader
from moto.server import ThreadedMotoServer

Expand Down Expand Up @@ -96,6 +98,15 @@ def script_fn():
metadata={"bar": {"raw_value": context.get_extra("bar"), "type": "md"}},
data_version="alpha",
)
context.report_asset_check(
"foo_check",
success=True,
severity="WARN",
metadata={
"meta_1": 1,
"meta_2": {"raw_value": "foo", "type": "text"},
},
)

with temp_script(script_fn) as script_path:
yield script_path
Expand Down Expand Up @@ -147,7 +158,7 @@ def test_ext_subprocess(
else:
assert False, "Unreachable"

@asset
@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["foo"]))])
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
Expand Down Expand Up @@ -176,6 +187,14 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
captured = capsys.readouterr()
assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)

asset_check_executions = instance.event_log_storage.get_asset_check_executions(
asset_key=foo.key,
check_name="foo_check",
limit=1,
)
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED


def test_ext_multi_asset():
def script_fn():
Expand Down Expand Up @@ -335,7 +354,7 @@ def script_fn():


def test_ext_no_client(external_script):
@asset
@asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["subproc_run"]))])
def subproc_run(context: AssetExecutionContext):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
Expand All @@ -361,6 +380,14 @@ def subproc_run(context: AssetExecutionContext):
assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]

asset_check_executions = instance.event_log_storage.get_asset_check_executions(
asset_key=subproc_run.key,
check_name="foo_check",
limit=1,
)
assert len(asset_check_executions) == 1
assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED


def test_ext_no_client_no_yield():
def script_fn():
Expand Down
6 changes: 2 additions & 4 deletions python_modules/dagster/dagster/_core/ext/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
from dagster._core.execution.context.compute import OpExecutionContext

if TYPE_CHECKING:
from dagster._core.definitions.result import MaterializeResult

from .context import ExtMessageHandler
from .context import ExtMessageHandler, ExtResult


class ExtClient(ABC):
Expand All @@ -23,7 +21,7 @@ def run(
*,
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
) -> Iterator["MaterializeResult"]: ...
) -> Iterator["ExtResult"]: ...


class ExtContextInjector(ABC):
Expand Down
45 changes: 39 additions & 6 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from queue import Queue
from typing import Any, Dict, Iterator, Mapping, Optional, Set
from typing import Any, Iterator, Mapping, Optional, Set, Union

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
Expand All @@ -12,12 +12,16 @@
ExtExtras,
ExtMessage,
ExtMetadataType,
ExtMetadataValue,
ExtParams,
ExtTimeWindow,
encode_env_var,
)
from typing_extensions import TypeAlias

import dagster._check as check
from dagster._core.definitions.asset_check_result import AssetCheckResult
from dagster._core.definitions.asset_check_spec import AssetCheckSeverity
from dagster._core.definitions.data_version import DataProvenance, DataVersion
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value
Expand All @@ -28,16 +32,16 @@
from dagster._core.execution.context.invocation import BoundOpExecutionContext
from dagster._core.ext.client import ExtMessageReader

ExtResult: TypeAlias = Union[MaterializeResult, AssetCheckResult]


class ExtMessageHandler:
def __init__(self, context: OpExecutionContext) -> None:
self._context = context
# Queue is thread-safe
self._result_queue: Queue[MaterializeResult] = Queue()
self._result_queue: Queue[ExtResult] = Queue()
# Only read by the main thread after all messages are handled, so no need for a lock
self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys)
self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
self._data_versions: Dict[AssetKey, DataVersion] = {}

@contextmanager
def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]:
Expand All @@ -46,7 +50,7 @@ def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParam
for key in self._unmaterialized_assets:
self._result_queue.put(MaterializeResult(asset_key=key))

def clear_result_queue(self) -> Iterator[MaterializeResult]:
def clear_result_queue(self) -> Iterator[ExtResult]:
while not self._result_queue.empty():
yield self._result_queue.get()

Expand Down Expand Up @@ -86,6 +90,8 @@ def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) ->
def handle_message(self, message: ExtMessage) -> None:
if message["method"] == "report_asset_materialization":
self._handle_report_asset_materialization(**message["params"]) # type: ignore
elif message["method"] == "report_asset_check":
self._handle_report_asset_check(**message["params"]) # type: ignore
elif message["method"] == "log":
self._handle_log(**message["params"]) # type: ignore

Expand All @@ -108,6 +114,33 @@ def _handle_report_asset_materialization(
self._result_queue.put(result)
self._unmaterialized_assets.remove(resolved_asset_key)

def _handle_report_asset_check(
self,
asset_key: str,
check_name: str,
success: bool,
severity: str,
metadata: Mapping[str, ExtMetadataValue],
) -> None:
check.str_param(asset_key, "asset_key")
check.str_param(check_name, "check_name")
check.bool_param(success, "success")
check.literal_param(severity, "severity", [x.value for x in AssetCheckSeverity])
metadata = check.opt_mapping_param(metadata, "metadata", key_type=str)
resolved_asset_key = AssetKey.from_user_string(asset_key)
resolved_metadata = {
k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
}
resolved_severity = AssetCheckSeverity(severity)
result = AssetCheckResult(
asset_key=resolved_asset_key,
check_name=check_name,
success=success,
severity=resolved_severity,
metadata=resolved_metadata,
)
self._result_queue.put(result)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
self._context.log.log(level, message)
Expand Down Expand Up @@ -138,7 +171,7 @@ def get_external_process_env_vars(self):
),
}

def get_results(self) -> Iterator[MaterializeResult]:
def get_results(self) -> Iterator[ExtResult]:
yield from self.message_handler.clear_result_queue()


Expand Down
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_core/ext/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

from dagster import _check as check
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.ext.client import (
ExtClient,
ExtContextInjector,
ExtMessageReader,
)
from dagster._core.ext.context import ExtResult
from dagster._core.ext.utils import (
ExtTempFileContextInjector,
ExtTempFileMessageReader,
Expand Down Expand Up @@ -67,7 +67,7 @@ def run(
extras: Optional[ExtExtras] = None,
env: Optional[Mapping[str, str]] = None,
cwd: Optional[str] = None,
) -> Iterator[MaterializeResult]:
) -> Iterator[ExtResult]:
with ext_protocol(
context=context,
context_injector=self.context_injector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

import dagster._check as check
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.ext.client import ExtClient, ExtContextInjector, ExtMessageReader
from dagster._core.ext.context import ExtResult
from dagster._core.ext.utils import (
ExtBlobStoreMessageReader,
ext_protocol,
Expand Down Expand Up @@ -67,7 +67,7 @@ def run(
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
submit_args: Optional[Mapping[str, str]] = None,
) -> Iterator[MaterializeResult]:
) -> Iterator[ExtResult]:
"""Run a Databricks job with the EXT protocol.
Args:
Expand Down
4 changes: 2 additions & 2 deletions python_modules/libraries/dagster-docker/dagster_docker/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
ResourceParam,
_check as check,
)
from dagster._core.definitions.result import MaterializeResult
from dagster._core.ext.client import (
ExtClient,
ExtContextInjector,
ExtMessageReader,
)
from dagster._core.ext.context import (
ExtMessageHandler,
ExtResult,
)
from dagster._core.ext.utils import (
ExtEnvContextInjector,
Expand Down Expand Up @@ -95,7 +95,7 @@ def run(
registry: Optional[Mapping[str, str]] = None,
container_kwargs: Optional[Mapping[str, Any]] = None,
extras: Optional[ExtExtras] = None,
) -> Iterator[MaterializeResult]:
) -> Iterator[ExtResult]:
"""Create a docker container and run it to completion, enriched with the ext protocol.
Args:
Expand Down
4 changes: 2 additions & 2 deletions python_modules/libraries/dagster-k8s/dagster_k8s/ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
_check as check,
)
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.ext.client import (
ExtClient,
Expand All @@ -19,6 +18,7 @@
)
from dagster._core.ext.context import (
ExtMessageHandler,
ExtResult,
)
from dagster._core.ext.utils import (
ExtEnvContextInjector,
Expand Down Expand Up @@ -124,7 +124,7 @@ def run(
base_pod_meta: Optional[Mapping[str, Any]] = None,
base_pod_spec: Optional[Mapping[str, Any]] = None,
extras: Optional[ExtExtras] = None,
) -> Iterator[MaterializeResult]:
) -> Iterator[ExtResult]:
"""Publish a kubernetes pod and wait for it to complete, enriched with the ext protocol.
Args:
Expand Down

0 comments on commit ca5bead

Please sign in to comment.