diff --git a/python_modules/dagster-pipes/dagster_pipes/__init__.py b/python_modules/dagster-pipes/dagster_pipes/__init__.py index c98b30dc03dcb..ed560eda434d3 100644 --- a/python_modules/dagster-pipes/dagster_pipes/__init__.py +++ b/python_modules/dagster-pipes/dagster_pipes/__init__.py @@ -11,7 +11,8 @@ from abc import ABC, abstractmethod from contextlib import ExitStack, contextmanager from io import StringIO -from threading import Event, Lock, Thread +from queue import Queue +from threading import Event, Thread from typing import ( IO, TYPE_CHECKING, @@ -81,7 +82,7 @@ def _make_message(method: str, params: Optional[Mapping[str, Any]]) -> "PipesMes class PipesMessage(TypedDict): - """A message sent from the orchestration process to the external process.""" + """A message sent from the external process to the orchestration process.""" __dagster_pipes_version: str method: str @@ -534,19 +535,17 @@ class PipesBlobStoreMessageWriterChannel(PipesMessageWriterChannel): def __init__(self, *, interval: float = 10): self._interval = interval - self._lock = Lock() - self._buffer = [] + self._buffer: Queue[PipesMessage] = Queue() self._counter = 1 def write_message(self, message: PipesMessage) -> None: - with self._lock: - self._buffer.append(message) + self._buffer.put(message) def flush_messages(self) -> Sequence[PipesMessage]: - with self._lock: - messages = list(self._buffer) - self._buffer.clear() - return messages + items = [] + while not self._buffer.empty(): + items.append(self._buffer.get()) + return items @abstractmethod def upload_messages_chunk(self, payload: StringIO, index: int) -> None: ... @@ -567,15 +566,15 @@ def buffered_upload_loop(self) -> Iterator[None]: def _upload_loop(self, is_task_complete: Event) -> None: start_or_last_upload = datetime.datetime.now() while True: - num_pending = len(self._buffer) now = datetime.datetime.now() - if num_pending == 0 and is_task_complete.is_set(): + if self._buffer.empty() and is_task_complete.is_set(): break elif is_task_complete.is_set() or (now - start_or_last_upload).seconds > self._interval: payload = "\n".join([json.dumps(message) for message in self.flush_messages()]) - self.upload_messages_chunk(StringIO(payload), self._counter) - start_or_last_upload = now - self._counter += 1 + if len(payload) > 0: + self.upload_messages_chunk(StringIO(payload), self._counter) + start_or_last_upload = now + self._counter += 1 time.sleep(1) diff --git a/python_modules/dagster/dagster/_core/pipes/utils.py b/python_modules/dagster/dagster/_core/pipes/utils.py index b2c6e04928cc0..9051ec76b9534 100644 --- a/python_modules/dagster/dagster/_core/pipes/utils.py +++ b/python_modules/dagster/dagster/_core/pipes/utils.py @@ -144,7 +144,7 @@ def read_messages( self, handler: "PipesMessageHandler", ) -> Iterator[PipesParams]: - """Set up a thread to read streaming messages from teh external process by tailing the + """Set up a thread to read streaming messages from the external process by tailing the target file. Args: @@ -208,6 +208,11 @@ def no_messages_debug_text(self) -> str: return "Attempted to read messages from a local temporary file." +# Number of seconds to wait after an external process has completed for stdio logs to become +# available. If this is exceeded, proceed with exiting without picking up logs. +WAIT_FOR_STDIO_LOGS_TIMEOUT = 60 + + class PipesBlobStoreMessageReader(PipesMessageReader): """Message reader that reads a sequence of message chunks written by an external process into a blob store such as S3, Azure blob storage, or GCS. @@ -223,14 +228,22 @@ class PipesBlobStoreMessageReader(PipesMessageReader): Args: interval (float): interval in seconds between attempts to download a chunk + forward_stdout (bool): whether to forward stdout from the pipes process to Dagster. + forward_stderr (bool): whether to forward stderr from the pipes process to Dagster. """ interval: float counter: int + forward_stdout: bool + forward_stderr: bool - def __init__(self, interval: float = 10): + def __init__( + self, interval: float = 10, forward_stdout: bool = False, forward_stderr: bool = False + ): self.interval = interval self.counter = 1 + self.forward_stdout = forward_stdout + self.forward_stderr = forward_stderr @contextmanager def read_messages( @@ -249,23 +262,46 @@ def read_messages( """ with self.get_params() as params: is_task_complete = Event() - thread = None + messages_thread = None + stdout_thread = None + stderr_thread = None try: - thread = Thread( - target=self._reader_thread, - args=( - handler, - params, - is_task_complete, - ), - daemon=True, + messages_thread = Thread( + target=self._messages_thread, args=(handler, params, is_task_complete) ) - thread.start() + messages_thread.start() + if self.forward_stdout: + stdout_thread = Thread( + target=self._stdout_thread, args=(params, is_task_complete) + ) + stdout_thread.start() + if self.forward_stderr: + stderr_thread = Thread( + target=self._stderr_thread, args=(params, is_task_complete) + ) + stderr_thread.start() yield params finally: + self.wait_for_stdio_logs(params) is_task_complete.set() - if thread: - thread.join() + if messages_thread: + messages_thread.join() + if stdout_thread: + stdout_thread.join() + if stderr_thread: + stderr_thread.join() + + # In cases where we are forwarding logs, in some cases the logs might not be written out until + # after the run completes. We wait for them to exist. + def wait_for_stdio_logs(self, params): + start_or_last_download = datetime.datetime.now() + while ( + datetime.datetime.now() - start_or_last_download + ).seconds <= WAIT_FOR_STDIO_LOGS_TIMEOUT and ( + (self.forward_stdout and not self.stdout_log_exists(params)) + or (self.forward_stderr and not self.stderr_log_exists(params)) + ): + time.sleep(5) @abstractmethod @contextmanager @@ -280,15 +316,30 @@ def get_params(self) -> Iterator[PipesParams]: @abstractmethod def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: ... - def _reader_thread( - self, handler: "PipesMessageHandler", params: PipesParams, is_task_complete: Event + def download_stdout_chunk(self, params: PipesParams) -> Optional[str]: + raise NotImplementedError() + + def stdout_log_exists(self, params: PipesParams) -> bool: + raise NotImplementedError() + + def download_stderr_chunk(self, params: PipesParams) -> Optional[str]: + raise NotImplementedError() + + def stderr_log_exists(self, params: PipesParams) -> bool: + raise NotImplementedError() + + def _messages_thread( + self, + handler: "PipesMessageHandler", + params: PipesParams, + is_task_complete: Event, ) -> None: start_or_last_download = datetime.datetime.now() while True: now = datetime.datetime.now() if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set(): - chunk = self.download_messages_chunk(self.counter, params) start_or_last_download = now + chunk = self.download_messages_chunk(self.counter, params) if chunk: for line in chunk.split("\n"): message = json.loads(line) @@ -298,6 +349,40 @@ def _reader_thread( break time.sleep(1) + def _stdout_thread( + self, + params: PipesParams, + is_task_complete: Event, + ) -> None: + start_or_last_download = datetime.datetime.now() + while True: + now = datetime.datetime.now() + if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set(): + start_or_last_download = now + chunk = self.download_stdout_chunk(params) + if chunk: + sys.stdout.write(chunk) + elif is_task_complete.is_set(): + break + time.sleep(10) + + def _stderr_thread( + self, + params: PipesParams, + is_task_complete: Event, + ) -> None: + start_or_last_download = datetime.datetime.now() + while True: + now = datetime.datetime.now() + if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set(): + start_or_last_download = now + chunk = self.download_stderr_chunk(params) + if chunk: + sys.stderr.write(chunk) + elif is_task_complete.is_set(): + break + time.sleep(10) + def extract_message_or_forward_to_stdout(handler: "PipesMessageHandler", log_line: str): # exceptions as control flow, you love to see it @@ -305,6 +390,8 @@ def extract_message_or_forward_to_stdout(handler: "PipesMessageHandler", log_lin message = json.loads(log_line) if PIPES_PROTOCOL_VERSION_FIELD in message.keys(): handler.handle_message(message) + else: + sys.stdout.writelines((log_line, "\n")) except Exception: # move non-message logs in to stdout for compute log capture sys.stdout.writelines((log_line, "\n")) diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py b/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py index 2720fe796e56c..a9aa5719fc7bf 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py @@ -5,7 +5,7 @@ import string import time from contextlib import contextmanager -from typing import Iterator, Mapping, Optional +from typing import Iterator, Literal, Mapping, Optional import dagster._check as check from dagster._annotations import experimental @@ -23,6 +23,7 @@ open_pipes_session, ) from dagster_pipes import ( + DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES, PipesContextData, PipesExtras, PipesParams, @@ -109,6 +110,12 @@ def run( **(self.env or {}), **pipes_session.get_bootstrap_env_vars(), } + cluster_log_root = pipes_session.get_bootstrap_params()[ + DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES["messages"] + ]["cluster_log_root"] + submit_task_dict["new_cluster"]["cluster_log_conf"] = { + "dbfs": {"destination": f"dbfs:{cluster_log_root}"} + } task = jobs.SubmitTask.from_dict(submit_task_dict) run_id = self.client.jobs.submit( tasks=[task], @@ -135,6 +142,7 @@ def run( f"Error running Databricks job: {run.state.state_message}" ) time.sleep(5) + time.sleep(30) # 30 seconds to make sure logs are flushed return PipesClientCompletedInvocation(tuple(pipes_session.get_results())) @@ -200,22 +208,35 @@ class PipesDbfsMessageReader(PipesBlobStoreMessageReader): Args: interval (float): interval in seconds between attempts to download a chunk client (WorkspaceClient): A databricks `WorkspaceClient` object. + cluster_log_root (Optional[str]): The root path on DBFS where the cluster logs are written. + If set, this will be used to read stderr/stdout logs. """ - def __init__(self, *, interval: int = 10, client: WorkspaceClient): - super().__init__(interval=interval) + def __init__( + self, + *, + interval: int = 10, + client: WorkspaceClient, + forward_stdout: bool = False, + forward_stderr: bool = False, + ): + super().__init__( + interval=interval, forward_stdout=forward_stdout, forward_stderr=forward_stderr + ) self.dbfs_client = files.DbfsAPI(client.api_client) + self.stdio_position = {"stdout": 0, "stderr": 0} @contextmanager def get_params(self) -> Iterator[PipesParams]: - with dbfs_tempdir(self.dbfs_client) as tempdir: - yield {"path": tempdir} + with dbfs_tempdir(self.dbfs_client) as messages_tempdir, dbfs_tempdir( + self.dbfs_client + ) as logs_tempdir: + yield {"path": messages_tempdir, "cluster_log_root": logs_tempdir} def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: message_path = os.path.join(params["path"], f"{index}.json") try: raw_message = self.dbfs_client.read(message_path) - # Files written to dbfs using the Python IO interface used in PipesDbfsMessageWriter are # base64-encoded. return base64.b64decode(raw_message.data).decode("utf-8") @@ -225,6 +246,47 @@ def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[s except IOError: return None + def download_stdout_chunk(self, params: PipesParams) -> Optional[str]: + return self._download_stdio_chunk(params, "stdout") + + def stdout_log_exists(self, params) -> bool: + return self._get_stdio_log_path(params, "stdout") is not None + + def download_stderr_chunk(self, params: PipesParams) -> Optional[str]: + return self._download_stdio_chunk(params, "stderr") + + def stderr_log_exists(self, params) -> bool: + return self._get_stdio_log_path(params, "stderr") is not None + + def _download_stdio_chunk( + self, params: PipesParams, stream: Literal["stdout", "stderr"] + ) -> Optional[str]: + log_path = self._get_stdio_log_path(params, stream) + if log_path is None: + return None + else: + try: + read_response = self.dbfs_client.read(log_path) + assert read_response.data + content = base64.b64decode(read_response.data).decode("utf-8") + chunk = content[self.stdio_position[stream] :] + self.stdio_position[stream] = len(content) + return chunk + except IOError: + return None + + # The directory containing logs will not exist until either 5 minutes have elapsed or the + # job has finished. + def _get_stdio_log_path( + self, params: PipesParams, stream: Literal["stdout", "stderr"] + ) -> Optional[str]: + log_root_path = os.path.join(params["cluster_log_root"]) + child_dirs = list(self.dbfs_client.list(log_root_path)) + if len(child_dirs) > 0: + return f"dbfs:{child_dirs[0].path}/driver/{stream}" + else: + return None + def no_messages_debug_text(self) -> str: return ( "Attempted to read messages from a temporary file in dbfs. Expected" diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py b/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py index 45d6656c101cd..6dfba40bf60ae 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py @@ -1,6 +1,8 @@ import base64 import inspect import os +import re +import subprocess import textwrap from contextlib import contextmanager from typing import Any, Callable, Iterator @@ -8,7 +10,7 @@ import pytest from dagster import AssetExecutionContext, asset, materialize from dagster._core.errors import DagsterPipesExecutionError -from dagster_databricks.pipes import PipesDatabricksClient, dbfs_tempdir +from dagster_databricks.pipes import PipesDatabricksClient, PipesDbfsMessageReader, dbfs_tempdir from databricks.sdk import WorkspaceClient from databricks.sdk.service import files, jobs @@ -16,6 +18,8 @@ def script_fn(): + import sys + from dagster_pipes import ( PipesDbfsContextLoader, PipesDbfsMessageWriter, @@ -23,11 +27,13 @@ def script_fn(): ) with open_dagster_pipes( - context_loader=PipesDbfsContextLoader(), message_writer=PipesDbfsMessageWriter() + context_loader=PipesDbfsContextLoader(), + message_writer=PipesDbfsMessageWriter(), ) as context: multiplier = context.get_extra("multiplier") value = 2 * multiplier - + print("hello from databricks stdout") # noqa: T201 + print("hello from databricks stderr", file=sys.stderr) # noqa: T201 context.log.info(f"{context.asset_key}: {2} * {multiplier} = {value}") context.report_asset_materialization( metadata={"value": value}, @@ -62,8 +68,31 @@ def client() -> WorkspaceClient: TASK_KEY = "DAGSTER_PIPES_TASK" +DAGSTER_PIPES_WHL_FILENAME = "dagster_pipes-1!0+dev-py3-none-any.whl" + # This has been manually uploaded to a test DBFS workspace. -DAGSTER_PIPES_WHL_PATH = "dbfs:/FileStore/jars/dagster_pipes-1!0+dev-py3-none-any.whl" +DAGSTER_PIPES_WHL_PATH = f"dbfs:/FileStore/jars/{DAGSTER_PIPES_WHL_FILENAME}" + + +def get_repo_root() -> str: + path = os.path.dirname(__file__) + while not os.path.exists(os.path.join(path, ".git")): + path = os.path.dirname(path) + return path + + +@contextmanager +def upload_dagster_pipes_whl(client: WorkspaceClient) -> Iterator[None]: + repo_root = get_repo_root() + orig_wd = os.getcwd() + dagster_pipes_root = os.path.join(repo_root, "python_modules", "dagster-pipes") + os.chdir(dagster_pipes_root) + subprocess.check_call(["python", "setup.py", "bdist_wheel"]) + subprocess.check_call( + ["dbfs", "cp", "--overwrite", f"dist/{DAGSTER_PIPES_WHL_FILENAME}", DAGSTER_PIPES_WHL_PATH] + ) + os.chdir(orig_wd) + yield def _make_submit_task(path: str) -> jobs.SubmitTask: @@ -83,25 +112,38 @@ def _make_submit_task(path: str) -> jobs.SubmitTask: @pytest.mark.skipif(IS_BUILDKITE, reason="Not configured to run on BK yet.") -def test_pipes_client(client: WorkspaceClient): +def test_pipes_client(capsys, client: WorkspaceClient): @asset def number_x(context: AssetExecutionContext, pipes_client: PipesDatabricksClient): - with temp_script(script_fn, client) as script_path: - task = _make_submit_task(script_path) - return pipes_client.run( - task=task, - context=context, - extras={"multiplier": 2, "storage_root": "fake"}, - ).get_results() + with upload_dagster_pipes_whl(client): + with temp_script(script_fn, client) as script_path: + task = _make_submit_task(script_path) + return pipes_client.run( + task=task, + context=context, + extras={"multiplier": 2, "storage_root": "fake"}, + ).get_results() result = materialize( [number_x], - resources={"pipes_client": PipesDatabricksClient(client)}, + resources={ + "pipes_client": PipesDatabricksClient( + client, + message_reader=PipesDbfsMessageReader( + client=client, + forward_stdout=True, + forward_stderr=True, + ), + ) + }, raise_on_error=False, ) assert result.success mats = result.asset_materializations_for_node(number_x.op.name) assert mats[0].metadata["value"].value == 4 + captured = capsys.readouterr() + assert re.search(r"hello from databricks stdout\n", captured.out, re.MULTILINE) + assert re.search(r"hello from databricks stderr\n", captured.err, re.MULTILINE) @pytest.mark.skipif(IS_BUILDKITE, reason="Not configured to run on BK yet.")