diff --git a/python_modules/dagster-pipes/dagster_pipes/__init__.py b/python_modules/dagster-pipes/dagster_pipes/__init__.py index 5ef63494de3e8..78905faab8883 100644 --- a/python_modules/dagster-pipes/dagster_pipes/__init__.py +++ b/python_modules/dagster-pipes/dagster_pipes/__init__.py @@ -11,7 +11,8 @@ from abc import ABC, abstractmethod from contextlib import ExitStack, contextmanager from io import StringIO -from threading import Event, Lock, Thread +from queue import Queue +from threading import Event, Thread from typing import ( IO, TYPE_CHECKING, @@ -66,7 +67,7 @@ def _make_message(method: str, params: Optional[Mapping[str, Any]]) -> "PipesMes class PipesMessage(TypedDict): - """A message sent from the orchestration process to the external process.""" + """A message sent from the external process to the orchestration process.""" __dagster_pipes_version: str method: str @@ -513,19 +514,17 @@ class PipesBlobStoreMessageWriterChannel(PipesMessageWriterChannel): def __init__(self, *, interval: float = 10): self._interval = interval - self._lock = Lock() - self._buffer = [] + self._buffer: Queue[PipesMessage] = Queue() self._counter = 1 def write_message(self, message: PipesMessage) -> None: - with self._lock: - self._buffer.append(message) + self._buffer.put(message) def flush_messages(self) -> Sequence[PipesMessage]: - with self._lock: - messages = list(self._buffer) - self._buffer.clear() - return messages + items = [] + while not self._buffer.empty(): + items.append(self._buffer.get()) + return items @abstractmethod def upload_messages_chunk(self, payload: StringIO, index: int) -> None: ... @@ -546,15 +545,15 @@ def buffered_upload_loop(self) -> Iterator[None]: def _upload_loop(self, is_task_complete: Event) -> None: start_or_last_upload = datetime.datetime.now() while True: - num_pending = len(self._buffer) now = datetime.datetime.now() - if num_pending == 0 and is_task_complete.is_set(): + if self._buffer.empty() and is_task_complete.is_set(): break elif is_task_complete.is_set() or (now - start_or_last_upload).seconds > self._interval: payload = "\n".join([json.dumps(message) for message in self.flush_messages()]) - self.upload_messages_chunk(StringIO(payload), self._counter) - start_or_last_upload = now - self._counter += 1 + if len(payload) > 0: + self.upload_messages_chunk(StringIO(payload), self._counter) + start_or_last_upload = now + self._counter += 1 time.sleep(1) diff --git a/python_modules/dagster/dagster/__init__.py b/python_modules/dagster/dagster/__init__.py index 17176ae4ef2e7..d4e60a04402da 100644 --- a/python_modules/dagster/dagster/__init__.py +++ b/python_modules/dagster/dagster/__init__.py @@ -485,6 +485,7 @@ from dagster._core.pipes.subprocess import PipesSubprocessClient as PipesSubprocessClient from dagster._core.pipes.utils import ( PipesBlobStoreMessageReader as PipesBlobStoreMessageReader, + PipesBlobStoreStdioReader as PipesBlobStoreStdioReader, PipesEnvContextInjector as PipesEnvContextInjector, PipesFileContextInjector as PipesFileContextInjector, PipesFileMessageReader as PipesFileMessageReader, diff --git a/python_modules/dagster/dagster/_core/pipes/utils.py b/python_modules/dagster/dagster/_core/pipes/utils.py index b2c6e04928cc0..866b8b005fd2f 100644 --- a/python_modules/dagster/dagster/_core/pipes/utils.py +++ b/python_modules/dagster/dagster/_core/pipes/utils.py @@ -4,10 +4,10 @@ import sys import tempfile import time -from abc import abstractmethod +from abc import ABC, abstractmethod from contextlib import contextmanager from threading import Event, Thread -from typing import Iterator, Optional +from typing import Iterator, Optional, TextIO from dagster_pipes import ( PIPES_PROTOCOL_VERSION_FIELD, @@ -144,7 +144,7 @@ def read_messages( self, handler: "PipesMessageHandler", ) -> Iterator[PipesParams]: - """Set up a thread to read streaming messages from teh external process by tailing the + """Set up a thread to read streaming messages from the external process by tailing the target file. Args: @@ -208,6 +208,12 @@ def no_messages_debug_text(self) -> str: return "Attempted to read messages from a local temporary file." +# Number of seconds to wait after an external process has completed for stdio logs to become +# available. If this is exceeded, proceed with exiting without picking up logs. +WAIT_FOR_STDIO_LOGS_TIMEOUT = 60 + + +@experimental class PipesBlobStoreMessageReader(PipesMessageReader): """Message reader that reads a sequence of message chunks written by an external process into a blob store such as S3, Azure blob storage, or GCS. @@ -221,16 +227,37 @@ class PipesBlobStoreMessageReader(PipesMessageReader): counter (starting from 1) on successful write, keeping counters on the read and write end in sync. + If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when + `read_messages` is called. If they are not passed, then the reader performs no stdout/stderr + forwarding. + Args: interval (float): interval in seconds between attempts to download a chunk + stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs. + stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs. """ interval: float counter: int + stdout_reader: "PipesBlobStoreStdioReader" + stderr_reader: "PipesBlobStoreStdioReader" - def __init__(self, interval: float = 10): + def __init__( + self, + interval: float = 10, + stdout_reader: Optional["PipesBlobStoreStdioReader"] = None, + stderr_reader: Optional["PipesBlobStoreStdioReader"] = None, + ): self.interval = interval self.counter = 1 + self.stdout_reader = ( + check.opt_inst_param(stdout_reader, "stdout_reader", PipesBlobStoreStdioReader) + or PipesNoOpStdioReader() + ) + self.stderr_reader = ( + check.opt_inst_param(stderr_reader, "stderr_reader", PipesBlobStoreStdioReader) + or PipesNoOpStdioReader() + ) @contextmanager def read_messages( @@ -249,23 +276,34 @@ def read_messages( """ with self.get_params() as params: is_task_complete = Event() - thread = None + messages_thread = None try: - thread = Thread( - target=self._reader_thread, - args=( - handler, - params, - is_task_complete, - ), - daemon=True, + messages_thread = Thread( + target=self._messages_thread, args=(handler, params, is_task_complete) ) - thread.start() + messages_thread.start() + self.stdout_reader.start(params, is_task_complete) + self.stderr_reader.start(params, is_task_complete) yield params finally: + self.wait_for_stdio_logs(params) is_task_complete.set() - if thread: - thread.join() + if messages_thread: + messages_thread.join() + self.stdout_reader.stop() + self.stderr_reader.stop() + + # In cases where we are forwarding logs, in some cases the logs might not be written out until + # after the run completes. We wait for them to exist. + def wait_for_stdio_logs(self, params): + start_or_last_download = datetime.datetime.now() + while ( + datetime.datetime.now() - start_or_last_download + ).seconds <= WAIT_FOR_STDIO_LOGS_TIMEOUT and ( + (self.stdout_reader and not self.stdout_reader.is_ready(params)) + or (self.stderr_reader and not self.stderr_reader.is_ready(params)) + ): + time.sleep(5) @abstractmethod @contextmanager @@ -280,15 +318,18 @@ def get_params(self) -> Iterator[PipesParams]: @abstractmethod def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: ... - def _reader_thread( - self, handler: "PipesMessageHandler", params: PipesParams, is_task_complete: Event + def _messages_thread( + self, + handler: "PipesMessageHandler", + params: PipesParams, + is_task_complete: Event, ) -> None: start_or_last_download = datetime.datetime.now() while True: now = datetime.datetime.now() if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set(): - chunk = self.download_messages_chunk(self.counter, params) start_or_last_download = now + chunk = self.download_messages_chunk(self.counter, params) if chunk: for line in chunk.split("\n"): message = json.loads(line) @@ -299,12 +340,83 @@ def _reader_thread( time.sleep(1) +class PipesBlobStoreStdioReader(ABC): + @abstractmethod + def start(self, params: PipesParams, is_task_complete: Event) -> None: ... + + @abstractmethod + def stop(self) -> None: ... + + @abstractmethod + def is_ready(self, params: PipesParams) -> bool: ... + + +@experimental +class PipesChunkedStdioReader(PipesBlobStoreStdioReader): + """Reader for reading stdout/stderr logs from a blob store such as S3, Azure blob storage, or GCS. + + Args: + interval (float): interval in seconds between attempts to download a chunk. + target_stream (TextIO): The stream to which to write the logs. Typcially `sys.stdout` or `sys.stderr`. + """ + + def __init__(self, *, interval: float = 10, target_stream: TextIO): + self.interval = interval + self.target_stream = target_stream + self.thread: Optional[Thread] = None + + @abstractmethod + def download_log_chunk(self, params: PipesParams) -> Optional[str]: ... + + def start(self, params: PipesParams, is_task_complete: Event) -> None: + self.thread = Thread(target=self._reader_thread, args=(params, is_task_complete)) + self.thread.start() + + def stop(self) -> None: + if self.thread: + self.thread.join() + + def _reader_thread( + self, + params: PipesParams, + is_task_complete: Event, + ) -> None: + start_or_last_download = datetime.datetime.now() + while True: + now = datetime.datetime.now() + if ( + (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set() + ) and self.is_ready(params): + start_or_last_download = now + chunk = self.download_log_chunk(params) + if chunk: + self.target_stream.write(chunk) + elif is_task_complete.is_set(): + break + time.sleep(self.interval) + + +class PipesNoOpStdioReader(PipesBlobStoreStdioReader): + """Default implementation for a pipes stdio reader that does nothing.""" + + def start(self, params: PipesParams, is_task_complete: Event) -> None: + pass + + def stop(self) -> None: + pass + + def is_ready(self, params: PipesParams) -> bool: + return True + + def extract_message_or_forward_to_stdout(handler: "PipesMessageHandler", log_line: str): # exceptions as control flow, you love to see it try: message = json.loads(log_line) if PIPES_PROTOCOL_VERSION_FIELD in message.keys(): handler.handle_message(message) + else: + sys.stdout.writelines((log_line, "\n")) except Exception: # move non-message logs in to stdout for compute log capture sys.stdout.writelines((log_line, "\n")) diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py index df949134e428f..53e6a21adba51 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes.py @@ -9,12 +9,37 @@ from dagster._core.pipes.client import ( PipesParams, ) -from dagster._core.pipes.utils import PipesBlobStoreMessageReader +from dagster._core.pipes.utils import PipesBlobStoreMessageReader, PipesBlobStoreStdioReader class PipesS3MessageReader(PipesBlobStoreMessageReader): - def __init__(self, *, interval: float = 10, bucket: str, client: boto3.client): - super().__init__(interval=interval) + """Message reader that reads messages by periodically reading message chunks from a specified S3 + bucket. + + If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when + `read_messages` is called. If they are not passed, then the reader performs no stdout/stderr + forwarding. + + Args: + interval (float): interval in seconds between attempts to download a chunk + bucket (str): The S3 bucket to read from. + client (WorkspaceClient): A boto3 client. + stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs. + stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs. + """ + + def __init__( + self, + *, + interval: float = 10, + bucket: str, + client: boto3.client, + stdout_reader: Optional[PipesBlobStoreStdioReader] = None, + stderr_reader: Optional[PipesBlobStoreStdioReader] = None, + ): + super().__init__( + interval=interval, stdout_reader=stdout_reader, stderr_reader=stderr_reader + ) self.bucket = check.str_param(bucket, "bucket") self.key_prefix = "".join(random.choices(string.ascii_letters, k=30)) self.client = client diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py b/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py index b7b5210dab1f1..979c2d3fd5ec8 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/__init__.py @@ -28,6 +28,7 @@ PipesDatabricksClient as PipesDatabricksClient, PipesDbfsContextInjector as PipesDbfsContextInjector, PipesDbfsMessageReader as PipesDbfsMessageReader, + PipesDbfsStdioReader as PipesDbfsStdioReader, ) from .resources import ( DatabricksClientResource as DatabricksClientResource, diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py b/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py index 39ecd08d28316..9fc9817aa0fde 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks/pipes.py @@ -3,9 +3,10 @@ import os import random import string +import sys import time -from contextlib import contextmanager -from typing import Iterator, Mapping, Optional +from contextlib import ExitStack, contextmanager +from typing import Iterator, Literal, Mapping, Optional, TextIO import dagster._check as check from dagster._annotations import experimental @@ -20,9 +21,12 @@ ) from dagster._core.pipes.utils import ( PipesBlobStoreMessageReader, + PipesBlobStoreStdioReader, + PipesChunkedStdioReader, open_pipes_session, ) from dagster_pipes import ( + DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES, PipesContextData, PipesExtras, PipesParams, @@ -31,6 +35,10 @@ from databricks.sdk.service import files, jobs from pydantic import Field +# Number of seconds between status checks on Databricks jobs launched by the +# `PipesDatabricksClient`. +_RUN_POLL_INTERVAL = 5 + @experimental class _PipesDatabricksClient(PipesClient): @@ -69,7 +77,15 @@ def __init__( message_reader, "message_reader", PipesMessageReader, - ) or PipesDbfsMessageReader(client=self.client) + ) or PipesDbfsMessageReader( + client=self.client, + stdout_reader=PipesDbfsStdioReader( + client=self.client, remote_log_name="stdout", target_stream=sys.stdout + ), + stderr_reader=PipesDbfsStdioReader( + client=self.client, remote_log_name="stderr", target_stream=sys.stderr + ), + ) @classmethod def _is_dagster_maintained(cls) -> bool: @@ -113,6 +129,13 @@ def run( **(self.env or {}), **pipes_session.get_bootstrap_env_vars(), } + cluster_log_root = pipes_session.get_bootstrap_params()[ + DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES["messages"] + ].get("cluster_log_root") + if cluster_log_root is not None: + submit_task_dict["new_cluster"]["cluster_log_conf"] = { + "dbfs": {"destination": f"dbfs:{cluster_log_root}"} + } task = jobs.SubmitTask.from_dict(submit_task_dict) run_id = self.client.jobs.submit( tasks=[task], @@ -138,7 +161,7 @@ def run( raise DagsterPipesExecutionError( f"Error running Databricks job: {run.state.state_message}" ) - time.sleep(5) + time.sleep(_RUN_POLL_INTERVAL) return PipesClientCompletedInvocation(tuple(pipes_session.get_results())) @@ -201,25 +224,45 @@ class PipesDbfsMessageReader(PipesBlobStoreMessageReader): """Message reader that reads messages by periodically reading message chunks from an automatically-generated temporary directory on DBFS. + If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when + `read_messages` is called. If they are not passed, then the reader performs no stdout/stderr + forwarding. + Args: interval (float): interval in seconds between attempts to download a chunk client (WorkspaceClient): A databricks `WorkspaceClient` object. + cluster_log_root (Optional[str]): The root path on DBFS where the cluster logs are written. + If set, this will be used to read stderr/stdout logs. + stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs. + stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs. """ - def __init__(self, *, interval: int = 10, client: WorkspaceClient): - super().__init__(interval=interval) + def __init__( + self, + *, + interval: float = 10, + client: WorkspaceClient, + stdout_reader: Optional[PipesBlobStoreStdioReader] = None, + stderr_reader: Optional[PipesBlobStoreStdioReader] = None, + ): + super().__init__( + interval=interval, stdout_reader=stdout_reader, stderr_reader=stderr_reader + ) self.dbfs_client = files.DbfsAPI(client.api_client) @contextmanager def get_params(self) -> Iterator[PipesParams]: - with dbfs_tempdir(self.dbfs_client) as tempdir: - yield {"path": tempdir} + with ExitStack() as stack: + params: PipesParams = {} + params["path"] = stack.enter_context(dbfs_tempdir(self.dbfs_client)) + if self.stdout_reader or self.stderr_reader: + params["cluster_log_root"] = stack.enter_context(dbfs_tempdir(self.dbfs_client)) + yield params def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: message_path = os.path.join(params["path"], f"{index}.json") try: raw_message = self.dbfs_client.read(message_path) - # Files written to dbfs using the Python IO interface used in PipesDbfsMessageWriter are # base64-encoded. return base64.b64decode(raw_message.data).decode("utf-8") @@ -235,3 +278,57 @@ def no_messages_debug_text(self) -> str: " PipesDbfsMessageWriter to be explicitly passed to open_dagster_pipes in the external" " process." ) + + +@experimental +class PipesDbfsStdioReader(PipesChunkedStdioReader): + """Reader that reads stdout/stderr logs from DBFS. + + Args: + interval (float): interval in seconds between attempts to download a log chunk + remote_log_name (Literal["stdout", "stderr"]): The name of the log file to read. + target_stream (TextIO): The stream to which to forward log chunk that have been read. + client (WorkspaceClient): A databricks `WorkspaceClient` object. + """ + + def __init__( + self, + *, + interval: float = 10, + remote_log_name: Literal["stdout", "stderr"], + target_stream: TextIO, + client: WorkspaceClient, + ): + super().__init__(interval=interval, target_stream=target_stream) + self.dbfs_client = files.DbfsAPI(client.api_client) + self.remote_log_name = remote_log_name + self.log_position = 0 + self.log_path = None + + def download_log_chunk(self, params: PipesParams) -> Optional[str]: + log_path = self._get_log_path(params) + if log_path is None: + return None + else: + try: + read_response = self.dbfs_client.read(log_path) + assert read_response.data + content = base64.b64decode(read_response.data).decode("utf-8") + chunk = content[self.log_position :] + self.log_position = len(content) + return chunk + except IOError: + return None + + def is_ready(self, params: PipesParams) -> bool: + return self._get_log_path(params) is not None + + # The directory containing logs will not exist until either 5 minutes have elapsed or the + # job has finished. + def _get_log_path(self, params: PipesParams) -> Optional[str]: + if self.log_path is None: + log_root_path = os.path.join(params["cluster_log_root"]) + child_dirs = list(self.dbfs_client.list(log_root_path)) + if len(child_dirs) > 0: + self.log_path = f"dbfs:{child_dirs[0].path}/driver/{self.remote_log_name}" + return self.log_path diff --git a/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py b/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py index 45d6656c101cd..1cafa135b26df 100644 --- a/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py +++ b/python_modules/libraries/dagster-databricks/dagster_databricks_tests/test_pipes.py @@ -1,6 +1,8 @@ import base64 import inspect import os +import re +import subprocess import textwrap from contextlib import contextmanager from typing import Any, Callable, Iterator @@ -16,6 +18,8 @@ def script_fn(): + import sys + from dagster_pipes import ( PipesDbfsContextLoader, PipesDbfsMessageWriter, @@ -23,11 +27,13 @@ def script_fn(): ) with open_dagster_pipes( - context_loader=PipesDbfsContextLoader(), message_writer=PipesDbfsMessageWriter() + context_loader=PipesDbfsContextLoader(), + message_writer=PipesDbfsMessageWriter(), ) as context: multiplier = context.get_extra("multiplier") value = 2 * multiplier - + print("hello from databricks stdout") # noqa: T201 + print("hello from databricks stderr", file=sys.stderr) # noqa: T201 context.log.info(f"{context.asset_key}: {2} * {multiplier} = {value}") context.report_asset_materialization( metadata={"value": value}, @@ -62,8 +68,33 @@ def client() -> WorkspaceClient: TASK_KEY = "DAGSTER_PIPES_TASK" +DAGSTER_PIPES_WHL_FILENAME = "dagster_pipes-1!0+dev-py3-none-any.whl" + # This has been manually uploaded to a test DBFS workspace. -DAGSTER_PIPES_WHL_PATH = "dbfs:/FileStore/jars/dagster_pipes-1!0+dev-py3-none-any.whl" +DAGSTER_PIPES_WHL_PATH = f"dbfs:/FileStore/jars/{DAGSTER_PIPES_WHL_FILENAME}" + + +def get_repo_root() -> str: + path = os.path.dirname(__file__) + while not os.path.exists(os.path.join(path, ".git")): + path = os.path.dirname(path) + return path + + +# Upload the Dagster Pipes wheel to DBFS. Use this fixture to avoid needing to manually reupload +# dagster-pipes if it has changed between test runs. +@contextmanager +def upload_dagster_pipes_whl(client: WorkspaceClient) -> Iterator[None]: + repo_root = get_repo_root() + orig_wd = os.getcwd() + dagster_pipes_root = os.path.join(repo_root, "python_modules", "dagster-pipes") + os.chdir(dagster_pipes_root) + subprocess.check_call(["python", "setup.py", "bdist_wheel"]) + subprocess.check_call( + ["dbfs", "cp", "--overwrite", f"dist/{DAGSTER_PIPES_WHL_FILENAME}", DAGSTER_PIPES_WHL_PATH] + ) + os.chdir(orig_wd) + yield def _make_submit_task(path: str) -> jobs.SubmitTask: @@ -83,25 +114,33 @@ def _make_submit_task(path: str) -> jobs.SubmitTask: @pytest.mark.skipif(IS_BUILDKITE, reason="Not configured to run on BK yet.") -def test_pipes_client(client: WorkspaceClient): +def test_pipes_client(capsys, client: WorkspaceClient): @asset def number_x(context: AssetExecutionContext, pipes_client: PipesDatabricksClient): - with temp_script(script_fn, client) as script_path: - task = _make_submit_task(script_path) - return pipes_client.run( - task=task, - context=context, - extras={"multiplier": 2, "storage_root": "fake"}, - ).get_results() + with upload_dagster_pipes_whl(client): + with temp_script(script_fn, client) as script_path: + task = _make_submit_task(script_path) + return pipes_client.run( + task=task, + context=context, + extras={"multiplier": 2, "storage_root": "fake"}, + ).get_results() result = materialize( [number_x], - resources={"pipes_client": PipesDatabricksClient(client)}, + resources={ + "pipes_client": PipesDatabricksClient( + client, + ) + }, raise_on_error=False, ) assert result.success mats = result.asset_materializations_for_node(number_x.op.name) assert mats[0].metadata["value"].value == 4 + captured = capsys.readouterr() + assert re.search(r"hello from databricks stdout\n", captured.out, re.MULTILINE) + assert re.search(r"hello from databricks stderr\n", captured.err, re.MULTILINE) @pytest.mark.skipif(IS_BUILDKITE, reason="Not configured to run on BK yet.")