From 066249e49d939e9cb24e19282bcc2d6d9f05add0 Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Thu, 14 Sep 2023 10:24:27 -0400 Subject: [PATCH] address feedback --- .../dagster/dagster/_core/ext/client.py | 1 - .../dagster/dagster/_core/ext/utils.py | 20 ++++---- .../libraries/dagster-aws/dagster_aws/ext.py | 10 ++-- .../libraries/dagster-databricks/README.md | 48 ++++++++++--------- .../dagster_databricks/ext.py | 38 ++++++++++----- 5 files changed, 67 insertions(+), 50 deletions(-) diff --git a/python_modules/dagster/dagster/_core/ext/client.py b/python_modules/dagster/dagster/_core/ext/client.py index d4ad967e6f529..9e86b438053e8 100644 --- a/python_modules/dagster/dagster/_core/ext/client.py +++ b/python_modules/dagster/dagster/_core/ext/client.py @@ -21,7 +21,6 @@ def run( *, context: OpExecutionContext, extras: Optional[ExtExtras] = None, - message_reader: Optional["ExtMessageReader"] = None, ) -> None: ... diff --git a/python_modules/dagster/dagster/_core/ext/utils.py b/python_modules/dagster/dagster/_core/ext/utils.py index 949be2723aa00..3a601a1ccf416 100644 --- a/python_modules/dagster/dagster/_core/ext/utils.py +++ b/python_modules/dagster/dagster/_core/ext/utils.py @@ -130,7 +130,7 @@ def read_messages( self, handler: "ExtMessageHandler", ) -> Iterator[ExtParams]: - with self.setup(): + with self.setup() as params: is_task_complete = Event() thread = None try: @@ -138,35 +138,35 @@ def read_messages( target=self._reader_thread, args=( handler, + params, is_task_complete, ), daemon=True, ) thread.start() - yield self.get_params() + yield params finally: is_task_complete.set() if thread: thread.join() - @contextmanager - def setup(self) -> Iterator[None]: - yield - @abstractmethod - def get_params(self) -> ExtParams: + @contextmanager + def setup(self) -> Iterator[ExtParams]: ... @abstractmethod - def download_messages_chunk(self, index: int) -> Optional[str]: + def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]: ... - def _reader_thread(self, handler: "ExtMessageHandler", is_task_complete: Event) -> None: + def _reader_thread( + self, handler: "ExtMessageHandler", params: ExtParams, is_task_complete: Event + ) -> None: start_or_last_download = datetime.datetime.now() while True: now = datetime.datetime.now() if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set(): - chunk = self.download_messages_chunk(self.counter) + chunk = self.download_messages_chunk(self.counter, params) start_or_last_download = now if chunk: for line in chunk.split("\n"): diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ext.py b/python_modules/libraries/dagster-aws/dagster_aws/ext.py index 32cfeb5593b01..2b8aad86087aa 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ext.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ext.py @@ -1,6 +1,7 @@ import random import string -from typing import Optional +from contextlib import contextmanager +from typing import Iterator, Optional import boto3 import dagster._check as check @@ -18,10 +19,11 @@ def __init__(self, *, interval: int = 10, bucket: str, client: boto3.client): self.key_prefix = "".join(random.choices(string.ascii_letters, k=30)) self.client = client - def get_params(self) -> ExtParams: - return {"bucket": self.bucket, "key_prefix": self.key_prefix} + @contextmanager + def setup(self) -> Iterator[ExtParams]: + yield {"bucket": self.bucket, "key_prefix": self.key_prefix} - def download_messages_chunk(self, index: int) -> Optional[str]: + def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]: key = f"{self.key_prefix}/{index}.json" try: obj = self.client.get_object(Bucket=self.bucket, Key=key) diff --git a/python_modules/libraries/dagster-databricks/README.md b/python_modules/libraries/dagster-databricks/README.md index 1f2b6259fee6e..7b4068c1656c0 100644 --- a/python_modules/libraries/dagster-databricks/README.md +++ b/python_modules/libraries/dagster-databricks/README.md @@ -3,7 +3,7 @@ The docs for `dagster-databricks` can be found [here](https://docs.dagster.io/_apidocs/libraries/dagster-databricks). -## EXT Example +## ext example This package includes a prototype API for launching databricks jobs with Dagster's EXT protocol. There are two ways to use the API: @@ -11,14 +11,14 @@ Dagster's EXT protocol. There are two ways to use the API: ### (1) `ExtDatabricks` resource The `ExtDatabricks` resource provides a high-level API for launching -databricks jobs using Dagster's EXT protocol. +databricks jobs using Dagster's ext protocol. -It takes a single `databricks.sdk.service.jobs.SubmitTask` specification. After -setting up EXT communications channels (which by default use DBFS), it injects -the information needed to connect to these channels from Databricks into the -task specification. It then launches a Databricks job by passing the -specification to `WorkspaceClient.jobs.submit`. It polls the job state and -exits gracefully on success or failure: +`ExtDatabricks.run` takes a single `databricks.sdk.service.jobs.SubmitTask` +specification. After setting up ext communications channels (which by default +use DBFS), it injects the information needed to connect to these channels from +Databricks into the task specification. It then launches a Databricks job by +passing the specification to `WorkspaceClient.jobs.submit`. It polls the job +state and exits gracefully on success or failure: ``` @@ -46,9 +46,10 @@ def databricks_asset(context: AssetExecutionContext, ext: ExtDatabricks): } }) - # arbitrary json-serializable data you want access to from the ExtContext - # in the databricks runtime - extras = {"foo": "bar"} + # Arbitrary json-serializable data you want access to from the `ExtContext` + # in the databricks runtime. Assume `sample_rate` is a parameter used by + # the target job's business logic. + extras = {"sample_rate": 1.0} # synchronously execute the databricks job ext.run( @@ -87,10 +88,10 @@ context = init_dagster_ext( ) # Access the `extras` dict passed when launching the job from Dagster. -foo_extra = context.get_extra("foo") +sample_rate = context.get_extra("sample_rate") # Stream log message back to Dagster -context.log(f"Extra for key 'foo': {foo_extra}") +context.log(f"Using sample rate: {sample_rate}") # ... your code that computes and persists the asset @@ -106,12 +107,12 @@ context.report_asset_data_version(get_data_version()) ### (2) `ext_protocol` context manager Internally, `ExtDatabricks` is using the `ext_protocol` context manager to set -up communications. If you'd prefer more control over how your databricks job is -launched and polled, you can skip `ExtDatabricks` and use this lower level API -directly. All that is necessary is that (1) your Databricks job be launched within -the scope of the `ext_process` context manager; (2) your job is launched on a -cluster containing the environment variables available on the yielded -`ext_context`. +up communications. If you have existing code to launch/poll the job you do not +want to change, or you just want more control than is permitted by +`ExtDatabricks`, you can use this lower level API directly. All that is +necessary is that (1) your Databricks job be launched within the scope of the +`ext_process` context manager; (2) your job is launched on a cluster containing +the environment variables available on the yielded `ext_context`. ``` import os @@ -128,9 +129,12 @@ def databricks_asset(context: AssetExecutionContext): token=os.environ["DATABRICKS_TOKEN"], ) - extras = {"foo": "bar"} + # Arbitrary json-serializable data you want access to from the `ExtContext` + # in the databricks runtime. Assume `sample_rate` is a parameter used by + # the target job's business logic. + extras = {"sample_rate": 1.0} - # Sets up EXT communications channels + # Sets up ext communications channels with ext_protocol( context=context, extras=extras, @@ -141,7 +145,7 @@ def databricks_asset(context: AssetExecutionContext): # Dict[str, str] with environment variables containing ext comms info. env_vars = ext_context.get_external_process_env_vars() - # Some function that handles launching/montoring of the databricks job. + # Some function that handles launching/monitoring of the databricks job. # It must ensure that the `env_vars` are set on the executing cluster. custom_databricks_launch_code(env_vars) ``` diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py index 65bcc0e0d4afb..d4516e0a8638d 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/ext.py @@ -7,6 +7,7 @@ from contextlib import contextmanager from typing import Iterator, Mapping, Optional +import dagster._check as check from dagster._core.definitions.resource_annotation import ResourceParam from dagster._core.errors import DagsterExternalExecutionError from dagster._core.execution.context.compute import OpExecutionContext @@ -38,9 +39,25 @@ class _ExtDatabricks(ExtClient): description="An optional dict of environment variables to pass to the subprocess.", ) - def __init__(self, client: WorkspaceClient, env: Optional[Mapping[str, str]] = None): + def __init__( + self, + client: WorkspaceClient, + env: Optional[Mapping[str, str]] = None, + context_injector: Optional[ExtContextInjector] = None, + message_reader: Optional[ExtMessageReader] = None, + ): self.client = client self.env = env + self.context_injector = check.opt_inst_param( + context_injector, + "context_injector", + ExtContextInjector, + ) or ExtDbfsContextInjector(client=self.client) + self.message_reader = check.opt_inst_param( + message_reader, + "message_reader", + ExtMessageReader, + ) or ExtDbfsMessageReader(client=self.client) def run( self, @@ -48,8 +65,6 @@ def run( *, context: OpExecutionContext, extras: Optional[ExtExtras] = None, - context_injector: Optional[ExtContextInjector] = None, - message_reader: Optional[ExtMessageReader] = None, submit_args: Optional[Mapping[str, str]] = None, ) -> None: """Run a Databricks job with the EXT protocol. @@ -66,8 +81,8 @@ def run( with ext_protocol( context=context, extras=extras, - context_injector=context_injector or ExtDbfsContextInjector(client=self.client), - message_reader=message_reader or ExtDbfsMessageReader(client=self.client), + context_injector=self.context_injector, + message_reader=self.message_reader, ) as ext_context: submit_task_dict = task.as_dict() submit_task_dict["new_cluster"]["spark_env_vars"] = { @@ -83,6 +98,7 @@ def run( while True: run = self.client.jobs.get_run(run_id) + context.log.info(f"Run state: {run.state.life_cycle_state}") if run.state.life_cycle_state in ( jobs.RunLifeCycleState.TERMINATED, jobs.RunLifeCycleState.SKIPPED, @@ -138,17 +154,13 @@ def __init__(self, *, interval: int = 10, client: WorkspaceClient): self.dbfs_client = files.DbfsAPI(client.api_client) @contextmanager - def setup(self) -> Iterator[None]: + def setup(self) -> Iterator[ExtParams]: with dbfs_tempdir(self.dbfs_client) as tempdir: self.tempdir = tempdir - yield - - def get_params(self) -> ExtParams: - return {"path": self.tempdir} + yield {"path": tempdir} - def download_messages_chunk(self, index: int) -> Optional[str]: - assert self.tempdir - message_path = os.path.join(self.tempdir, f"{index}.json") + def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]: + message_path = os.path.join(params["path"], f"{index}.json") try: raw_message = self.dbfs_client.read(message_path) return raw_message.data