From 5072a064acefdc0835c525a006f937bd7e8f1b22 Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Wed, 11 Oct 2023 16:28:17 -0400 Subject: [PATCH] [pipes] databricks unstructured log forwarding (#16674) ## Summary & Motivation This adds stdout/stderr forwarding to the dagster-databricks pipes integration. It was a long road getting here with several dead ends. The current approach in this PR is to modify `PipesBlobStoreMessageReader` with `forward_{stdout,stderr}` boolean params and corresponding hooks for downloading stdout, stderr chunks. If `forward_{stdout,stderr}` is enabled, threads will be launched for the streams (alongside the message chunk thread) that periodically download stderr/stdout chunks and write them to the corresponding orchestration process streams. In the `PipesDbfsMessageReader`, instead of using an incrementing counter (as is used for messages), the stdout/stderr chunk downloaders are written to track a string index in the file corresponding to the stream. We repeatedly download the full file. Every time the file is downloaded, we only forward starting from the offset index. This approach of repeatedly downloading the full file and applying the offset only on the orchestration end can surely be improved to just download starting from the offset, but I did not implement that yet (there are some concerns around getting the indexing right given that the files are stored as base64, I think with padding). While it is possible that other integrations would need to make some changes on the pipes end, I didn't need to make any for databricks, because a DBFS location for `stdout`/`stderr` is configured when launching the job, so we don't need to do anything in the orchestration process. This introduces a potential asymmetry between the `PipesMessageReader` and `PipesMessageWriter`-- probably we will end up with just a `PipesReader`. There are definitely some other rough patches here: - Databricks does not let you directly configure the directory where stdout/stderr are written. Instead you set a root directory, and then logs get stored in `//driver/{stdout,stderr}`. This introduces a difficulty because the cluster id does not exist until the job is launched (and you can't set it manually). Because the message reader gets set up before the job is launched, the message reader doesn't know where to look. - I got around this by setting the log root to a temporary directory and polling that directory for the first child to appear, which will be where the logs are stored. This is not ideal because users may want to retain the logs in DBFS. - Another approach would be to send the cluster id back in the new `opened` message, but to then thread this into the message reader requires additional plumbing work. For those who want to play with this, the workflow is to repeatedly run `dagster_databricks_tests/test_pipes.py::test_pipes_client`. This requires you to have `DATABRICKS_HOST` and `DATABRICKS_TOKEN` set in your env. `DATABRICKS_HOST` should be `https://dbc-07902917-6487.cloud.databricks.com`. `DATABRICKS_TOKEN` should be set to a value you generate by going to User Settings > Developer > Access Tokens in the Databricks UI. ## How I Tested These Changes Tested via `capsys` to make sure logs are forwarded. --- .../dagster-pipes/dagster_pipes/__init__.py | 29 ++-- python_modules/dagster/dagster/__init__.py | 1 + .../dagster/dagster/_core/pipes/utils.py | 150 +++++++++++++++--- .../dagster-aws/dagster_aws/pipes.py | 31 +++- .../dagster_databricks/__init__.py | 1 + .../dagster_databricks/pipes.py | 115 ++++++++++++-- .../dagster_databricks_tests/test_pipes.py | 63 ++++++-- 7 files changed, 332 insertions(+), 58 deletions(-) diff --git a/python_modules/dagster-pipes/dagster_pipes/__init__.py b/python_modules/dagster-pipes/dagster_pipes/__init__.py index 5ef63494de3e8..78905faab8883 100644 --- a/python_modules/dagster-pipes/dagster_pipes/__init__.py +++ b/python_modules/dagster-pipes/dagster_pipes/__init__.py @@ -11,7 +11,8 @@ from abc import ABC, abstractmethod from contextlib import ExitStack, contextmanager from io import StringIO -from threading import Event, Lock, Thread +from queue import Queue +from threading import Event, Thread from typing import ( IO, TYPE_CHECKING, @@ -66,7 +67,7 @@ def _make_message(method: str, params: Optional[Mapping[str, Any]]) -> "PipesMes class PipesMessage(TypedDict): - """A message sent from the orchestration process to the external process.""" + """A message sent from the external process to the orchestration process.""" __dagster_pipes_version: str method: str @@ -513,19 +514,17 @@ class PipesBlobStoreMessageWriterChannel(PipesMessageWriterChannel): def __init__(self, *, interval: float = 10): self._interval = interval - self._lock = Lock() - self._buffer = [] + self._buffer: Queue[PipesMessage] = Queue() self._counter = 1 def write_message(self, message: PipesMessage) -> None: - with self._lock: - self._buffer.append(message) + self._buffer.put(message) def flush_messages(self) -> Sequence[PipesMessage]: - with self._lock: - messages = list(self._buffer) - self._buffer.clear() - return messages + items = [] + while not self._buffer.empty(): + items.append(self._buffer.get()) + return items @abstractmethod def upload_messages_chunk(self, payload: StringIO, index: int) -> None: ... @@ -546,15 +545,15 @@ def buffered_upload_loop(self) -> Iterator[None]: def _upload_loop(self, is_task_complete: Event) -> None: start_or_last_upload = datetime.datetime.now() while True: - num_pending = len(self._buffer) now = datetime.datetime.now() - if num_pending == 0 and is_task_complete.is_set(): + if self._buffer.empty() and is_task_complete.is_set(): break elif is_task_complete.is_set() or (now - start_or_last_upload).seconds > self._interval: payload = "\n".join([json.dumps(message) for message in self.flush_messages()]) - self.upload_messages_chunk(StringIO(payload), self._counter) - start_or_last_upload = now - self._counter += 1 + if len(payload) > 0: + self.upload_messages_chunk(StringIO(payload), self._counter) + start_or_last_upload = now + self._counter += 1 time.sleep(1) diff --git a/python_modules/dagster/dagster/__init__.py b/python_modules/dagster/dagster/__init__.py index 17176ae4ef2e7..d4e60a04402da 100644 --- a/python_modules/dagster/dagster/__init__.py +++ b/python_modules/dagster/dagster/__init__.py @@ -485,6 +485,7 @@ from dagster._core.pipes.subprocess import PipesSubprocessClient as PipesSubprocessClient from dagster._core.pipes.utils import ( PipesBlobStoreMessageReader as PipesBlobStoreMessageReader, + PipesBlobStoreStdioReader as PipesBlobStoreStdioReader, PipesEnvContextInjector as PipesEnvContextInjector, PipesFileContextInjector as PipesFileContextInjector, PipesFileMessageReader as PipesFileMessageReader, diff --git a/python_modules/dagster/dagster/_core/pipes/utils.py b/python_modules/dagster/dagster/_core/pipes/utils.py index b2c6e04928cc0..866b8b005fd2f 100644 --- a/python_modules/dagster/dagster/_core/pipes/utils.py +++ b/python_modules/dagster/dagster/_core/pipes/utils.py @@ -4,10 +4,10 @@ import sys import tempfile import time -from abc import abstractmethod +from abc import ABC, abstractmethod from contextlib import contextmanager from threading import Event, Thread -from typing import Iterator, Optional +from typing import Iterator, Optional, TextIO from dagster_pipes import ( PIPES_PROTOCOL_VERSION_FIELD, @@ -144,7 +144,7 @@ def read_messages( self, handler: "PipesMessageHandler", ) -> Iterator[PipesParams]: - """Set up a thread to read streaming messages from teh external process by tailing the + """Set up a thread to read streaming messages from the external process by tailing the target file. Args: @@ -208,6 +208,12 @@ def no_messages_debug_text(self) -> str: return "Attempted to read messages from a local temporary file." +# Number of seconds to wait after an external process has completed for stdio logs to become +# available. If this is exceeded, proceed with exiting without picking up logs. +WAIT_FOR_STDIO_LOGS_TIMEOUT = 60 + + +@experimental class PipesBlobStoreMessageReader(PipesMessageReader): """Message reader that reads a sequence of message chunks written by an external process into a blob store such as S3, Azure blob storage, or GCS. @@ -221,16 +227,37 @@ class PipesBlobStoreMessageReader(PipesMessageReader): counter (starting from 1) on successful write, keeping counters on the read and write end in sync. + If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when + `read_messages` is called. If they are not passed, then the reader performs no stdout/stderr + forwarding. + Args: interval (float): interval in seconds between attempts to download a chunk + stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs. + stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs. """ interval: float counter: int + stdout_reader: "PipesBlobStoreStdioReader" + stderr_reader: "PipesBlobStoreStdioReader" - def __init__(self, interval: float = 10): + def __init__( + self, + interval: float = 10, + stdout_reader: Optional["PipesBlobStoreStdioReader"] = None, + stderr_reader: Optional["PipesBlobStoreStdioReader"] = None, + ): self.interval = interval self.counter = 1 + self.stdout_reader = ( + check.opt_inst_param(stdout_reader, "stdout_reader", PipesBlobStoreStdioReader) + or PipesNoOpStdioReader() + ) + self.stderr_reader = ( + check.opt_inst_param(stderr_reader, "stderr_reader", PipesBlobStoreStdioReader) + or PipesNoOpStdioReader() + ) @contextmanager def read_messages( @@ -249,23 +276,34 @@ def read_messages( """ with self.get_params() as params: is_task_complete = Event() - thread = None + messages_thread = None try: - thread = Thread( - target=self._reader_thread, - args=( - handler, - params, - is_task_complete, - ), - daemon=True, + messages_thread = Thread( + target=self._messages_thread, args=(handler, params, is_task_complete) ) - thread.start() + messages_thread.start() + self.stdout_reader.start(params, is_task_complete) + self.stderr_reader.start(params, is_task_complete) yield params finally: + self.wait_for_stdio_logs(params) is_task_complete.set() - if thread: - thread.join() + if messages_thread: + messages_thread.join() + self.stdout_reader.stop() + self.stderr_reader.stop() + + # In cases where we are forwarding logs, in some cases the logs might not be written out until + # after the run completes. We wait for them to exist. + def wait_for_stdio_logs(self, params): + start_or_last_download = datetime.datetime.now() + while ( + datetime.datetime.now() - start_or_last_download + ).seconds <= WAIT_FOR_STDIO_LOGS_TIMEOUT and ( + (self.stdout_reader and not self.stdout_reader.is_ready(params)) + or (self.stderr_reader and not self.stderr_reader.is_ready(params)) + ): + time.sleep(5) @abstractmethod @contextmanager @@ -280,15 +318,18 @@ def get_params(self) -> Iterator[PipesParams]: @abstractmethod def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: ... - def _reader_thread( - self, handler: "PipesMessageHandler", params: PipesParams, is_task_complete: Event + def _messages_thread( + self, + handler: "PipesMessageHandler", + params: PipesParams, + is_task_complete: Event, ) -> None: start_or_last_download = datetime.datetime.now() while True: now = datetime.datetime.now() if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set(): - chunk = self.download_messages_chunk(self.counter, params) start_or_last_download = now + chunk = self.download_messages_chunk(self.counter, params) if chunk: for line in chunk.split("\n"): message = json.loads(line) @@ -299,12 +340,83 @@ def _reader_thread( time.sleep(1) +class PipesBlobStoreStdioReader(ABC): + @abstractmethod + def start(self, params: PipesParams, is_task_complete: Event) -> None: ... + + @abstractmethod + def stop(self) -> None: ... + + @abstractmethod + def is_ready(self, params: PipesParams) -> bool: ... + + +@experimental +class PipesChunkedStdioReader(PipesBlobStoreStdioReader): + """Reader for reading stdout/stderr logs from a blob store such as S3, Azure blob storage, or GCS. + + Args: + interval (float): interval in seconds between attempts to download a chunk. + target_stream (TextIO): The stream to which to write the logs. Typcially `sys.stdout` or `sys.stderr`. + """ + + def __init__(self, *, interval: float = 10, target_stream: TextIO): + self.interval = interval + self.target_stream = target_stream + self.thread: Optional[Thread] = None + + @abstractmethod + def download_log_chunk(self, params: PipesParams) -> Optional[str]: ... + + def start(self, params: PipesParams, is_task_complete: Event) -> None: + self.thread = Thread(target=self._reader_thread, args=(params, is_task_complete)) + self.thread.start() + + def stop(self) -> None: + if self.thread: + self.thread.join() + + def _reader_thread( + self, + params: PipesParams, + is_task_complete: Event, + ) -> None: + start_or_last_download = datetime.datetime.now() + while True: + now = datetime.datetime.now() + if ( + (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set() + ) and self.is_ready(params): + start_or_last_download = now + chunk = self.download_log_chunk(params) + if chunk: + self.target_stream.write(chunk) + elif is_task_complete.is_set(): + break + time.sleep(self.interval) + + +class PipesNoOpStdioReader(PipesBlobStoreStdioReader): + """Default implementation for a pipes stdio reader that does nothing.""" + + def start(self, params: PipesParams, is_task_complete: Event) -> None: + pass + + def stop(self) -> None: + pass + + def is_ready(self, params: PipesParams) -> bool: + return True + + def extract_message_or_forward_to_stdout(handler: "PipesMessageHandler", log_line: str): # exceptions as control flow, you love to see it try: message = json.loads(log_line) if PIPES_PROTOCOL_VERSION_FIELD in message.keys(): handler.handle_message(message) + else: + sys.stdout.writelines((log_line, "\n")) except Exception: # move non-message logs in to stdout for compute log capture sys.stdout.writelines((log_line, "\n")) diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py index df949134e428f..53e6a21adba51 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py @@ -9,12 +9,37 @@ from dagster._core.pipes.client import ( PipesParams, ) -from dagster._core.pipes.utils import PipesBlobStoreMessageReader +from dagster._core.pipes.utils import PipesBlobStoreMessageReader, PipesBlobStoreStdioReader class PipesS3MessageReader(PipesBlobStoreMessageReader): - def __init__(self, *, interval: float = 10, bucket: str, client: boto3.client): - super().__init__(interval=interval) + """Message reader that reads messages by periodically reading message chunks from a specified S3 + bucket. + + If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when + `read_messages` is called. If they are not passed, then the reader performs no stdout/stderr + forwarding. + + Args: + interval (float): interval in seconds between attempts to download a chunk + bucket (str): The S3 bucket to read from. + client (WorkspaceClient): A boto3 client. + stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs. + stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs. + """ + + def __init__( + self, + *, + interval: float = 10, + bucket: str, + client: boto3.client, + stdout_reader: Optional[PipesBlobStoreStdioReader] = None, + stderr_reader: Optional[PipesBlobStoreStdioReader] = None, + ): + super().__init__( + interval=interval, stdout_reader=stdout_reader, stderr_reader=stderr_reader + ) self.bucket = check.str_param(bucket, "bucket") self.key_prefix = "".join(random.choices(string.ascii_letters, k=30)) self.client = client diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py b/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py index b7b5210dab1f1..979c2d3fd5ec8 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py @@ -28,6 +28,7 @@ PipesDatabricksClient as PipesDatabricksClient, PipesDbfsContextInjector as PipesDbfsContextInjector, PipesDbfsMessageReader as PipesDbfsMessageReader, + PipesDbfsStdioReader as PipesDbfsStdioReader, ) from .resources import ( DatabricksClientResource as DatabricksClientResource, diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py b/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py index 39ecd08d28316..9fc9817aa0fde 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py @@ -3,9 +3,10 @@ import os import random import string +import sys import time -from contextlib import contextmanager -from typing import Iterator, Mapping, Optional +from contextlib import ExitStack, contextmanager +from typing import Iterator, Literal, Mapping, Optional, TextIO import dagster._check as check from dagster._annotations import experimental @@ -20,9 +21,12 @@ ) from dagster._core.pipes.utils import ( PipesBlobStoreMessageReader, + PipesBlobStoreStdioReader, + PipesChunkedStdioReader, open_pipes_session, ) from dagster_pipes import ( + DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES, PipesContextData, PipesExtras, PipesParams, @@ -31,6 +35,10 @@ from databricks.sdk.service import files, jobs from pydantic import Field +# Number of seconds between status checks on Databricks jobs launched by the +# `PipesDatabricksClient`. +_RUN_POLL_INTERVAL = 5 + @experimental class _PipesDatabricksClient(PipesClient): @@ -69,7 +77,15 @@ def __init__( message_reader, "message_reader", PipesMessageReader, - ) or PipesDbfsMessageReader(client=self.client) + ) or PipesDbfsMessageReader( + client=self.client, + stdout_reader=PipesDbfsStdioReader( + client=self.client, remote_log_name="stdout", target_stream=sys.stdout + ), + stderr_reader=PipesDbfsStdioReader( + client=self.client, remote_log_name="stderr", target_stream=sys.stderr + ), + ) @classmethod def _is_dagster_maintained(cls) -> bool: @@ -113,6 +129,13 @@ def run( **(self.env or {}), **pipes_session.get_bootstrap_env_vars(), } + cluster_log_root = pipes_session.get_bootstrap_params()[ + DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES["messages"] + ].get("cluster_log_root") + if cluster_log_root is not None: + submit_task_dict["new_cluster"]["cluster_log_conf"] = { + "dbfs": {"destination": f"dbfs:{cluster_log_root}"} + } task = jobs.SubmitTask.from_dict(submit_task_dict) run_id = self.client.jobs.submit( tasks=[task], @@ -138,7 +161,7 @@ def run( raise DagsterPipesExecutionError( f"Error running Databricks job: {run.state.state_message}" ) - time.sleep(5) + time.sleep(_RUN_POLL_INTERVAL) return PipesClientCompletedInvocation(tuple(pipes_session.get_results())) @@ -201,25 +224,45 @@ class PipesDbfsMessageReader(PipesBlobStoreMessageReader): """Message reader that reads messages by periodically reading message chunks from an automatically-generated temporary directory on DBFS. + If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when + `read_messages` is called. If they are not passed, then the reader performs no stdout/stderr + forwarding. + Args: interval (float): interval in seconds between attempts to download a chunk client (WorkspaceClient): A databricks `WorkspaceClient` object. + cluster_log_root (Optional[str]): The root path on DBFS where the cluster logs are written. + If set, this will be used to read stderr/stdout logs. + stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs. + stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs. """ - def __init__(self, *, interval: int = 10, client: WorkspaceClient): - super().__init__(interval=interval) + def __init__( + self, + *, + interval: float = 10, + client: WorkspaceClient, + stdout_reader: Optional[PipesBlobStoreStdioReader] = None, + stderr_reader: Optional[PipesBlobStoreStdioReader] = None, + ): + super().__init__( + interval=interval, stdout_reader=stdout_reader, stderr_reader=stderr_reader + ) self.dbfs_client = files.DbfsAPI(client.api_client) @contextmanager def get_params(self) -> Iterator[PipesParams]: - with dbfs_tempdir(self.dbfs_client) as tempdir: - yield {"path": tempdir} + with ExitStack() as stack: + params: PipesParams = {} + params["path"] = stack.enter_context(dbfs_tempdir(self.dbfs_client)) + if self.stdout_reader or self.stderr_reader: + params["cluster_log_root"] = stack.enter_context(dbfs_tempdir(self.dbfs_client)) + yield params def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: message_path = os.path.join(params["path"], f"{index}.json") try: raw_message = self.dbfs_client.read(message_path) - # Files written to dbfs using the Python IO interface used in PipesDbfsMessageWriter are # base64-encoded. return base64.b64decode(raw_message.data).decode("utf-8") @@ -235,3 +278,57 @@ def no_messages_debug_text(self) -> str: " PipesDbfsMessageWriter to be explicitly passed to open_dagster_pipes in the external" " process." ) + + +@experimental +class PipesDbfsStdioReader(PipesChunkedStdioReader): + """Reader that reads stdout/stderr logs from DBFS. + + Args: + interval (float): interval in seconds between attempts to download a log chunk + remote_log_name (Literal["stdout", "stderr"]): The name of the log file to read. + target_stream (TextIO): The stream to which to forward log chunk that have been read. + client (WorkspaceClient): A databricks `WorkspaceClient` object. + """ + + def __init__( + self, + *, + interval: float = 10, + remote_log_name: Literal["stdout", "stderr"], + target_stream: TextIO, + client: WorkspaceClient, + ): + super().__init__(interval=interval, target_stream=target_stream) + self.dbfs_client = files.DbfsAPI(client.api_client) + self.remote_log_name = remote_log_name + self.log_position = 0 + self.log_path = None + + def download_log_chunk(self, params: PipesParams) -> Optional[str]: + log_path = self._get_log_path(params) + if log_path is None: + return None + else: + try: + read_response = self.dbfs_client.read(log_path) + assert read_response.data + content = base64.b64decode(read_response.data).decode("utf-8") + chunk = content[self.log_position :] + self.log_position = len(content) + return chunk + except IOError: + return None + + def is_ready(self, params: PipesParams) -> bool: + return self._get_log_path(params) is not None + + # The directory containing logs will not exist until either 5 minutes have elapsed or the + # job has finished. + def _get_log_path(self, params: PipesParams) -> Optional[str]: + if self.log_path is None: + log_root_path = os.path.join(params["cluster_log_root"]) + child_dirs = list(self.dbfs_client.list(log_root_path)) + if len(child_dirs) > 0: + self.log_path = f"dbfs:{child_dirs[0].path}/driver/{self.remote_log_name}" + return self.log_path diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py b/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py index 45d6656c101cd..1cafa135b26df 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py @@ -1,6 +1,8 @@ import base64 import inspect import os +import re +import subprocess import textwrap from contextlib import contextmanager from typing import Any, Callable, Iterator @@ -16,6 +18,8 @@ def script_fn(): + import sys + from dagster_pipes import ( PipesDbfsContextLoader, PipesDbfsMessageWriter, @@ -23,11 +27,13 @@ def script_fn(): ) with open_dagster_pipes( - context_loader=PipesDbfsContextLoader(), message_writer=PipesDbfsMessageWriter() + context_loader=PipesDbfsContextLoader(), + message_writer=PipesDbfsMessageWriter(), ) as context: multiplier = context.get_extra("multiplier") value = 2 * multiplier - + print("hello from databricks stdout") # noqa: T201 + print("hello from databricks stderr", file=sys.stderr) # noqa: T201 context.log.info(f"{context.asset_key}: {2} * {multiplier} = {value}") context.report_asset_materialization( metadata={"value": value}, @@ -62,8 +68,33 @@ def client() -> WorkspaceClient: TASK_KEY = "DAGSTER_PIPES_TASK" +DAGSTER_PIPES_WHL_FILENAME = "dagster_pipes-1!0+dev-py3-none-any.whl" + # This has been manually uploaded to a test DBFS workspace. -DAGSTER_PIPES_WHL_PATH = "dbfs:/FileStore/jars/dagster_pipes-1!0+dev-py3-none-any.whl" +DAGSTER_PIPES_WHL_PATH = f"dbfs:/FileStore/jars/{DAGSTER_PIPES_WHL_FILENAME}" + + +def get_repo_root() -> str: + path = os.path.dirname(__file__) + while not os.path.exists(os.path.join(path, ".git")): + path = os.path.dirname(path) + return path + + +# Upload the Dagster Pipes wheel to DBFS. Use this fixture to avoid needing to manually reupload +# dagster-pipes if it has changed between test runs. +@contextmanager +def upload_dagster_pipes_whl(client: WorkspaceClient) -> Iterator[None]: + repo_root = get_repo_root() + orig_wd = os.getcwd() + dagster_pipes_root = os.path.join(repo_root, "python_modules", "dagster-pipes") + os.chdir(dagster_pipes_root) + subprocess.check_call(["python", "setup.py", "bdist_wheel"]) + subprocess.check_call( + ["dbfs", "cp", "--overwrite", f"dist/{DAGSTER_PIPES_WHL_FILENAME}", DAGSTER_PIPES_WHL_PATH] + ) + os.chdir(orig_wd) + yield def _make_submit_task(path: str) -> jobs.SubmitTask: @@ -83,25 +114,33 @@ def _make_submit_task(path: str) -> jobs.SubmitTask: @pytest.mark.skipif(IS_BUILDKITE, reason="Not configured to run on BK yet.") -def test_pipes_client(client: WorkspaceClient): +def test_pipes_client(capsys, client: WorkspaceClient): @asset def number_x(context: AssetExecutionContext, pipes_client: PipesDatabricksClient): - with temp_script(script_fn, client) as script_path: - task = _make_submit_task(script_path) - return pipes_client.run( - task=task, - context=context, - extras={"multiplier": 2, "storage_root": "fake"}, - ).get_results() + with upload_dagster_pipes_whl(client): + with temp_script(script_fn, client) as script_path: + task = _make_submit_task(script_path) + return pipes_client.run( + task=task, + context=context, + extras={"multiplier": 2, "storage_root": "fake"}, + ).get_results() result = materialize( [number_x], - resources={"pipes_client": PipesDatabricksClient(client)}, + resources={ + "pipes_client": PipesDatabricksClient( + client, + ) + }, raise_on_error=False, ) assert result.success mats = result.asset_materializations_for_node(number_x.op.name) assert mats[0].metadata["value"].value == 4 + captured = capsys.readouterr() + assert re.search(r"hello from databricks stdout\n", captured.out, re.MULTILINE) + assert re.search(r"hello from databricks stderr\n", captured.err, re.MULTILINE) @pytest.mark.skipif(IS_BUILDKITE, reason="Not configured to run on BK yet.")