diff --git a/python_modules/automation/tox.ini b/python_modules/automation/tox.ini index 513092e7ca6c0..f77dcceb41ed8 100644 --- a/python_modules/automation/tox.ini +++ b/python_modules/automation/tox.ini @@ -6,6 +6,7 @@ download = True passenv = CI_PULL_REQUEST COVERALLS_REPO_TOKEN BUILDKITE* deps = -e ../dagster[test] + -e ../dagster-ext -e ../dagster-graphql -e ../libraries/dagster-managed-elements -e ../libraries/dagster-airbyte diff --git a/python_modules/dagster-ext/dagster_ext/__init__.py b/python_modules/dagster-ext/dagster_ext/__init__.py index 8e66f660efd07..8c370f0623a67 100644 --- a/python_modules/dagster-ext/dagster_ext/__init__.py +++ b/python_modules/dagster-ext/dagster_ext/__init__.py @@ -559,8 +559,7 @@ def load_context(self, params: ExtParams) -> Iterator[ExtContextData]: unmounted_path = _assert_env_param_type(params, "path", str, self.__class__) path = os.path.join("/dbfs", unmounted_path.lstrip("/")) with open(path, "r") as f: - data = json.load(f) - yield data + yield json.load(f) class ExtDbfsMessageWriter(ExtBlobStoreMessageWriter): diff --git a/python_modules/dagster/dagster/_core/ext/client.py b/python_modules/dagster/dagster/_core/ext/client.py index d4ad967e6f529..9e86b438053e8 100644 --- a/python_modules/dagster/dagster/_core/ext/client.py +++ b/python_modules/dagster/dagster/_core/ext/client.py @@ -21,7 +21,6 @@ def run( *, context: OpExecutionContext, extras: Optional[ExtExtras] = None, - message_reader: Optional["ExtMessageReader"] = None, ) -> None: ... diff --git a/python_modules/dagster/dagster/_core/ext/utils.py b/python_modules/dagster/dagster/_core/ext/utils.py index 949be2723aa00..aeaf6aea0fb82 100644 --- a/python_modules/dagster/dagster/_core/ext/utils.py +++ b/python_modules/dagster/dagster/_core/ext/utils.py @@ -130,7 +130,7 @@ def read_messages( self, handler: "ExtMessageHandler", ) -> Iterator[ExtParams]: - with self.setup(): + with self.get_params() as params: is_task_complete = Event() thread = None try: @@ -138,35 +138,35 @@ def read_messages( target=self._reader_thread, args=( handler, + params, is_task_complete, ), daemon=True, ) thread.start() - yield self.get_params() + yield params finally: is_task_complete.set() if thread: thread.join() - @contextmanager - def setup(self) -> Iterator[None]: - yield - @abstractmethod - def get_params(self) -> ExtParams: + @contextmanager + def get_params(self) -> Iterator[ExtParams]: ... @abstractmethod - def download_messages_chunk(self, index: int) -> Optional[str]: + def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]: ... - def _reader_thread(self, handler: "ExtMessageHandler", is_task_complete: Event) -> None: + def _reader_thread( + self, handler: "ExtMessageHandler", params: ExtParams, is_task_complete: Event + ) -> None: start_or_last_download = datetime.datetime.now() while True: now = datetime.datetime.now() if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set(): - chunk = self.download_messages_chunk(self.counter) + chunk = self.download_messages_chunk(self.counter, params) start_or_last_download = now if chunk: for line in chunk.split("\n"): diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ext.py b/python_modules/libraries/dagster-aws/dagster_aws/ext.py index 32cfeb5593b01..807c482f02971 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ext.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ext.py @@ -1,6 +1,7 @@ import random import string -from typing import Optional +from contextlib import contextmanager +from typing import Iterator, Optional import boto3 import dagster._check as check @@ -18,10 +19,11 @@ def __init__(self, *, interval: int = 10, bucket: str, client: boto3.client): self.key_prefix = "".join(random.choices(string.ascii_letters, k=30)) self.client = client - def get_params(self) -> ExtParams: - return {"bucket": self.bucket, "key_prefix": self.key_prefix} + @contextmanager + def get_params(self) -> Iterator[ExtParams]: + yield {"bucket": self.bucket, "key_prefix": self.key_prefix} - def download_messages_chunk(self, index: int) -> Optional[str]: + def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]: key = f"{self.key_prefix}/{index}.json" try: obj = self.client.get_object(Bucket=self.bucket, Key=key) diff --git a/python_modules/libraries/dagster-databricks/README.md b/python_modules/libraries/dagster-databricks/README.md index 1f2b6259fee6e..b5d0ab2166409 100644 --- a/python_modules/libraries/dagster-databricks/README.md +++ b/python_modules/libraries/dagster-databricks/README.md @@ -3,7 +3,7 @@ The docs for `dagster-databricks` can be found [here](https://docs.dagster.io/_apidocs/libraries/dagster-databricks). -## EXT Example +## ext example This package includes a prototype API for launching databricks jobs with Dagster's EXT protocol. There are two ways to use the API: @@ -11,14 +11,14 @@ Dagster's EXT protocol. There are two ways to use the API: ### (1) `ExtDatabricks` resource The `ExtDatabricks` resource provides a high-level API for launching -databricks jobs using Dagster's EXT protocol. +databricks jobs using Dagster's ext protocol. -It takes a single `databricks.sdk.service.jobs.SubmitTask` specification. After -setting up EXT communications channels (which by default use DBFS), it injects -the information needed to connect to these channels from Databricks into the -task specification. It then launches a Databricks job by passing the -specification to `WorkspaceClient.jobs.submit`. It polls the job state and -exits gracefully on success or failure: +`ExtDatabricks.run` takes a single `databricks.sdk.service.jobs.SubmitTask` +specification. After setting up ext communications channels (which by default +use DBFS), it injects the information needed to connect to these channels from +Databricks into the task specification. It then launches a Databricks job by +passing the specification to `WorkspaceClient.jobs.submit`. It polls the job +state and exits gracefully on success or failure: ``` @@ -46,9 +46,10 @@ def databricks_asset(context: AssetExecutionContext, ext: ExtDatabricks): } }) - # arbitrary json-serializable data you want access to from the ExtContext - # in the databricks runtime - extras = {"foo": "bar"} + # Arbitrary json-serializable data you want access to from the `ExtContext` + # in the databricks runtime. Assume `sample_rate` is a parameter used by + # the target job's business logic. + extras = {"sample_rate": 1.0} # synchronously execute the databricks job ext.run( @@ -87,10 +88,10 @@ context = init_dagster_ext( ) # Access the `extras` dict passed when launching the job from Dagster. -foo_extra = context.get_extra("foo") +sample_rate = context.get_extra("sample_rate") # Stream log message back to Dagster -context.log(f"Extra for key 'foo': {foo_extra}") +context.log(f"Using sample rate: {sample_rate}") # ... your code that computes and persists the asset @@ -105,13 +106,12 @@ context.report_asset_data_version(get_data_version()) ### (2) `ext_protocol` context manager -Internally, `ExtDatabricks` is using the `ext_protocol` context manager to set -up communications. If you'd prefer more control over how your databricks job is -launched and polled, you can skip `ExtDatabricks` and use this lower level API -directly. All that is necessary is that (1) your Databricks job be launched within -the scope of the `ext_process` context manager; (2) your job is launched on a -cluster containing the environment variables available on the yielded -`ext_context`. +If you have existing code to launch/poll the job you do not want to change, or +you just want more control than is permitted by `ExtDatabricks`, you can use +`ext_protocol`. All that is necessary is that (1) your Databricks job be +launched within the scope of the `ext_process` context manager; (2) your job is +launched on a cluster containing the environment variables available on the +yielded `ext_context`. ``` import os @@ -128,9 +128,12 @@ def databricks_asset(context: AssetExecutionContext): token=os.environ["DATABRICKS_TOKEN"], ) - extras = {"foo": "bar"} + # Arbitrary json-serializable data you want access to from the `ExtContext` + # in the databricks runtime. Assume `sample_rate` is a parameter used by + # the target job's business logic. + extras = {"sample_rate": 1.0} - # Sets up EXT communications channels + # Sets up ext communications channels with ext_protocol( context=context, extras=extras, @@ -141,7 +144,7 @@ def databricks_asset(context: AssetExecutionContext): # Dict[str, str] with environment variables containing ext comms info. env_vars = ext_context.get_external_process_env_vars() - # Some function that handles launching/montoring of the databricks job. + # Some function that handles launching/monitoring of the databricks job. # It must ensure that the `env_vars` are set on the executing cluster. custom_databricks_launch_code(env_vars) ``` diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py index 65bcc0e0d4afb..22bb89034cd07 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from typing import Iterator, Mapping, Optional +import dagster._check as check from dagster._core.definitions.resource_annotation import ResourceParam from dagster._core.errors import DagsterExternalExecutionError from dagster._core.execution.context.compute import OpExecutionContext @@ -38,9 +39,25 @@ class _ExtDatabricks(ExtClient): description="An optional dict of environment variables to pass to the subprocess.", ) - def __init__(self, client: WorkspaceClient, env: Optional[Mapping[str, str]] = None): + def __init__( + self, + client: WorkspaceClient, + env: Optional[Mapping[str, str]] = None, + context_injector: Optional[ExtContextInjector] = None, + message_reader: Optional[ExtMessageReader] = None, + ): self.client = client self.env = env + self.context_injector = check.opt_inst_param( + context_injector, + "context_injector", + ExtContextInjector, + ) or ExtDbfsContextInjector(client=self.client) + self.message_reader = check.opt_inst_param( + message_reader, + "message_reader", + ExtMessageReader, + ) or ExtDbfsMessageReader(client=self.client) def run( self, @@ -48,8 +65,6 @@ def run( *, context: OpExecutionContext, extras: Optional[ExtExtras] = None, - context_injector: Optional[ExtContextInjector] = None, - message_reader: Optional[ExtMessageReader] = None, submit_args: Optional[Mapping[str, str]] = None, ) -> None: """Run a Databricks job with the EXT protocol. @@ -66,8 +81,8 @@ def run( with ext_protocol( context=context, extras=extras, - context_injector=context_injector or ExtDbfsContextInjector(client=self.client), - message_reader=message_reader or ExtDbfsMessageReader(client=self.client), + context_injector=self.context_injector, + message_reader=self.message_reader, ) as ext_context: submit_task_dict = task.as_dict() submit_task_dict["new_cluster"]["spark_env_vars"] = { @@ -83,6 +98,9 @@ def run( while True: run = self.client.jobs.get_run(run_id) + context.log.info( + f"Databricks run {run_id} current state: {run.state.life_cycle_state}" + ) if run.state.life_cycle_state in ( jobs.RunLifeCycleState.TERMINATED, jobs.RunLifeCycleState.SKIPPED, @@ -131,26 +149,22 @@ def inject_context(self, context: "ExtContextData") -> Iterator[ExtParams]: class ExtDbfsMessageReader(ExtBlobStoreMessageReader): - tempdir: Optional[str] = None - def __init__(self, *, interval: int = 10, client: WorkspaceClient): super().__init__(interval=interval) self.dbfs_client = files.DbfsAPI(client.api_client) @contextmanager - def setup(self) -> Iterator[None]: + def get_params(self) -> Iterator[ExtParams]: with dbfs_tempdir(self.dbfs_client) as tempdir: - self.tempdir = tempdir - yield - - def get_params(self) -> ExtParams: - return {"path": self.tempdir} + yield {"path": tempdir} - def download_messages_chunk(self, index: int) -> Optional[str]: - assert self.tempdir - message_path = os.path.join(self.tempdir, f"{index}.json") + def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]: + message_path = os.path.join(params["path"], f"{index}.json") try: raw_message = self.dbfs_client.read(message_path) return raw_message.data + # An error here is an expected result, since an IOError will be thrown if the next message + # chunk doesn't yet exist. Swallowing the error here is equivalent to doing a no-op on a + # status check showing a non-existent file. except IOError: return None diff --git a/python_modules/libraries/dagster-databricks/setup.py b/python_modules/libraries/dagster-databricks/setup.py index f936714a1605d..9fbf10a81d4f4 100644 --- a/python_modules/libraries/dagster-databricks/setup.py +++ b/python_modules/libraries/dagster-databricks/setup.py @@ -34,6 +34,7 @@ def get_version() -> str: include_package_data=True, install_requires=[ f"dagster{pin}", + f"dagster-ext-process{pin}", f"dagster-pyspark{pin}", "databricks-cli~=0.17", # TODO: Remove this dependency in the next minor release. "databricks_api", # TODO: Remove this dependency in the next minor release.