From 73b846e52b114deefd6c4cfd672fcd4bea1cc440 Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Tue, 10 Oct 2023 14:55:15 -0400 Subject: [PATCH] address feedback --- .../dagster/dagster/_core/pipes/utils.py | 109 ++++++++---------- .../dagster-aws/dagster_aws/pipes.py | 27 ++++- .../dagster_databricks/pipes.py | 105 ++++++++++------- .../dagster_databricks_tests/test_pipes.py | 9 +- 4 files changed, 142 insertions(+), 108 deletions(-) diff --git a/python_modules/dagster/dagster/_core/pipes/utils.py b/python_modules/dagster/dagster/_core/pipes/utils.py index 9051ec76b9534..6891ae1cefbd4 100644 --- a/python_modules/dagster/dagster/_core/pipes/utils.py +++ b/python_modules/dagster/dagster/_core/pipes/utils.py @@ -4,10 +4,10 @@ import sys import tempfile import time -from abc import abstractmethod +from abc import ABC, abstractmethod from contextlib import contextmanager from threading import Event, Thread -from typing import Iterator, Optional +from typing import Iterator, Optional, TextIO from dagster_pipes import ( PIPES_PROTOCOL_VERSION_FIELD, @@ -226,24 +226,31 @@ class PipesBlobStoreMessageReader(PipesMessageReader): counter (starting from 1) on successful write, keeping counters on the read and write end in sync. + If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when + `read_messages` is called. If they are not passed, then the reader performs no stdout/stderr + forwarding. + Args: interval (float): interval in seconds between attempts to download a chunk - forward_stdout (bool): whether to forward stdout from the pipes process to Dagster. - forward_stderr (bool): whether to forward stderr from the pipes process to Dagster. + stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs. + stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs. """ interval: float counter: int - forward_stdout: bool - forward_stderr: bool + stdout_reader: Optional["PipesBlobStoreStdioReader"] + stderr_reader: Optional["PipesBlobStoreStdioReader"] def __init__( - self, interval: float = 10, forward_stdout: bool = False, forward_stderr: bool = False + self, + interval: float = 10, + stdout_reader: Optional["PipesBlobStoreStdioReader"] = None, + stderr_reader: Optional["PipesBlobStoreStdioReader"] = None, ): self.interval = interval self.counter = 1 - self.forward_stdout = forward_stdout - self.forward_stderr = forward_stderr + self.stdout_reader = stdout_reader + self.stderr_reader = stderr_reader @contextmanager def read_messages( @@ -263,33 +270,25 @@ def read_messages( with self.get_params() as params: is_task_complete = Event() messages_thread = None - stdout_thread = None - stderr_thread = None try: messages_thread = Thread( target=self._messages_thread, args=(handler, params, is_task_complete) ) messages_thread.start() - if self.forward_stdout: - stdout_thread = Thread( - target=self._stdout_thread, args=(params, is_task_complete) - ) - stdout_thread.start() - if self.forward_stderr: - stderr_thread = Thread( - target=self._stderr_thread, args=(params, is_task_complete) - ) - stderr_thread.start() + if self.stdout_reader: + self.stdout_reader.start(params, is_task_complete) + if self.stderr_reader: + self.stderr_reader.start(params, is_task_complete) yield params finally: self.wait_for_stdio_logs(params) is_task_complete.set() if messages_thread: messages_thread.join() - if stdout_thread: - stdout_thread.join() - if stderr_thread: - stderr_thread.join() + if self.stdout_reader: + self.stdout_reader.stop() + if self.stderr_reader: + self.stderr_reader.stop() # In cases where we are forwarding logs, in some cases the logs might not be written out until # after the run completes. We wait for them to exist. @@ -298,8 +297,8 @@ def wait_for_stdio_logs(self, params): while ( datetime.datetime.now() - start_or_last_download ).seconds <= WAIT_FOR_STDIO_LOGS_TIMEOUT and ( - (self.forward_stdout and not self.stdout_log_exists(params)) - or (self.forward_stderr and not self.stderr_log_exists(params)) + (self.stdout_reader and not self.stdout_reader.log_exists(params)) + or (self.stderr_reader and not self.stderr_reader.log_exists(params)) ): time.sleep(5) @@ -316,18 +315,6 @@ def get_params(self) -> Iterator[PipesParams]: @abstractmethod def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: ... - def download_stdout_chunk(self, params: PipesParams) -> Optional[str]: - raise NotImplementedError() - - def stdout_log_exists(self, params: PipesParams) -> bool: - raise NotImplementedError() - - def download_stderr_chunk(self, params: PipesParams) -> Optional[str]: - raise NotImplementedError() - - def stderr_log_exists(self, params: PipesParams) -> bool: - raise NotImplementedError() - def _messages_thread( self, handler: "PipesMessageHandler", @@ -349,24 +336,28 @@ def _messages_thread( break time.sleep(1) - def _stdout_thread( - self, - params: PipesParams, - is_task_complete: Event, - ) -> None: - start_or_last_download = datetime.datetime.now() - while True: - now = datetime.datetime.now() - if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set(): - start_or_last_download = now - chunk = self.download_stdout_chunk(params) - if chunk: - sys.stdout.write(chunk) - elif is_task_complete.is_set(): - break - time.sleep(10) - def _stderr_thread( +class PipesBlobStoreStdioReader(ABC): + def __init__(self, *, interval: float = 10, target_stream: TextIO): + self.interval = interval + self.target_stream = target_stream + self.thread: Optional[Thread] = None + + @abstractmethod + def download_log_chunk(self, params: PipesParams) -> Optional[str]: ... + + @abstractmethod + def log_exists(self, params: PipesParams) -> bool: ... + + def start(self, params: PipesParams, is_task_complete: Event) -> None: + self.thread = Thread(target=self._reader_thread, args=(params, is_task_complete)) + self.thread.start() + + def stop(self) -> None: + if self.thread: + self.thread.join() + + def _reader_thread( self, params: PipesParams, is_task_complete: Event, @@ -376,12 +367,12 @@ def _stderr_thread( now = datetime.datetime.now() if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set(): start_or_last_download = now - chunk = self.download_stderr_chunk(params) + chunk = self.download_log_chunk(params) if chunk: - sys.stderr.write(chunk) + self.target_stream.write(chunk) elif is_task_complete.is_set(): break - time.sleep(10) + time.sleep(self.interval) def extract_message_or_forward_to_stdout(handler: "PipesMessageHandler", log_line: str): diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py index df949134e428f..6483ebad3b3ab 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py @@ -9,12 +9,33 @@ from dagster._core.pipes.client import ( PipesParams, ) -from dagster._core.pipes.utils import PipesBlobStoreMessageReader +from dagster._core.pipes.utils import PipesBlobStoreMessageReader, PipesBlobStoreStdioReader class PipesS3MessageReader(PipesBlobStoreMessageReader): - def __init__(self, *, interval: float = 10, bucket: str, client: boto3.client): - super().__init__(interval=interval) + """Message reader that reads messages by periodically reading message chunks from a specified S3 + bucket. + + Args: + interval (float): interval in seconds between attempts to download a chunk + bucket (str): The S3 bucket to read from. + client (WorkspaceClient): A boto3 client. + stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs. + stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs. + """ + + def __init__( + self, + *, + interval: float = 10, + bucket: str, + client: boto3.client, + stdout_reader: Optional[PipesBlobStoreStdioReader] = None, + stderr_reader: Optional[PipesBlobStoreStdioReader] = None, + ): + super().__init__( + interval=interval, stdout_reader=stdout_reader, stderr_reader=stderr_reader + ) self.bucket = check.str_param(bucket, "bucket") self.key_prefix = "".join(random.choices(string.ascii_letters, k=30)) self.client = client diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py b/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py index a9aa5719fc7bf..b545577159a8b 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py @@ -3,9 +3,10 @@ import os import random import string +import sys import time -from contextlib import contextmanager -from typing import Iterator, Literal, Mapping, Optional +from contextlib import ExitStack, contextmanager +from typing import Iterator, Literal, Mapping, Optional, TextIO import dagster._check as check from dagster._annotations import experimental @@ -20,6 +21,7 @@ ) from dagster._core.pipes.utils import ( PipesBlobStoreMessageReader, + PipesBlobStoreStdioReader, open_pipes_session, ) from dagster_pipes import ( @@ -70,7 +72,15 @@ def __init__( message_reader, "message_reader", PipesMessageReader, - ) or PipesDbfsMessageReader(client=self.client) + ) or PipesDbfsMessageReader( + client=self.client, + stdout_reader=PipesDbfsStdioReader( + client=self.client, remote_log_name="stdout", target_stream=sys.stdout + ), + stderr_reader=PipesDbfsStdioReader( + client=self.client, remote_log_name="stderr", target_stream=sys.stderr + ), + ) def run( self, @@ -112,10 +122,11 @@ def run( } cluster_log_root = pipes_session.get_bootstrap_params()[ DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES["messages"] - ]["cluster_log_root"] - submit_task_dict["new_cluster"]["cluster_log_conf"] = { - "dbfs": {"destination": f"dbfs:{cluster_log_root}"} - } + ].get("cluster_log_root") + if cluster_log_root is not None: + submit_task_dict["new_cluster"]["cluster_log_conf"] = { + "dbfs": {"destination": f"dbfs:{cluster_log_root}"} + } task = jobs.SubmitTask.from_dict(submit_task_dict) run_id = self.client.jobs.submit( tasks=[task], @@ -210,6 +221,8 @@ class PipesDbfsMessageReader(PipesBlobStoreMessageReader): client (WorkspaceClient): A databricks `WorkspaceClient` object. cluster_log_root (Optional[str]): The root path on DBFS where the cluster logs are written. If set, this will be used to read stderr/stdout logs. + stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs. + stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs. """ def __init__( @@ -217,21 +230,22 @@ def __init__( *, interval: int = 10, client: WorkspaceClient, - forward_stdout: bool = False, - forward_stderr: bool = False, + stdout_reader: Optional[PipesBlobStoreStdioReader] = None, + stderr_reader: Optional[PipesBlobStoreStdioReader] = None, ): super().__init__( - interval=interval, forward_stdout=forward_stdout, forward_stderr=forward_stderr + interval=interval, stdout_reader=stdout_reader, stderr_reader=stderr_reader ) self.dbfs_client = files.DbfsAPI(client.api_client) - self.stdio_position = {"stdout": 0, "stderr": 0} @contextmanager def get_params(self) -> Iterator[PipesParams]: - with dbfs_tempdir(self.dbfs_client) as messages_tempdir, dbfs_tempdir( - self.dbfs_client - ) as logs_tempdir: - yield {"path": messages_tempdir, "cluster_log_root": logs_tempdir} + with ExitStack() as stack: + params: PipesParams = {} + params["path"] = stack.enter_context(dbfs_tempdir(self.dbfs_client)) + if self.stdout_reader or self.stderr_reader: + params["cluster_log_root"] = stack.enter_context(dbfs_tempdir(self.dbfs_client)) + yield params def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: message_path = os.path.join(params["path"], f"{index}.json") @@ -246,22 +260,39 @@ def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[s except IOError: return None - def download_stdout_chunk(self, params: PipesParams) -> Optional[str]: - return self._download_stdio_chunk(params, "stdout") + def no_messages_debug_text(self) -> str: + return ( + "Attempted to read messages from a temporary file in dbfs. Expected" + " PipesDbfsMessageWriter to be explicitly passed to open_dagster_pipes in the external" + " process." + ) + - def stdout_log_exists(self, params) -> bool: - return self._get_stdio_log_path(params, "stdout") is not None +class PipesDbfsStdioReader(PipesBlobStoreStdioReader): + """Reader that reads stdout/stderr logs from DBFS. - def download_stderr_chunk(self, params: PipesParams) -> Optional[str]: - return self._download_stdio_chunk(params, "stderr") + Args: + interval (float): interval in seconds between attempts to download a log chunk + remote_log_name (Literal["stdout", "stderr"]): The name of the log file to read. + target_stream (TextIO): The stream to which to forward log chunk that have been read. + client (WorkspaceClient): A databricks `WorkspaceClient` object. + """ - def stderr_log_exists(self, params) -> bool: - return self._get_stdio_log_path(params, "stderr") is not None + def __init__( + self, + *, + interval: float = 10, + remote_log_name: Literal["stdout", "stderr"], + target_stream: TextIO, + client: WorkspaceClient, + ): + super().__init__(interval=interval, target_stream=target_stream) + self.dbfs_client = files.DbfsAPI(client.api_client) + self.remote_log_name = remote_log_name + self.log_position = 0 - def _download_stdio_chunk( - self, params: PipesParams, stream: Literal["stdout", "stderr"] - ) -> Optional[str]: - log_path = self._get_stdio_log_path(params, stream) + def download_log_chunk(self, params: PipesParams) -> Optional[str]: + log_path = self._get_log_path(params) if log_path is None: return None else: @@ -269,27 +300,21 @@ def _download_stdio_chunk( read_response = self.dbfs_client.read(log_path) assert read_response.data content = base64.b64decode(read_response.data).decode("utf-8") - chunk = content[self.stdio_position[stream] :] - self.stdio_position[stream] = len(content) + chunk = content[self.log_position :] + self.log_position = len(content) return chunk except IOError: return None + def log_exists(self, params) -> bool: + return self._get_log_path(params) is not None + # The directory containing logs will not exist until either 5 minutes have elapsed or the # job has finished. - def _get_stdio_log_path( - self, params: PipesParams, stream: Literal["stdout", "stderr"] - ) -> Optional[str]: + def _get_log_path(self, params: PipesParams) -> Optional[str]: log_root_path = os.path.join(params["cluster_log_root"]) child_dirs = list(self.dbfs_client.list(log_root_path)) if len(child_dirs) > 0: - return f"dbfs:{child_dirs[0].path}/driver/{stream}" + return f"dbfs:{child_dirs[0].path}/driver/{self.remote_log_name}" else: return None - - def no_messages_debug_text(self) -> str: - return ( - "Attempted to read messages from a temporary file in dbfs. Expected" - " PipesDbfsMessageWriter to be explicitly passed to open_dagster_pipes in the external" - " process." - ) diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py b/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py index 6dfba40bf60ae..1cafa135b26df 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py @@ -10,7 +10,7 @@ import pytest from dagster import AssetExecutionContext, asset, materialize from dagster._core.errors import DagsterPipesExecutionError -from dagster_databricks.pipes import PipesDatabricksClient, PipesDbfsMessageReader, dbfs_tempdir +from dagster_databricks.pipes import PipesDatabricksClient, dbfs_tempdir from databricks.sdk import WorkspaceClient from databricks.sdk.service import files, jobs @@ -81,6 +81,8 @@ def get_repo_root() -> str: return path +# Upload the Dagster Pipes wheel to DBFS. Use this fixture to avoid needing to manually reupload +# dagster-pipes if it has changed between test runs. @contextmanager def upload_dagster_pipes_whl(client: WorkspaceClient) -> Iterator[None]: repo_root = get_repo_root() @@ -129,11 +131,6 @@ def number_x(context: AssetExecutionContext, pipes_client: PipesDatabricksClient resources={ "pipes_client": PipesDatabricksClient( client, - message_reader=PipesDbfsMessageReader( - client=client, - forward_stdout=True, - forward_stderr=True, - ), ) }, raise_on_error=False,