From ca5beadd560788eb15c25604b2a53cc6f69e5590 Mon Sep 17 00:00:00 2001
From: Sean Mackesey <s.mackesey@gmail.com>
Date: Fri, 15 Sep 2023 08:54:31 -0400
Subject: [PATCH] commit

---
 .../dagster-ext/dagster_ext/__init__.py       | 31 +++++++++++++
 .../dagster_ext_tests/test_context.py         | 12 +++++
 .../test_external_execution.py                | 31 ++++++++++++-
 .../dagster/dagster/_core/ext/client.py       |  6 +--
 .../dagster/dagster/_core/ext/context.py      | 45 ++++++++++++++++---
 .../dagster/dagster/_core/ext/subprocess.py   |  4 +-
 .../dagster_databricks/ext.py                 |  4 +-
 .../dagster-docker/dagster_docker/ext.py      |  4 +-
 .../libraries/dagster-k8s/dagster_k8s/ext.py  |  4 +-
 9 files changed, 121 insertions(+), 20 deletions(-)

diff --git a/python_modules/dagster-ext/dagster_ext/__init__.py b/python_modules/dagster-ext/dagster_ext/__init__.py
index 172e52b127e13..1c332e376b97e 100644
--- a/python_modules/dagster-ext/dagster_ext/__init__.py
+++ b/python_modules/dagster-ext/dagster_ext/__init__.py
@@ -99,6 +99,8 @@ class ExtDataProvenance(TypedDict):
     is_user_provided: bool
 
 
+ExtAssetCheckSeverity = Literal["WARN", "ERROR"]
+
 ExtMetadataRawValue = Union[int, float, str, Mapping[str, Any], Sequence[Any], bool, None]
 
 
@@ -799,6 +801,35 @@ def report_asset_materialization(
         )
         self._materialized_assets.add(asset_key)
 
+    def report_asset_check(
+        self,
+        check_name: str,
+        success: bool,
+        severity: ExtAssetCheckSeverity = "ERROR",
+        metadata: Optional[Mapping[str, Union[ExtMetadataRawValue, ExtMetadataValue]]] = None,
+        asset_key: Optional[str] = None,
+    ) -> None:
+        asset_key = _resolve_optionally_passed_asset_key(
+            self._data, asset_key, "report_asset_check"
+        )
+        check_name = _assert_param_type(check_name, str, "report_asset_check", "check_name")
+        success = _assert_param_type(success, bool, "report_asset_check", "success")
+        metadata = (
+            _normalize_param_metadata(metadata, "report_asset_check", "metadata")
+            if metadata
+            else None
+        )
+        self._write_message(
+            "report_asset_check",
+            {
+                "asset_key": asset_key,
+                "check_name": check_name,
+                "success": success,
+                "metadata": metadata,
+                "severity": severity,
+            },
+        )
+
     def log(self, message: str, level: str = "info") -> None:
         message = _assert_param_type(message, str, "log", "asset_key")
         level = _assert_param_value(level, ["info", "warning", "error"], "log", "level")
diff --git a/python_modules/dagster-ext/dagster_ext_tests/test_context.py b/python_modules/dagster-ext/dagster_ext_tests/test_context.py
index 4294e53bbeda7..f23e856a7ea6d 100644
--- a/python_modules/dagster-ext/dagster_ext_tests/test_context.py
+++ b/python_modules/dagster-ext/dagster_ext_tests/test_context.py
@@ -88,6 +88,16 @@ def test_single_asset_context():
     )
 
     _assert_unknown_asset_key(context, "report_asset_materialization", asset_key="fake")
+    context.report_asset_check(
+        "foo_check",
+        True,
+        metadata={
+            "meta_1": 1,
+            "meta_2": {"raw_value": "foo", "type": "text"},
+        },
+    )
+
+    _assert_unknown_asset_key(context, "report_asset_check", "foo_check", True, asset_key="fake")
 
 
 def test_multi_asset_context():
@@ -115,6 +125,8 @@ def test_multi_asset_context():
 
     _assert_undefined_asset_key(context, "report_asset_materialization", "bar")
     _assert_unknown_asset_key(context, "report_asset_materialization", "bar", asset_key="fake")
+    _assert_undefined_asset_key(context, "report_asset_check", "foo_check", True)
+    _assert_unknown_asset_key(context, "report_asset_check", "foo_check", True, asset_key="fake")
 
 
 def test_no_partition_context():
diff --git a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py
index c2b0b540904cf..1e011bfecf56c 100644
--- a/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py
+++ b/python_modules/dagster-ext/dagster_ext_tests/test_external_execution.py
@@ -9,6 +9,7 @@
 
 import boto3
 import pytest
+from dagster._core.definitions.asset_check_spec import AssetCheckSpec
 from dagster._core.definitions.asset_spec import AssetSpec
 from dagster._core.definitions.data_version import (
     DATA_VERSION_IS_USER_PROVIDED_TAG,
@@ -44,6 +45,7 @@
     ext_protocol,
 )
 from dagster._core.instance_for_test import instance_for_test
+from dagster._core.storage.asset_check_execution_record import AssetCheckExecutionRecordStatus
 from dagster_aws.ext import ExtS3MessageReader
 from moto.server import ThreadedMotoServer
 
@@ -96,6 +98,15 @@ def script_fn():
             metadata={"bar": {"raw_value": context.get_extra("bar"), "type": "md"}},
             data_version="alpha",
         )
+        context.report_asset_check(
+            "foo_check",
+            success=True,
+            severity="WARN",
+            metadata={
+                "meta_1": 1,
+                "meta_2": {"raw_value": "foo", "type": "text"},
+            },
+        )
 
     with temp_script(script_fn) as script_path:
         yield script_path
@@ -147,7 +158,7 @@ def test_ext_subprocess(
     else:
         assert False, "Unreachable"
 
-    @asset
+    @asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["foo"]))])
     def foo(context: AssetExecutionContext, ext: ExtSubprocess):
         extras = {"bar": "baz"}
         cmd = [_PYTHON_EXECUTABLE, external_script]
@@ -176,6 +187,14 @@ def foo(context: AssetExecutionContext, ext: ExtSubprocess):
         captured = capsys.readouterr()
         assert re.search(r"dagster - INFO - [^\n]+ - hello world\n", captured.err, re.MULTILINE)
 
+        asset_check_executions = instance.event_log_storage.get_asset_check_executions(
+            asset_key=foo.key,
+            check_name="foo_check",
+            limit=1,
+        )
+        assert len(asset_check_executions) == 1
+        assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED
+
 
 def test_ext_multi_asset():
     def script_fn():
@@ -335,7 +354,7 @@ def script_fn():
 
 
 def test_ext_no_client(external_script):
-    @asset
+    @asset(check_specs=[AssetCheckSpec(name="foo_check", asset=AssetKey(["subproc_run"]))])
     def subproc_run(context: AssetExecutionContext):
         extras = {"bar": "baz"}
         cmd = [_PYTHON_EXECUTABLE, external_script]
@@ -361,6 +380,14 @@ def subproc_run(context: AssetExecutionContext):
         assert mat.asset_materialization.tags[DATA_VERSION_TAG] == "alpha"
         assert mat.asset_materialization.tags[DATA_VERSION_IS_USER_PROVIDED_TAG]
 
+        asset_check_executions = instance.event_log_storage.get_asset_check_executions(
+            asset_key=subproc_run.key,
+            check_name="foo_check",
+            limit=1,
+        )
+        assert len(asset_check_executions) == 1
+        assert asset_check_executions[0].status == AssetCheckExecutionRecordStatus.SUCCEEDED
+
 
 def test_ext_no_client_no_yield():
     def script_fn():
diff --git a/python_modules/dagster/dagster/_core/ext/client.py b/python_modules/dagster/dagster/_core/ext/client.py
index 47270207bb62e..5d96312eb9eb7 100644
--- a/python_modules/dagster/dagster/_core/ext/client.py
+++ b/python_modules/dagster/dagster/_core/ext/client.py
@@ -11,9 +11,7 @@
 from dagster._core.execution.context.compute import OpExecutionContext
 
 if TYPE_CHECKING:
-    from dagster._core.definitions.result import MaterializeResult
-
-    from .context import ExtMessageHandler
+    from .context import ExtMessageHandler, ExtResult
 
 
 class ExtClient(ABC):
@@ -23,7 +21,7 @@ def run(
         *,
         context: OpExecutionContext,
         extras: Optional[ExtExtras] = None,
-    ) -> Iterator["MaterializeResult"]: ...
+    ) -> Iterator["ExtResult"]: ...
 
 
 class ExtContextInjector(ABC):
diff --git a/python_modules/dagster/dagster/_core/ext/context.py b/python_modules/dagster/dagster/_core/ext/context.py
index 4b3731a393a40..0b42f21fd2615 100644
--- a/python_modules/dagster/dagster/_core/ext/context.py
+++ b/python_modules/dagster/dagster/_core/ext/context.py
@@ -1,7 +1,7 @@
 from contextlib import contextmanager
 from dataclasses import dataclass
 from queue import Queue
-from typing import Any, Dict, Iterator, Mapping, Optional, Set
+from typing import Any, Iterator, Mapping, Optional, Set, Union
 
 from dagster_ext import (
     DAGSTER_EXT_ENV_KEYS,
@@ -12,12 +12,16 @@
     ExtExtras,
     ExtMessage,
     ExtMetadataType,
+    ExtMetadataValue,
     ExtParams,
     ExtTimeWindow,
     encode_env_var,
 )
+from typing_extensions import TypeAlias
 
 import dagster._check as check
+from dagster._core.definitions.asset_check_result import AssetCheckResult
+from dagster._core.definitions.asset_check_spec import AssetCheckSeverity
 from dagster._core.definitions.data_version import DataProvenance, DataVersion
 from dagster._core.definitions.events import AssetKey
 from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value
@@ -28,16 +32,16 @@
 from dagster._core.execution.context.invocation import BoundOpExecutionContext
 from dagster._core.ext.client import ExtMessageReader
 
+ExtResult: TypeAlias = Union[MaterializeResult, AssetCheckResult]
+
 
 class ExtMessageHandler:
     def __init__(self, context: OpExecutionContext) -> None:
         self._context = context
         # Queue is thread-safe
-        self._result_queue: Queue[MaterializeResult] = Queue()
+        self._result_queue: Queue[ExtResult] = Queue()
         # Only read by the main thread after all messages are handled, so no need for a lock
         self._unmaterialized_assets: Set[AssetKey] = set(context.selected_asset_keys)
-        self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
-        self._data_versions: Dict[AssetKey, DataVersion] = {}
 
     @contextmanager
     def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParams]:
@@ -46,7 +50,7 @@ def handle_messages(self, message_reader: ExtMessageReader) -> Iterator[ExtParam
         for key in self._unmaterialized_assets:
             self._result_queue.put(MaterializeResult(asset_key=key))
 
-    def clear_result_queue(self) -> Iterator[MaterializeResult]:
+    def clear_result_queue(self) -> Iterator[ExtResult]:
         while not self._result_queue.empty():
             yield self._result_queue.get()
 
@@ -86,6 +90,8 @@ def _resolve_metadata_value(self, value: Any, metadata_type: ExtMetadataType) ->
     def handle_message(self, message: ExtMessage) -> None:
         if message["method"] == "report_asset_materialization":
             self._handle_report_asset_materialization(**message["params"])  # type: ignore
+        elif message["method"] == "report_asset_check":
+            self._handle_report_asset_check(**message["params"])  # type: ignore
         elif message["method"] == "log":
             self._handle_log(**message["params"])  # type: ignore
 
@@ -108,6 +114,33 @@ def _handle_report_asset_materialization(
         self._result_queue.put(result)
         self._unmaterialized_assets.remove(resolved_asset_key)
 
+    def _handle_report_asset_check(
+        self,
+        asset_key: str,
+        check_name: str,
+        success: bool,
+        severity: str,
+        metadata: Mapping[str, ExtMetadataValue],
+    ) -> None:
+        check.str_param(asset_key, "asset_key")
+        check.str_param(check_name, "check_name")
+        check.bool_param(success, "success")
+        check.literal_param(severity, "severity", [x.value for x in AssetCheckSeverity])
+        metadata = check.opt_mapping_param(metadata, "metadata", key_type=str)
+        resolved_asset_key = AssetKey.from_user_string(asset_key)
+        resolved_metadata = {
+            k: self._resolve_metadata_value(v["raw_value"], v["type"]) for k, v in metadata.items()
+        }
+        resolved_severity = AssetCheckSeverity(severity)
+        result = AssetCheckResult(
+            asset_key=resolved_asset_key,
+            check_name=check_name,
+            success=success,
+            severity=resolved_severity,
+            metadata=resolved_metadata,
+        )
+        self._result_queue.put(result)
+
     def _handle_log(self, message: str, level: str = "info") -> None:
         check.str_param(message, "message")
         self._context.log.log(level, message)
@@ -138,7 +171,7 @@ def get_external_process_env_vars(self):
             ),
         }
 
-    def get_results(self) -> Iterator[MaterializeResult]:
+    def get_results(self) -> Iterator[ExtResult]:
         yield from self.message_handler.clear_result_queue()
 
 
diff --git a/python_modules/dagster/dagster/_core/ext/subprocess.py b/python_modules/dagster/dagster/_core/ext/subprocess.py
index 1d658082884c2..be8f919763f19 100644
--- a/python_modules/dagster/dagster/_core/ext/subprocess.py
+++ b/python_modules/dagster/dagster/_core/ext/subprocess.py
@@ -5,7 +5,6 @@
 
 from dagster import _check as check
 from dagster._core.definitions.resource_annotation import ResourceParam
-from dagster._core.definitions.result import MaterializeResult
 from dagster._core.errors import DagsterExternalExecutionError
 from dagster._core.execution.context.compute import OpExecutionContext
 from dagster._core.ext.client import (
@@ -13,6 +12,7 @@
     ExtContextInjector,
     ExtMessageReader,
 )
+from dagster._core.ext.context import ExtResult
 from dagster._core.ext.utils import (
     ExtTempFileContextInjector,
     ExtTempFileMessageReader,
@@ -67,7 +67,7 @@ def run(
         extras: Optional[ExtExtras] = None,
         env: Optional[Mapping[str, str]] = None,
         cwd: Optional[str] = None,
-    ) -> Iterator[MaterializeResult]:
+    ) -> Iterator[ExtResult]:
         with ext_protocol(
             context=context,
             context_injector=self.context_injector,
diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py
index deda20b3c665d..19451003635bf 100644
--- a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py
+++ b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py
@@ -9,10 +9,10 @@
 
 import dagster._check as check
 from dagster._core.definitions.resource_annotation import ResourceParam
-from dagster._core.definitions.result import MaterializeResult
 from dagster._core.errors import DagsterExternalExecutionError
 from dagster._core.execution.context.compute import OpExecutionContext
 from dagster._core.ext.client import ExtClient, ExtContextInjector, ExtMessageReader
+from dagster._core.ext.context import ExtResult
 from dagster._core.ext.utils import (
     ExtBlobStoreMessageReader,
     ext_protocol,
@@ -67,7 +67,7 @@ def run(
         context: OpExecutionContext,
         extras: Optional[ExtExtras] = None,
         submit_args: Optional[Mapping[str, str]] = None,
-    ) -> Iterator[MaterializeResult]:
+    ) -> Iterator[ExtResult]:
         """Run a Databricks job with the EXT protocol.
 
         Args:
diff --git a/python_modules/libraries/dagster-docker/dagster_docker/ext.py b/python_modules/libraries/dagster-docker/dagster_docker/ext.py
index b7a88238f3a04..8fac849329439 100644
--- a/python_modules/libraries/dagster-docker/dagster_docker/ext.py
+++ b/python_modules/libraries/dagster-docker/dagster_docker/ext.py
@@ -7,7 +7,6 @@
     ResourceParam,
     _check as check,
 )
-from dagster._core.definitions.result import MaterializeResult
 from dagster._core.ext.client import (
     ExtClient,
     ExtContextInjector,
@@ -15,6 +14,7 @@
 )
 from dagster._core.ext.context import (
     ExtMessageHandler,
+    ExtResult,
 )
 from dagster._core.ext.utils import (
     ExtEnvContextInjector,
@@ -95,7 +95,7 @@ def run(
         registry: Optional[Mapping[str, str]] = None,
         container_kwargs: Optional[Mapping[str, Any]] = None,
         extras: Optional[ExtExtras] = None,
-    ) -> Iterator[MaterializeResult]:
+    ) -> Iterator[ExtResult]:
         """Create a docker container and run it to completion, enriched with the ext protocol.
 
         Args:
diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py
index 93bdcf42df06b..8da2226cebdd9 100644
--- a/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py
+++ b/python_modules/libraries/dagster-k8s/dagster_k8s/ext.py
@@ -9,7 +9,6 @@
     _check as check,
 )
 from dagster._core.definitions.resource_annotation import ResourceParam
-from dagster._core.definitions.result import MaterializeResult
 from dagster._core.errors import DagsterInvariantViolationError
 from dagster._core.ext.client import (
     ExtClient,
@@ -19,6 +18,7 @@
 )
 from dagster._core.ext.context import (
     ExtMessageHandler,
+    ExtResult,
 )
 from dagster._core.ext.utils import (
     ExtEnvContextInjector,
@@ -124,7 +124,7 @@ def run(
         base_pod_meta: Optional[Mapping[str, Any]] = None,
         base_pod_spec: Optional[Mapping[str, Any]] = None,
         extras: Optional[ExtExtras] = None,
-    ) -> Iterator[MaterializeResult]:
+    ) -> Iterator[ExtResult]:
         """Publish a kubernetes pod and wait for it to complete, enriched with the ext protocol.
 
         Args: