Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Sep 14, 2023
1 parent 5f27ccb commit 066249e
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 50 deletions.
1 change: 0 additions & 1 deletion python_modules/dagster/dagster/_core/ext/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def run(
*,
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
message_reader: Optional["ExtMessageReader"] = None,
) -> None:
...

Expand Down
20 changes: 10 additions & 10 deletions python_modules/dagster/dagster/_core/ext/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,43 +130,43 @@ def read_messages(
self,
handler: "ExtMessageHandler",
) -> Iterator[ExtParams]:
with self.setup():
with self.setup() as params:
is_task_complete = Event()
thread = None
try:
thread = Thread(
target=self._reader_thread,
args=(
handler,
params,
is_task_complete,
),
daemon=True,
)
thread.start()
yield self.get_params()
yield params
finally:
is_task_complete.set()
if thread:
thread.join()

@contextmanager
def setup(self) -> Iterator[None]:
yield

@abstractmethod
def get_params(self) -> ExtParams:
@contextmanager
def setup(self) -> Iterator[ExtParams]:
...

@abstractmethod
def download_messages_chunk(self, index: int) -> Optional[str]:
def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]:
...

def _reader_thread(self, handler: "ExtMessageHandler", is_task_complete: Event) -> None:
def _reader_thread(
self, handler: "ExtMessageHandler", params: ExtParams, is_task_complete: Event
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
chunk = self.download_messages_chunk(self.counter)
chunk = self.download_messages_chunk(self.counter, params)
start_or_last_download = now
if chunk:
for line in chunk.split("\n"):
Expand Down
10 changes: 6 additions & 4 deletions python_modules/libraries/dagster-aws/dagster_aws/ext.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
import string
from typing import Optional
from contextlib import contextmanager
from typing import Iterator, Optional

import boto3
import dagster._check as check
Expand All @@ -18,10 +19,11 @@ def __init__(self, *, interval: int = 10, bucket: str, client: boto3.client):
self.key_prefix = "".join(random.choices(string.ascii_letters, k=30))
self.client = client

def get_params(self) -> ExtParams:
return {"bucket": self.bucket, "key_prefix": self.key_prefix}
@contextmanager
def setup(self) -> Iterator[ExtParams]:
yield {"bucket": self.bucket, "key_prefix": self.key_prefix}

def download_messages_chunk(self, index: int) -> Optional[str]:
def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]:
key = f"{self.key_prefix}/{index}.json"
try:
obj = self.client.get_object(Bucket=self.bucket, Key=key)
Expand Down
48 changes: 26 additions & 22 deletions python_modules/libraries/dagster-databricks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
The docs for `dagster-databricks` can be found
[here](https://docs.dagster.io/_apidocs/libraries/dagster-databricks).

## EXT Example
## ext example

This package includes a prototype API for launching databricks jobs with
Dagster's EXT protocol. There are two ways to use the API:

### (1) `ExtDatabricks` resource

The `ExtDatabricks` resource provides a high-level API for launching
databricks jobs using Dagster's EXT protocol.
databricks jobs using Dagster's ext protocol.

It takes a single `databricks.sdk.service.jobs.SubmitTask` specification. After
setting up EXT communications channels (which by default use DBFS), it injects
the information needed to connect to these channels from Databricks into the
task specification. It then launches a Databricks job by passing the
specification to `WorkspaceClient.jobs.submit`. It polls the job state and
exits gracefully on success or failure:
`ExtDatabricks.run` takes a single `databricks.sdk.service.jobs.SubmitTask`
specification. After setting up ext communications channels (which by default
use DBFS), it injects the information needed to connect to these channels from
Databricks into the task specification. It then launches a Databricks job by
passing the specification to `WorkspaceClient.jobs.submit`. It polls the job
state and exits gracefully on success or failure:


```
Expand Down Expand Up @@ -46,9 +46,10 @@ def databricks_asset(context: AssetExecutionContext, ext: ExtDatabricks):
}
})
# arbitrary json-serializable data you want access to from the ExtContext
# in the databricks runtime
extras = {"foo": "bar"}
# Arbitrary json-serializable data you want access to from the `ExtContext`
# in the databricks runtime. Assume `sample_rate` is a parameter used by
# the target job's business logic.
extras = {"sample_rate": 1.0}
# synchronously execute the databricks job
ext.run(
Expand Down Expand Up @@ -87,10 +88,10 @@ context = init_dagster_ext(
)
# Access the `extras` dict passed when launching the job from Dagster.
foo_extra = context.get_extra("foo")
sample_rate = context.get_extra("sample_rate")
# Stream log message back to Dagster
context.log(f"Extra for key 'foo': {foo_extra}")
context.log(f"Using sample rate: {sample_rate}")
# ... your code that computes and persists the asset
Expand All @@ -106,12 +107,12 @@ context.report_asset_data_version(get_data_version())
### (2) `ext_protocol` context manager

Internally, `ExtDatabricks` is using the `ext_protocol` context manager to set
up communications. If you'd prefer more control over how your databricks job is
launched and polled, you can skip `ExtDatabricks` and use this lower level API
directly. All that is necessary is that (1) your Databricks job be launched within
the scope of the `ext_process` context manager; (2) your job is launched on a
cluster containing the environment variables available on the yielded
`ext_context`.
up communications. If you have existing code to launch/poll the job you do not
want to change, or you just want more control than is permitted by
`ExtDatabricks`, you can use this lower level API directly. All that is
necessary is that (1) your Databricks job be launched within the scope of the
`ext_process` context manager; (2) your job is launched on a cluster containing
the environment variables available on the yielded `ext_context`.

```
import os
Expand All @@ -128,9 +129,12 @@ def databricks_asset(context: AssetExecutionContext):
token=os.environ["DATABRICKS_TOKEN"],
)
extras = {"foo": "bar"}
# Arbitrary json-serializable data you want access to from the `ExtContext`
# in the databricks runtime. Assume `sample_rate` is a parameter used by
# the target job's business logic.
extras = {"sample_rate": 1.0}
# Sets up EXT communications channels
# Sets up ext communications channels
with ext_protocol(
context=context,
extras=extras,
Expand All @@ -141,7 +145,7 @@ def databricks_asset(context: AssetExecutionContext):
# Dict[str, str] with environment variables containing ext comms info.
env_vars = ext_context.get_external_process_env_vars()
# Some function that handles launching/montoring of the databricks job.
# Some function that handles launching/monitoring of the databricks job.
# It must ensure that the `env_vars` are set on the executing cluster.
custom_databricks_launch_code(env_vars)
```
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from contextlib import contextmanager
from typing import Iterator, Mapping, Optional

import dagster._check as check
from dagster._core.definitions.resource_annotation import ResourceParam
from dagster._core.errors import DagsterExternalExecutionError
from dagster._core.execution.context.compute import OpExecutionContext
Expand Down Expand Up @@ -38,18 +39,32 @@ class _ExtDatabricks(ExtClient):
description="An optional dict of environment variables to pass to the subprocess.",
)

def __init__(self, client: WorkspaceClient, env: Optional[Mapping[str, str]] = None):
def __init__(
self,
client: WorkspaceClient,
env: Optional[Mapping[str, str]] = None,
context_injector: Optional[ExtContextInjector] = None,
message_reader: Optional[ExtMessageReader] = None,
):
self.client = client
self.env = env
self.context_injector = check.opt_inst_param(
context_injector,
"context_injector",
ExtContextInjector,
) or ExtDbfsContextInjector(client=self.client)
self.message_reader = check.opt_inst_param(
message_reader,
"message_reader",
ExtMessageReader,
) or ExtDbfsMessageReader(client=self.client)

def run(
self,
task: jobs.SubmitTask,
*,
context: OpExecutionContext,
extras: Optional[ExtExtras] = None,
context_injector: Optional[ExtContextInjector] = None,
message_reader: Optional[ExtMessageReader] = None,
submit_args: Optional[Mapping[str, str]] = None,
) -> None:
"""Run a Databricks job with the EXT protocol.
Expand All @@ -66,8 +81,8 @@ def run(
with ext_protocol(
context=context,
extras=extras,
context_injector=context_injector or ExtDbfsContextInjector(client=self.client),
message_reader=message_reader or ExtDbfsMessageReader(client=self.client),
context_injector=self.context_injector,
message_reader=self.message_reader,
) as ext_context:
submit_task_dict = task.as_dict()
submit_task_dict["new_cluster"]["spark_env_vars"] = {
Expand All @@ -83,6 +98,7 @@ def run(

while True:
run = self.client.jobs.get_run(run_id)
context.log.info(f"Run state: {run.state.life_cycle_state}")
if run.state.life_cycle_state in (
jobs.RunLifeCycleState.TERMINATED,
jobs.RunLifeCycleState.SKIPPED,
Expand Down Expand Up @@ -138,17 +154,13 @@ def __init__(self, *, interval: int = 10, client: WorkspaceClient):
self.dbfs_client = files.DbfsAPI(client.api_client)

@contextmanager
def setup(self) -> Iterator[None]:
def setup(self) -> Iterator[ExtParams]:
with dbfs_tempdir(self.dbfs_client) as tempdir:
self.tempdir = tempdir
yield

def get_params(self) -> ExtParams:
return {"path": self.tempdir}
yield {"path": tempdir}

def download_messages_chunk(self, index: int) -> Optional[str]:
assert self.tempdir
message_path = os.path.join(self.tempdir, f"{index}.json")
def download_messages_chunk(self, index: int, params: ExtParams) -> Optional[str]:
message_path = os.path.join(params["path"], f"{index}.json")
try:
raw_message = self.dbfs_client.read(message_path)
return raw_message.data
Expand Down

0 comments on commit 066249e

Please sign in to comment.