Skip to content

Commit

Permalink
[ext] Use MaterializeResult for ext_protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Sep 19, 2023
1 parent fe57d60 commit a0ae406
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_ext_subprocess(
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
extras = {"bar": "baz"}
cmd = [_PYTHON_EXECUTABLE, external_script]
ext.run(
return ext.run(
cmd,
context=context,
extras=extras,
Expand Down Expand Up @@ -201,7 +201,7 @@ def script_fn():
def foo(context: AssetExecutionContext, ext: ExtSubprocess):
with temp_script(script_fn) as script_path:
cmd = [_PYTHON_EXECUTABLE, script_path]
ext.run(cmd, context=context)
return ext.run(cmd, context=context)

with instance_for_test() as instance:
materialize(
Expand Down Expand Up @@ -313,6 +313,9 @@ def subproc_run(context: AssetExecutionContext):
extras=extras,
) as ext_context:
subprocess.run(cmd, env=ext_context.get_external_process_env_vars(), check=False)
_ext_context = ext_context
mat_results = _ext_context.get_materialize_results()
return mat_results[0] if len(mat_results) == 1 else mat_results

with instance_for_test() as instance:
materialize(
Expand Down
6 changes: 4 additions & 2 deletions python_modules/dagster/dagster/_core/ext/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional
from typing import TYPE_CHECKING, Iterator, Optional, Tuple, Union

from dagster_ext import (
ExtContextData,
Expand All @@ -11,6 +11,8 @@
from dagster._core.execution.context.compute import OpExecutionContext

if TYPE_CHECKING:
from dagster._core.definitions.result import MaterializeResult

from .context import ExtMessageHandler


Expand All @@ -21,7 +23,7 @@ def run(
*,
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
) -> None: ...
) -> Union["MaterializeResult", Tuple["MaterializeResult", ...]]: ...


class ExtContextInjector(ABC):
Expand Down
35 changes: 29 additions & 6 deletions python_modules/dagster/dagster/_core/ext/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Mapping, Optional, get_args
from typing import Any, Dict, Mapping, Optional, Tuple, get_args

from dagster_ext import (
DAGSTER_EXT_ENV_KEYS,
Expand All @@ -19,6 +19,7 @@
from dagster._core.definitions.events import AssetKey
from dagster._core.definitions.metadata import MetadataValue, normalize_metadata_value
from dagster._core.definitions.partition_key_range import PartitionKeyRange
from dagster._core.definitions.result import MaterializeResult
from dagster._core.definitions.time_window_partitions import TimeWindow
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.execution.context.invocation import BoundOpExecutionContext
Expand All @@ -27,6 +28,16 @@
class ExtMessageHandler:
def __init__(self, context: OpExecutionContext) -> None:
self._context = context
self._metadata: Dict[AssetKey, Dict[str, MetadataValue]] = {}
self._data_versions: Dict[AssetKey, DataVersion] = {}

@property
def metadata(self) -> Dict[AssetKey, Dict[str, MetadataValue]]:
return self._metadata

@property
def data_versions(self) -> Dict[AssetKey, DataVersion]:
return self._data_versions

# Type ignores because we currently validate in individual handlers
def handle_message(self, message: ExtMessage) -> None:
Expand All @@ -43,10 +54,9 @@ def _handle_report_asset_metadata(
check.str_param(asset_key, "asset_key")
check.str_param(label, "label")
check.opt_literal_param(type, "type", get_args(ExtMetadataType))
key = AssetKey.from_user_string(asset_key)
output_name = self._context.output_for_asset_key(key)
resolved_asset_key = AssetKey.from_user_string(asset_key)
metadata_value = self._resolve_metadata_value(value, type)
self._context.add_output_metadata({label: metadata_value}, output_name)
self._metadata.setdefault(resolved_asset_key, {})[label] = metadata_value

def _resolve_metadata_value(
self, value: Any, metadata_type: Optional[ExtMetadataType]
Expand Down Expand Up @@ -85,8 +95,8 @@ def _resolve_metadata_value(
def _handle_report_asset_data_version(self, asset_key: str, data_version: str) -> None:
check.str_param(asset_key, "asset_key")
check.str_param(data_version, "data_version")
key = AssetKey.from_user_string(asset_key)
self._context.set_data_version(key, DataVersion(data_version))
resolved_asset_key = AssetKey.from_user_string(asset_key)
self._data_versions[resolved_asset_key] = DataVersion(data_version)

def _handle_log(self, message: str, level: str = "info") -> None:
check.str_param(message, "message")
Expand Down Expand Up @@ -118,6 +128,19 @@ def get_external_process_env_vars(self):
),
}

def get_materialize_results(self) -> Tuple[MaterializeResult, ...]:
return tuple(
self._materialize_result_for_asset(AssetKey.from_user_string(key))
for key in self.context_data["asset_keys"] or []
)

def _materialize_result_for_asset(self, asset_key: AssetKey):
return MaterializeResult(
asset_key=asset_key,
metadata=self.message_handler.metadata.get(asset_key),
data_version=self.message_handler.data_versions.get(asset_key),
)


def build_external_execution_context_data(
context: OpExecutionContext,
Expand Down
8 changes: 6 additions & 2 deletions python_modules/dagster/dagster/_core/ext/subprocess.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from subprocess import Popen
from typing import Mapping, Optional, Sequence, Union
from typing import Mapping, Optional, Sequence, Tuple, Union

from dagster_ext import ExtExtras

from dagster import _check as check
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.ext.client import (
Expand Down Expand Up @@ -66,7 +67,7 @@ def run(
extras: Optional[ExtExtras] = None,
env: Optional[Mapping[str, str]] = None,
cwd: Optional[str] = None,
) -> None:
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
with ext_protocol(
context=context,
context_injector=self.context_injector,
Expand All @@ -88,6 +89,9 @@ def run(
raise DagsterExternalExecutionError(
f"External execution process failed with code {process.returncode}"
)
_ext_context = ext_context
mat_results = _ext_context.get_materialize_results()
return mat_results[0] if len(mat_results) == 1 else mat_results


ExtSubprocess = ResourceParam[_ExtSubprocess]
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import string
import time
from contextlib import contextmanager
from typing import Iterator, Mapping, Optional
from typing import Iterator, Mapping, Optional, Tuple, Union

import dagster._check as check
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.ext.client import ExtClient, ExtContextInjector, ExtMessageReader
Expand Down Expand Up @@ -66,7 +67,7 @@ def run(
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
submit_args: Optional[Mapping[str, str]] = None,
) -> None:
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
"""Run a Databricks job with the EXT protocol.
Args:
Expand Down Expand Up @@ -106,7 +107,7 @@ def run(
jobs.RunLifeCycleState.SKIPPED,
):
if run.state.result_state == jobs.RunResultState.SUCCESS:
return
break
else:
raise DagsterExternalExecutionError(
f"Error running Databricks job: {run.state.state_message}"
Expand All @@ -116,6 +117,8 @@ def run(
f"Error running Databricks job: {run.state.state_message}"
)
time.sleep(5)
_ext_context = ext_context
return _ext_context.get_materialize_results()


ExtDatabricks = ResourceParam[_ExtDatabricks]
Expand Down
7 changes: 5 additions & 2 deletions python_modules/libraries/dagster-docker/dagster_docker/ext.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from contextlib import contextmanager
from typing import Any, Iterator, Mapping, Optional, Sequence, Union
from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union

import docker
from dagster import (
OpExecutionContext,
ResourceParam,
_check as check,
)
from dagster._core.definitions.result import MaterializeResult
from dagster._core.ext.client import (
ExtClient,
ExtContextInjector,
Expand Down Expand Up @@ -94,7 +95,7 @@ def run(
registry: Optional[Mapping[str, str]] = None,
container_kwargs: Optional[Mapping[str, Any]] = None,
extras: Optional[ExtExtras] = None,
) -> None:
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
"""Create a docker container and run it to completion, enriched with the ext protocol.
Args:
Expand Down Expand Up @@ -162,6 +163,8 @@ def run(
raise DagsterExtError(f"Container exited with non-zero status code: {result}")
finally:
container.stop()
_ext_context = ext_context
return _ext_context.get_materialize_results()

def _create_container(
self,
Expand Down
11 changes: 7 additions & 4 deletions python_modules/libraries/dagster-k8s/dagster_k8s/ext.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import random
import string
from contextlib import contextmanager
from typing import Any, Iterator, Mapping, Optional, Sequence, Union
from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union

import kubernetes
from dagster import (
OpExecutionContext,
_check as check,
)
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.definitions.result import MaterializeResult
from dagster._core.errors import DagsterInvariantViolationError
from dagster._core.ext.client import (
ExtClient,
Expand Down Expand Up @@ -123,7 +124,7 @@ def run(
base_pod_meta: Optional[Mapping[str, Any]] = None,
base_pod_spec: Optional[Mapping[str, Any]] = None,
extras: Optional[ExtExtras] = None,
) -> None:
) -> Union[MaterializeResult, Tuple[MaterializeResult, ...]]:
"""Publish a kubernetes pod and wait for it to complete, enriched with the ext protocol.
Args:
Expand Down Expand Up @@ -158,15 +159,15 @@ def run(
extras=extras,
context_injector=self.context_injector,
message_reader=self.message_reader,
) as ext_process:
) as ext_context:
namespace = namespace or "default"
pod_name = get_pod_name(context.run_id, context.op.name)
pod_body = build_pod_body(
pod_name=pod_name,
image=image,
command=command,
env_vars={
**ext_process.get_external_process_env_vars(),
**ext_context.get_external_process_env_vars(),
**(self.env or {}),
**(env or {}),
},
Expand Down Expand Up @@ -196,6 +197,8 @@ def run(
)
finally:
client.core_api.delete_namespaced_pod(pod_name, namespace)
_ext_context = ext_context
return _ext_context.get_materialize_results()


def build_pod_body(
Expand Down

0 comments on commit a0ae406

Please sign in to comment.