From 5ae23678ca2dfb5578dd8cc1e574adc5c5b0dad3 Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Wed, 13 Sep 2023 10:30:12 -0400 Subject: [PATCH 1/2] fix --- docs/sphinx/conf.py | 1 + .../dagster-ext/dagster_ext/__init__.py | 41 ++++- .../libraries/dagster-databricks/README.md | 143 ++++++++++++++++ .../dagster_databricks/__init__.py | 6 + .../dagster_databricks/ext.py | 156 ++++++++++++++++++ .../test_external_asset.py | 114 +++++++++++++ .../libraries/dagster-databricks/setup.py | 2 +- .../libraries/dagster-databricks/tox.ini | 1 + 8 files changed, 461 insertions(+), 3 deletions(-) create mode 100644 python_modules/libraries/dagster-databricks/dagster_databricks/ext.py create mode 100644 python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_external_asset.py diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py index 341188be73834..760fed7fe23a1 100644 --- a/docs/sphinx/conf.py +++ b/docs/sphinx/conf.py @@ -17,6 +17,7 @@ ### dagster packages "../../python_modules/automation", "../../python_modules/dagster", + "../../python_modules/dagster-ext", "../../python_modules/dagster-graphql", "../../python_modules/dagit", "../../python_modules/dagster-webserver", diff --git a/python_modules/dagster-ext/dagster_ext/__init__.py b/python_modules/dagster-ext/dagster_ext/__init__.py index 54527e58bc1ba..8e66f660efd07 100644 --- a/python_modules/dagster-ext/dagster_ext/__init__.py +++ b/python_modules/dagster-ext/dagster_ext/__init__.py @@ -411,6 +411,17 @@ def _upload_loop(self, is_task_complete: Event) -> None: time.sleep(1) +class ExtBufferedFilesystemMessageWriterChannel(ExtBlobStoreMessageWriterChannel): + def __init__(self, path: str, *, interval: float = 10): + super().__init__(interval=interval) + self._path = path + + def upload_messages_chunk(self, payload: IO, index: int) -> None: + message_path = os.path.join(self._path, f"{index}.json") + with open(message_path, "w") as f: + f.write(payload.read()) + + # ######################## # ##### IO - DEFAULT # ######################## @@ -500,7 +511,6 @@ class ExtS3MessageWriter(ExtBlobStoreMessageWriter): # client is a boto3.client("s3") object def __init__(self, client: Any, *, interval: float = 10): super().__init__(interval=interval) - self._interval = _assert_param_type(interval, float, self.__class__.__name__, "interval") # Not checking client type for now because it's a boto3.client object and we don't want to # depend on boto3. self._client = client @@ -515,7 +525,7 @@ def make_channel( client=self._client, bucket=bucket, key_prefix=key_prefix, - interval=self._interval, + interval=self.interval, ) @@ -538,6 +548,33 @@ def upload_messages_chunk(self, payload: IO, index: int) -> None: ) +# ######################## +# ##### IO - DBFS +# ######################## + + +class ExtDbfsContextLoader(ExtContextLoader): + @contextmanager + def load_context(self, params: ExtParams) -> Iterator[ExtContextData]: + unmounted_path = _assert_env_param_type(params, "path", str, self.__class__) + path = os.path.join("/dbfs", unmounted_path.lstrip("/")) + with open(path, "r") as f: + data = json.load(f) + yield data + + +class ExtDbfsMessageWriter(ExtBlobStoreMessageWriter): + def make_channel( + self, + params: ExtParams, + ) -> "ExtBufferedFilesystemMessageWriterChannel": + unmounted_path = _assert_env_param_type(params, "path", str, self.__class__) + return ExtBufferedFilesystemMessageWriterChannel( + path=os.path.join("/dbfs", unmounted_path.lstrip("/")), + interval=self.interval, + ) + + # ######################## # ##### CONTEXT # ######################## diff --git a/python_modules/libraries/dagster-databricks/README.md b/python_modules/libraries/dagster-databricks/README.md index 86c4d2d87f3ae..1f2b6259fee6e 100644 --- a/python_modules/libraries/dagster-databricks/README.md +++ b/python_modules/libraries/dagster-databricks/README.md @@ -2,3 +2,146 @@ The docs for `dagster-databricks` can be found [here](https://docs.dagster.io/_apidocs/libraries/dagster-databricks). + +## EXT Example + +This package includes a prototype API for launching databricks jobs with +Dagster's EXT protocol. There are two ways to use the API: + +### (1) `ExtDatabricks` resource + +The `ExtDatabricks` resource provides a high-level API for launching +databricks jobs using Dagster's EXT protocol. + +It takes a single `databricks.sdk.service.jobs.SubmitTask` specification. After +setting up EXT communications channels (which by default use DBFS), it injects +the information needed to connect to these channels from Databricks into the +task specification. It then launches a Databricks job by passing the +specification to `WorkspaceClient.jobs.submit`. It polls the job state and +exits gracefully on success or failure: + + +``` +import os +from dagster import AssetExecutionContext, Definitions, asset +from dagster_databricks import ExtDatabricks +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import jobs + +@asset +def databricks_asset(context: AssetExecutionContext, ext: ExtDatabricks): + + # task specification will be passed to databricks as-is, except for the + # injection of environment variables + task = jobs.SubmitTask.from_dict({ + "new_cluster": { ... }, + "libraries": [ + # must include dagster-ext-process + {"pypi": {"package": "dagster-ext-process"}}, + ], + "task_key": "some-key", + "spark_python_task": { + "python_file": "dbfs:/myscript.py", + "source": jobs.Source.WORKSPACE, + } + }) + + # arbitrary json-serializable data you want access to from the ExtContext + # in the databricks runtime + extras = {"foo": "bar"} + + # synchronously execute the databricks job + ext.run( + task=task, + context=context, + extras=extras, + ) + +client = WorkspaceClient( + host=os.environ["DATABRICKS_HOST"], + token=os.environ["DATABRICKS_TOKEN"], +) + +defs = Definitions( + assets=[databricks_asset], + resources = {"ext": ExtDatabricks(client)} +) +``` + +`ExtDatabricks.run` requires that the targeted python script +(`dbfs:/myscript.py` above) already exist in DBFS. Here is what it might look +like: + +``` +### dbfs:/myscript.py + +# `dagster_ext` must be available in the databricks python environment +from dagster_ext import ExtDbfsContextLoader, ExtDbfsMessageWriter, init_dagster_ext + +# Sets up communication channels and downloads the context data sent from Dagster. +# Note that while other `context_loader` and `message_writer` settings are +# possible, it is recommended to use the below settings for Databricks. +context = init_dagster_ext( + context_loader=ExtDbfsContextLoader(), + message_writer=ExtDbfsMessageWriter() +) + +# Access the `extras` dict passed when launching the job from Dagster. +foo_extra = context.get_extra("foo") + +# Stream log message back to Dagster +context.log(f"Extra for key 'foo': {foo_extra}") + +# ... your code that computes and persists the asset + +# Stream arbitrary metadata back to Dagster. This will be attached to the +# associated `AssetMaterialization` +context.report_asset_metadata("some_metric", get_metric(), metadata_type="text") + +# Stream data version back to Dagster. This will also be attached to the +# associated `AssetMaterialization`. +context.report_asset_data_version(get_data_version()) +``` + +### (2) `ext_protocol` context manager + +Internally, `ExtDatabricks` is using the `ext_protocol` context manager to set +up communications. If you'd prefer more control over how your databricks job is +launched and polled, you can skip `ExtDatabricks` and use this lower level API +directly. All that is necessary is that (1) your Databricks job be launched within +the scope of the `ext_process` context manager; (2) your job is launched on a +cluster containing the environment variables available on the yielded +`ext_context`. + +``` +import os + +from dagster import AssetExecutionContext, ext_protocol +from dagster_databricks import ExtDbfsContextInjector, ExtDbfsMessageReader +from databricks.sdk import WorkspaceClient + +@asset +def databricks_asset(context: AssetExecutionContext): + + client = WorkspaceClient( + host=os.environ["DATABRICKS_HOST"], + token=os.environ["DATABRICKS_TOKEN"], + ) + + extras = {"foo": "bar"} + + # Sets up EXT communications channels + with ext_protocol( + context=context, + extras=extras, + context_injector=ExtDbfsContextInjector(client=client), + message_reader=ExtDbfsMessageReader(client=client), + ) as ext_context: + + # Dict[str, str] with environment variables containing ext comms info. + env_vars = ext_context.get_external_process_env_vars() + + # Some function that handles launching/montoring of the databricks job. + # It must ensure that the `env_vars` are set on the executing cluster. + custom_databricks_launch_code(env_vars) +``` diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py b/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py index 77165104fdab8..e5551b434fa6a 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py @@ -20,6 +20,12 @@ DatabricksPySparkStepLauncher as DatabricksPySparkStepLauncher, databricks_pyspark_step_launcher as databricks_pyspark_step_launcher, ) +from .ext import ( + ExtDatabricks as ExtDatabricks, + ExtDbfsContextInjector as ExtDbfsContextInjector, + ExtDbfsMessageReader as ExtDbfsMessageReader, + dbfs_tempdir as dbfs_tempdir, +) from .ops import ( create_databricks_run_now_op as create_databricks_run_now_op, create_databricks_submit_run_op as create_databricks_submit_run_op, diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py new file mode 100644 index 0000000000000..65bcc0e0d4afb --- /dev/null +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py @@ -0,0 +1,156 @@ +import base64 +import json +import os +import random +import string +import time +from contextlib import contextmanager +from typing import Iterator, Mapping, Optional + +from dagster._core.definitions.resource_annotation import ResourceParam +from dagster._core.errors import DagsterExternalExecutionError +from dagster._core.execution.context.compute import OpExecutionContext +from dagster._core.ext.client import ExtClient, ExtContextInjector, ExtMessageReader +from dagster._core.ext.utils import ( + ExtBlobStoreMessageReader, + ext_protocol, +) +from dagster_ext import ( + ExtContextData, + ExtExtras, + ExtParams, +) +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import files, jobs +from pydantic import Field + + +class _ExtDatabricks(ExtClient): + """Ext client for databricks. + + Args: + client (WorkspaceClient): A databricks workspace client. + env (Optional[Mapping[str,str]]: An optional dict of environment variables to pass to the databricks job. + """ + + env: Optional[Mapping[str, str]] = Field( + default=None, + description="An optional dict of environment variables to pass to the subprocess.", + ) + + def __init__(self, client: WorkspaceClient, env: Optional[Mapping[str, str]] = None): + self.client = client + self.env = env + + def run( + self, + task: jobs.SubmitTask, + *, + context: OpExecutionContext, + extras: Optional[ExtExtras] = None, + context_injector: Optional[ExtContextInjector] = None, + message_reader: Optional[ExtMessageReader] = None, + submit_args: Optional[Mapping[str, str]] = None, + ) -> None: + """Run a Databricks job with the EXT protocol. + + Args: + task (databricks.sdk.service.jobs.SubmitTask): Specification of the databricks + task to run. Environment variables used by dagster-ext will be set under the + `spark_env_vars` key of the `new_cluster` field (if there is an existing dictionary + here, the EXT environment variables will be merged in). Everything else will be + passed unaltered under the `tasks` arg to `WorkspaceClient.jobs.submit`. + submit_args (Optional[Mapping[str, str]]): Additional keyword arguments that will be + forwarded as-is to `WorkspaceClient.jobs.submit`. + """ + with ext_protocol( + context=context, + extras=extras, + context_injector=context_injector or ExtDbfsContextInjector(client=self.client), + message_reader=message_reader or ExtDbfsMessageReader(client=self.client), + ) as ext_context: + submit_task_dict = task.as_dict() + submit_task_dict["new_cluster"]["spark_env_vars"] = { + **submit_task_dict["new_cluster"].get("spark_env_vars", {}), + **(self.env or {}), + **ext_context.get_external_process_env_vars(), + } + task = jobs.SubmitTask.from_dict(submit_task_dict) + run_id = self.client.jobs.submit( + tasks=[task], + **(submit_args or {}), + ).bind()["run_id"] + + while True: + run = self.client.jobs.get_run(run_id) + if run.state.life_cycle_state in ( + jobs.RunLifeCycleState.TERMINATED, + jobs.RunLifeCycleState.SKIPPED, + ): + if run.state.result_state == jobs.RunResultState.SUCCESS: + return + else: + raise DagsterExternalExecutionError( + f"Error running Databricks job: {run.state.state_message}" + ) + elif run.state.life_cycle_state == jobs.RunLifeCycleState.INTERNAL_ERROR: + raise DagsterExternalExecutionError( + f"Error running Databricks job: {run.state.state_message}" + ) + time.sleep(5) + + +ExtDatabricks = ResourceParam[_ExtDatabricks] + +_CONTEXT_FILENAME = "context.json" + + +@contextmanager +def dbfs_tempdir(dbfs_client: files.DbfsAPI) -> Iterator[str]: + dirname = "".join(random.choices(string.ascii_letters, k=30)) + tempdir = f"/tmp/{dirname}" + dbfs_client.mkdirs(tempdir) + try: + yield tempdir + finally: + dbfs_client.delete(tempdir, recursive=True) + + +class ExtDbfsContextInjector(ExtContextInjector): + def __init__(self, *, client: WorkspaceClient): + super().__init__() + self.dbfs_client = files.DbfsAPI(client.api_client) + + @contextmanager + def inject_context(self, context: "ExtContextData") -> Iterator[ExtParams]: + with dbfs_tempdir(self.dbfs_client) as tempdir: + path = os.path.join(tempdir, _CONTEXT_FILENAME) + contents = base64.b64encode(json.dumps(context).encode("utf-8")).decode("utf-8") + self.dbfs_client.put(path, contents=contents, overwrite=True) + yield {"path": path} + + +class ExtDbfsMessageReader(ExtBlobStoreMessageReader): + tempdir: Optional[str] = None + + def __init__(self, *, interval: int = 10, client: WorkspaceClient): + super().__init__(interval=interval) + self.dbfs_client = files.DbfsAPI(client.api_client) + + @contextmanager + def setup(self) -> Iterator[None]: + with dbfs_tempdir(self.dbfs_client) as tempdir: + self.tempdir = tempdir + yield + + def get_params(self) -> ExtParams: + return {"path": self.tempdir} + + def download_messages_chunk(self, index: int) -> Optional[str]: + assert self.tempdir + message_path = os.path.join(self.tempdir, f"{index}.json") + try: + raw_message = self.dbfs_client.read(message_path) + return raw_message.data + except IOError: + return None diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_external_asset.py b/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_external_asset.py new file mode 100644 index 0000000000000..83e763e279be0 --- /dev/null +++ b/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_external_asset.py @@ -0,0 +1,114 @@ +import base64 +import inspect +import os +import textwrap +from contextlib import contextmanager +from typing import Any, Callable, Iterator + +import pytest +from dagster import AssetExecutionContext, asset, materialize +from dagster._core.errors import DagsterExternalExecutionError +from dagster_databricks import ExtDatabricks, dbfs_tempdir +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import files, jobs + +IS_BUILDKITE = os.getenv("BUILDKITE") is not None + + +def script_fn(): + from dagster_ext import ExtDbfsContextLoader, ExtDbfsMessageWriter, init_dagster_ext + + context = init_dagster_ext( + context_loader=ExtDbfsContextLoader(), message_writer=ExtDbfsMessageWriter() + ) + + multiplier = context.get_extra("multiplier") + value = 2 * multiplier + + context.log(f"{context.asset_key}: {2} * {multiplier} = {value}") + context.report_asset_metadata("value", value) + context.report_asset_data_version("alpha") + + +@contextmanager +def temp_script(script_fn: Callable[[], Any], client: WorkspaceClient) -> Iterator[str]: + # drop the signature line + source = textwrap.dedent(inspect.getsource(script_fn).split("\n", 1)[1]) + dbfs_client = files.DbfsAPI(client.api_client) + with dbfs_tempdir(dbfs_client) as tempdir: + script_path = os.path.join(tempdir, "script.py") + contents = base64.b64encode(source.encode("utf-8")).decode("utf-8") + dbfs_client.put(script_path, contents=contents, overwrite=True) + yield script_path + + +@pytest.fixture +def client() -> WorkspaceClient: + return WorkspaceClient( + host=os.environ["DATABRICKS_HOST"], + token=os.environ["DATABRICKS_TOKEN"], + ) + + +CLUSTER_DEFAULTS = { + "spark_version": "12.2.x-scala2.12", + "node_type_id": "i3.xlarge", + "num_workers": 0, +} + +TASK_KEY = "DAGSTER_EXT_TASK" + +# This has been manually uploaded to a test DBFS workspace. +DAGSTER_EXTERNALS_WHL_PATH = "dbfs:/FileStore/jars/dagster_ext-1!0+dev-py3-none-any.whl" + + +def _make_submit_task(path: str) -> jobs.SubmitTask: + return jobs.SubmitTask.from_dict( + { + "new_cluster": CLUSTER_DEFAULTS, + "libraries": [ + {"whl": DAGSTER_EXTERNALS_WHL_PATH}, + ], + "task_key": TASK_KEY, + "spark_python_task": { + "python_file": f"dbfs:{path}", + "source": jobs.Source.WORKSPACE, + }, + } + ) + + +@pytest.mark.skipif(IS_BUILDKITE, reason="Not configured to run on BK yet.") +def test_basic(client: WorkspaceClient): + @asset + def number_x(context: AssetExecutionContext, ext_databricks: ExtDatabricks): + with temp_script(script_fn, client) as script_path: + task = _make_submit_task(script_path) + ext_databricks.run( + task=task, + context=context, + extras={"multiplier": 2, "storage_root": "fake"}, + ) + + result = materialize( + [number_x], + resources={"ext_databricks": ExtDatabricks(client)}, + raise_on_error=False, + ) + assert result.success + mats = result.asset_materializations_for_node(number_x.op.name) + assert mats[0].metadata["path"] + + +@pytest.mark.skipif(IS_BUILDKITE, reason="Not configured to run on BK yet.") +def test_nonexistent_entry_point(client: WorkspaceClient): + @asset + def fake(context: AssetExecutionContext, ext_databricks: ExtDatabricks): + task = _make_submit_task("/fake/fake") + ext_databricks.run(task=task, context=context) + + with pytest.raises(DagsterExternalExecutionError, match=r"Cannot read the python file"): + materialize( + [fake], + resources={"ext_databricks": ExtDatabricks(client)}, + ) diff --git a/python_modules/libraries/dagster-databricks/setup.py b/python_modules/libraries/dagster-databricks/setup.py index 45bf18da0ff24..f936714a1605d 100644 --- a/python_modules/libraries/dagster-databricks/setup.py +++ b/python_modules/libraries/dagster-databricks/setup.py @@ -37,7 +37,7 @@ def get_version() -> str: f"dagster-pyspark{pin}", "databricks-cli~=0.17", # TODO: Remove this dependency in the next minor release. "databricks_api", # TODO: Remove this dependency in the next minor release. - "databricks-sdk<0.7", # Breaking changes occur in minor versions. + "databricks-sdk<0.9", # Breaking changes occur in minor versions. ], zip_safe=False, ) diff --git a/python_modules/libraries/dagster-databricks/tox.ini b/python_modules/libraries/dagster-databricks/tox.ini index 3c8493fef3d80..c8897f744e50c 100644 --- a/python_modules/libraries/dagster-databricks/tox.ini +++ b/python_modules/libraries/dagster-databricks/tox.ini @@ -6,6 +6,7 @@ download = True passenv = CI_* COVERALLS_REPO_TOKEN DATABRICKS_* BUILDKITE* SSH_* deps = -e ../../dagster[test] + -e ../../dagster-ext -e ../dagster-aws -e ../dagster-azure -e ../dagster-spark From 096980a672527b5a784240a539f2217e6141a8ee Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Thu, 14 Sep 2023 10:24:27 -0400 Subject: [PATCH 2/2] address feedback --- python_modules/automation/tox.ini | 1 + .../dagster-ext/dagster_ext/__init__.py | 3 +- .../dagster/dagster/_core/ext/client.py | 1 - .../dagster/dagster/_core/ext/utils.py | 20 ++++---- .../libraries/dagster-aws/dagster_aws/ext.py | 10 ++-- .../libraries/dagster-databricks/README.md | 49 ++++++++++--------- .../dagster_databricks/ext.py | 46 +++++++++++------ .../libraries/dagster-databricks/setup.py | 1 + 8 files changed, 75 insertions(+), 56 deletions(-) diff --git a/python_modules/automation/tox.ini b/python_modules/automation/tox.ini index 513092e7ca6c0..f77dcceb41ed8 100644 --- a/python_modules/automation/tox.ini +++ b/python_modules/automation/tox.ini @@ -6,6 +6,7 @@ download = True passenv = CI_PULL_REQUEST COVERALLS_REPO_TOKEN BUILDKITE* deps = -e ../dagster[test] + -e ../dagster-ext -e ../dagster-graphql -e ../libraries/dagster-managed-elements -e ../libraries/dagster-airbyte diff --git a/python_modules/dagster-ext/dagster_ext/__init__.py b/python_modules/dagster-ext/dagster_ext/__init__.py index 8e66f660efd07..8c370f0623a67 100644 --- a/python_modules/dagster-ext/dagster_ext/__init__.py +++ b/python_modules/dagster-ext/dagster_ext/__init__.py @@ -559,8 +559,7 @@ def load_context(self, params: ExtParams) -> Iterator[ExtContextData]: unmounted_path = _assert_env_param_type(params, "path", str, self.__class__) path = os.path.join("/dbfs", unmounted_path.lstrip("/")) with open(path, "r") as f: - data = json.load(f) - yield data + yield json.load(f) class ExtDbfsMessageWriter(ExtBlobStoreMessageWriter): diff --git a/python_modules/dagster/dagster/_core/ext/client.py b/python_modules/dagster/dagster/_core/ext/client.py index d4ad967e6f529..9e86b438053e8 100644 --- a/python_modules/dagster/dagster/_core/ext/client.py +++ b/python_modules/dagster/dagster/_core/ext/client.py @@ -21,7 +21,6 @@ def run( *, context: OpExecutionContext, extras: Optional[ExtExtras] = None, - message_reader: Optional["ExtMessageReader"] = None, ) -> None: ... diff --git a/python_modules/dagster/dagster/_core/ext/utils.py b/python_modules/dagster/dagster/_core/ext/utils.py index 949be2723aa00..aeaf6aea0fb82 100644 --- a/python_modules/dagster/dagster/_core/ext/utils.py +++ b/python_modules/dagster/dagster/_core/ext/utils.py @@ -130,7 +130,7 @@ def read_messages( self, handler: "ExtMessageHandler", ) -> Iterator[ExtParams]: - with self.setup(): + with self.get_params() as params: is_task_complete = Event() thread = None try: @@ -138,35 +138,35 @@ def read_messages( target=self._reader_thread, args=( handler, + params, is_task_complete, ), daemon=True, ) thread.start() - yield self.get_params() + yield params finally: is_task_complete.set() if thread: thread.join() - @contextmanager - def setup(self) -> Iterator[None]: - yield - @abstractmethod - def get_params(self) -> ExtParams: + @contextmanager + def get_params(self) -> Iterator[ExtParams]: ... @abstractmethod - def download_messages_chunk(self, index: int) -> Optional[str]: + def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]: ... - def _reader_thread(self, handler: "ExtMessageHandler", is_task_complete: Event) -> None: + def _reader_thread( + self, handler: "ExtMessageHandler", params: ExtParams, is_task_complete: Event + ) -> None: start_or_last_download = datetime.datetime.now() while True: now = datetime.datetime.now() if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set(): - chunk = self.download_messages_chunk(self.counter) + chunk = self.download_messages_chunk(self.counter, params) start_or_last_download = now if chunk: for line in chunk.split("\n"): diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ext.py b/python_modules/libraries/dagster-aws/dagster_aws/ext.py index 32cfeb5593b01..807c482f02971 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ext.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ext.py @@ -1,6 +1,7 @@ import random import string -from typing import Optional +from contextlib import contextmanager +from typing import Iterator, Optional import boto3 import dagster._check as check @@ -18,10 +19,11 @@ def __init__(self, *, interval: int = 10, bucket: str, client: boto3.client): self.key_prefix = "".join(random.choices(string.ascii_letters, k=30)) self.client = client - def get_params(self) -> ExtParams: - return {"bucket": self.bucket, "key_prefix": self.key_prefix} + @contextmanager + def get_params(self) -> Iterator[ExtParams]: + yield {"bucket": self.bucket, "key_prefix": self.key_prefix} - def download_messages_chunk(self, index: int) -> Optional[str]: + def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]: key = f"{self.key_prefix}/{index}.json" try: obj = self.client.get_object(Bucket=self.bucket, Key=key) diff --git a/python_modules/libraries/dagster-databricks/README.md b/python_modules/libraries/dagster-databricks/README.md index 1f2b6259fee6e..b5d0ab2166409 100644 --- a/python_modules/libraries/dagster-databricks/README.md +++ b/python_modules/libraries/dagster-databricks/README.md @@ -3,7 +3,7 @@ The docs for `dagster-databricks` can be found [here](https://docs.dagster.io/_apidocs/libraries/dagster-databricks). -## EXT Example +## ext example This package includes a prototype API for launching databricks jobs with Dagster's EXT protocol. There are two ways to use the API: @@ -11,14 +11,14 @@ Dagster's EXT protocol. There are two ways to use the API: ### (1) `ExtDatabricks` resource The `ExtDatabricks` resource provides a high-level API for launching -databricks jobs using Dagster's EXT protocol. +databricks jobs using Dagster's ext protocol. -It takes a single `databricks.sdk.service.jobs.SubmitTask` specification. After -setting up EXT communications channels (which by default use DBFS), it injects -the information needed to connect to these channels from Databricks into the -task specification. It then launches a Databricks job by passing the -specification to `WorkspaceClient.jobs.submit`. It polls the job state and -exits gracefully on success or failure: +`ExtDatabricks.run` takes a single `databricks.sdk.service.jobs.SubmitTask` +specification. After setting up ext communications channels (which by default +use DBFS), it injects the information needed to connect to these channels from +Databricks into the task specification. It then launches a Databricks job by +passing the specification to `WorkspaceClient.jobs.submit`. It polls the job +state and exits gracefully on success or failure: ``` @@ -46,9 +46,10 @@ def databricks_asset(context: AssetExecutionContext, ext: ExtDatabricks): } }) - # arbitrary json-serializable data you want access to from the ExtContext - # in the databricks runtime - extras = {"foo": "bar"} + # Arbitrary json-serializable data you want access to from the `ExtContext` + # in the databricks runtime. Assume `sample_rate` is a parameter used by + # the target job's business logic. + extras = {"sample_rate": 1.0} # synchronously execute the databricks job ext.run( @@ -87,10 +88,10 @@ context = init_dagster_ext( ) # Access the `extras` dict passed when launching the job from Dagster. -foo_extra = context.get_extra("foo") +sample_rate = context.get_extra("sample_rate") # Stream log message back to Dagster -context.log(f"Extra for key 'foo': {foo_extra}") +context.log(f"Using sample rate: {sample_rate}") # ... your code that computes and persists the asset @@ -105,13 +106,12 @@ context.report_asset_data_version(get_data_version()) ### (2) `ext_protocol` context manager -Internally, `ExtDatabricks` is using the `ext_protocol` context manager to set -up communications. If you'd prefer more control over how your databricks job is -launched and polled, you can skip `ExtDatabricks` and use this lower level API -directly. All that is necessary is that (1) your Databricks job be launched within -the scope of the `ext_process` context manager; (2) your job is launched on a -cluster containing the environment variables available on the yielded -`ext_context`. +If you have existing code to launch/poll the job you do not want to change, or +you just want more control than is permitted by `ExtDatabricks`, you can use +`ext_protocol`. All that is necessary is that (1) your Databricks job be +launched within the scope of the `ext_process` context manager; (2) your job is +launched on a cluster containing the environment variables available on the +yielded `ext_context`. ``` import os @@ -128,9 +128,12 @@ def databricks_asset(context: AssetExecutionContext): token=os.environ["DATABRICKS_TOKEN"], ) - extras = {"foo": "bar"} + # Arbitrary json-serializable data you want access to from the `ExtContext` + # in the databricks runtime. Assume `sample_rate` is a parameter used by + # the target job's business logic. + extras = {"sample_rate": 1.0} - # Sets up EXT communications channels + # Sets up ext communications channels with ext_protocol( context=context, extras=extras, @@ -141,7 +144,7 @@ def databricks_asset(context: AssetExecutionContext): # Dict[str, str] with environment variables containing ext comms info. env_vars = ext_context.get_external_process_env_vars() - # Some function that handles launching/montoring of the databricks job. + # Some function that handles launching/monitoring of the databricks job. # It must ensure that the `env_vars` are set on the executing cluster. custom_databricks_launch_code(env_vars) ``` diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py index 65bcc0e0d4afb..22bb89034cd07 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from typing import Iterator, Mapping, Optional +import dagster._check as check from dagster._core.definitions.resource_annotation import ResourceParam from dagster._core.errors import DagsterExternalExecutionError from dagster._core.execution.context.compute import OpExecutionContext @@ -38,9 +39,25 @@ class _ExtDatabricks(ExtClient): description="An optional dict of environment variables to pass to the subprocess.", ) - def __init__(self, client: WorkspaceClient, env: Optional[Mapping[str, str]] = None): + def __init__( + self, + client: WorkspaceClient, + env: Optional[Mapping[str, str]] = None, + context_injector: Optional[ExtContextInjector] = None, + message_reader: Optional[ExtMessageReader] = None, + ): self.client = client self.env = env + self.context_injector = check.opt_inst_param( + context_injector, + "context_injector", + ExtContextInjector, + ) or ExtDbfsContextInjector(client=self.client) + self.message_reader = check.opt_inst_param( + message_reader, + "message_reader", + ExtMessageReader, + ) or ExtDbfsMessageReader(client=self.client) def run( self, @@ -48,8 +65,6 @@ def run( *, context: OpExecutionContext, extras: Optional[ExtExtras] = None, - context_injector: Optional[ExtContextInjector] = None, - message_reader: Optional[ExtMessageReader] = None, submit_args: Optional[Mapping[str, str]] = None, ) -> None: """Run a Databricks job with the EXT protocol. @@ -66,8 +81,8 @@ def run( with ext_protocol( context=context, extras=extras, - context_injector=context_injector or ExtDbfsContextInjector(client=self.client), - message_reader=message_reader or ExtDbfsMessageReader(client=self.client), + context_injector=self.context_injector, + message_reader=self.message_reader, ) as ext_context: submit_task_dict = task.as_dict() submit_task_dict["new_cluster"]["spark_env_vars"] = { @@ -83,6 +98,9 @@ def run( while True: run = self.client.jobs.get_run(run_id) + context.log.info( + f"Databricks run {run_id} current state: {run.state.life_cycle_state}" + ) if run.state.life_cycle_state in ( jobs.RunLifeCycleState.TERMINATED, jobs.RunLifeCycleState.SKIPPED, @@ -131,26 +149,22 @@ def inject_context(self, context: "ExtContextData") -> Iterator[ExtParams]: class ExtDbfsMessageReader(ExtBlobStoreMessageReader): - tempdir: Optional[str] = None - def __init__(self, *, interval: int = 10, client: WorkspaceClient): super().__init__(interval=interval) self.dbfs_client = files.DbfsAPI(client.api_client) @contextmanager - def setup(self) -> Iterator[None]: + def get_params(self) -> Iterator[ExtParams]: with dbfs_tempdir(self.dbfs_client) as tempdir: - self.tempdir = tempdir - yield - - def get_params(self) -> ExtParams: - return {"path": self.tempdir} + yield {"path": tempdir} - def download_messages_chunk(self, index: int) -> Optional[str]: - assert self.tempdir - message_path = os.path.join(self.tempdir, f"{index}.json") + def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]: + message_path = os.path.join(params["path"], f"{index}.json") try: raw_message = self.dbfs_client.read(message_path) return raw_message.data + # An error here is an expected result, since an IOError will be thrown if the next message + # chunk doesn't yet exist. Swallowing the error here is equivalent to doing a no-op on a + # status check showing a non-existent file. except IOError: return None diff --git a/python_modules/libraries/dagster-databricks/setup.py b/python_modules/libraries/dagster-databricks/setup.py index f936714a1605d..9fbf10a81d4f4 100644 --- a/python_modules/libraries/dagster-databricks/setup.py +++ b/python_modules/libraries/dagster-databricks/setup.py @@ -34,6 +34,7 @@ def get_version() -> str: include_package_data=True, install_requires=[ f"dagster{pin}", + f"dagster-ext-process{pin}", f"dagster-pyspark{pin}", "databricks-cli~=0.17", # TODO: Remove this dependency in the next minor release. "databricks_api", # TODO: Remove this dependency in the next minor release.