Skip to content

Commit

Permalink
[pipes] databricks unstructured log forwarding (#16674)
Browse files Browse the repository at this point in the history
## Summary & Motivation

This adds stdout/stderr forwarding to the dagster-databricks pipes
integration. It was a long road getting here with several dead ends.

The current approach in this PR is to modify
`PipesBlobStoreMessageReader` with `forward_{stdout,stderr}` boolean
params and corresponding hooks for downloading stdout, stderr chunks. If
`forward_{stdout,stderr}` is enabled, threads will be launched for the
streams (alongside the message chunk thread) that periodically download
stderr/stdout chunks and write them to the corresponding orchestration
process streams.

In the `PipesDbfsMessageReader`, instead of using an incrementing
counter (as is used for messages), the stdout/stderr chunk downloaders
are written to track a string index in the file corresponding to the
stream. We repeatedly download the full file. Every time the file is
downloaded, we only forward starting from the offset index. This
approach of repeatedly downloading the full file and applying the offset
only on the orchestration end can surely be improved to just download
starting from the offset, but I did not implement that yet (there are
some concerns around getting the indexing right given that the files are
stored as base64, I think with padding).

While it is possible that other integrations would need to make some
changes on the pipes end, I didn't need to make any for databricks,
because a DBFS location for `stdout`/`stderr` is configured when
launching the job, so we don't need to do anything in the orchestration
process. This introduces a potential asymmetry between the
`PipesMessageReader` and `PipesMessageWriter`-- probably we will end up
with just a `PipesReader`.

There are definitely some other rough patches here:

- Databricks does not let you directly configure the directory where
stdout/stderr are written. Instead you set a root directory, and then
logs get stored in `<root>/<cluster-id>/driver/{stdout,stderr}`. This
introduces a difficulty because the cluster id does not exist until the
job is launched (and you can't set it manually). Because the message
reader gets set up before the job is launched, the message reader
doesn't know where to look.
- I got around this by setting the log root to a temporary directory and
polling that directory for the first child to appear, which will be
where the logs are stored. This is not ideal because users may want to
retain the logs in DBFS.
- Another approach would be to send the cluster id back in the new
`opened` message, but to then thread this into the message reader
requires additional plumbing work.
 
For those who want to play with this, the workflow is to repeatedly run
`dagster_databricks_tests/test_pipes.py::test_pipes_client`. This
requires you to have `DATABRICKS_HOST` and `DATABRICKS_TOKEN` set in
your env. `DATABRICKS_HOST` should be
`https://dbc-07902917-6487.cloud.databricks.com`. `DATABRICKS_TOKEN`
should be set to a value you generate by going to User Settings >
Developer > Access Tokens in the Databricks UI.

## How I Tested These Changes

Tested via `capsys` to make sure logs are forwarded.
  • Loading branch information
smackesey committed Oct 11, 2023
1 parent 4926467 commit 5072a06
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 58 deletions.
29 changes: 14 additions & 15 deletions python_modules/dagster-pipes/dagster_pipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from abc import ABC, abstractmethod
from contextlib import ExitStack, contextmanager
from io import StringIO
from threading import Event, Lock, Thread
from queue import Queue
from threading import Event, Thread
from typing import (
IO,
TYPE_CHECKING,
Expand Down Expand Up @@ -66,7 +67,7 @@ def _make_message(method: str, params: Optional[Mapping[str, Any]]) -> "PipesMes


class PipesMessage(TypedDict):
"""A message sent from the orchestration process to the external process."""
"""A message sent from the external process to the orchestration process."""

__dagster_pipes_version: str
method: str
Expand Down Expand Up @@ -513,19 +514,17 @@ class PipesBlobStoreMessageWriterChannel(PipesMessageWriterChannel):

def __init__(self, *, interval: float = 10):
self._interval = interval
self._lock = Lock()
self._buffer = []
self._buffer: Queue[PipesMessage] = Queue()
self._counter = 1

def write_message(self, message: PipesMessage) -> None:
with self._lock:
self._buffer.append(message)
self._buffer.put(message)

def flush_messages(self) -> Sequence[PipesMessage]:
with self._lock:
messages = list(self._buffer)
self._buffer.clear()
return messages
items = []
while not self._buffer.empty():
items.append(self._buffer.get())
return items

@abstractmethod
def upload_messages_chunk(self, payload: StringIO, index: int) -> None: ...
Expand All @@ -546,15 +545,15 @@ def buffered_upload_loop(self) -> Iterator[None]:
def _upload_loop(self, is_task_complete: Event) -> None:
start_or_last_upload = datetime.datetime.now()
while True:
num_pending = len(self._buffer)
now = datetime.datetime.now()
if num_pending == 0 and is_task_complete.is_set():
if self._buffer.empty() and is_task_complete.is_set():
break
elif is_task_complete.is_set() or (now - start_or_last_upload).seconds > self._interval:
payload = "\n".join([json.dumps(message) for message in self.flush_messages()])
self.upload_messages_chunk(StringIO(payload), self._counter)
start_or_last_upload = now
self._counter += 1
if len(payload) > 0:
self.upload_messages_chunk(StringIO(payload), self._counter)
start_or_last_upload = now
self._counter += 1
time.sleep(1)


Expand Down
1 change: 1 addition & 0 deletions python_modules/dagster/dagster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@
from dagster._core.pipes.subprocess import PipesSubprocessClient as PipesSubprocessClient
from dagster._core.pipes.utils import (
PipesBlobStoreMessageReader as PipesBlobStoreMessageReader,
PipesBlobStoreStdioReader as PipesBlobStoreStdioReader,
PipesEnvContextInjector as PipesEnvContextInjector,
PipesFileContextInjector as PipesFileContextInjector,
PipesFileMessageReader as PipesFileMessageReader,
Expand Down
150 changes: 131 additions & 19 deletions python_modules/dagster/dagster/_core/pipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import sys
import tempfile
import time
from abc import abstractmethod
from abc import ABC, abstractmethod
from contextlib import contextmanager
from threading import Event, Thread
from typing import Iterator, Optional
from typing import Iterator, Optional, TextIO

from dagster_pipes import (
PIPES_PROTOCOL_VERSION_FIELD,
Expand Down Expand Up @@ -144,7 +144,7 @@ def read_messages(
self,
handler: "PipesMessageHandler",
) -> Iterator[PipesParams]:
"""Set up a thread to read streaming messages from teh external process by tailing the
"""Set up a thread to read streaming messages from the external process by tailing the
target file.
Args:
Expand Down Expand Up @@ -208,6 +208,12 @@ def no_messages_debug_text(self) -> str:
return "Attempted to read messages from a local temporary file."


# Number of seconds to wait after an external process has completed for stdio logs to become
# available. If this is exceeded, proceed with exiting without picking up logs.
WAIT_FOR_STDIO_LOGS_TIMEOUT = 60


@experimental
class PipesBlobStoreMessageReader(PipesMessageReader):
"""Message reader that reads a sequence of message chunks written by an external process into a
blob store such as S3, Azure blob storage, or GCS.
Expand All @@ -221,16 +227,37 @@ class PipesBlobStoreMessageReader(PipesMessageReader):
counter (starting from 1) on successful write, keeping counters on the read and write end in
sync.
If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when
`read_messages` is called. If they are not passed, then the reader performs no stdout/stderr
forwarding.
Args:
interval (float): interval in seconds between attempts to download a chunk
stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs.
stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs.
"""

interval: float
counter: int
stdout_reader: "PipesBlobStoreStdioReader"
stderr_reader: "PipesBlobStoreStdioReader"

def __init__(self, interval: float = 10):
def __init__(
self,
interval: float = 10,
stdout_reader: Optional["PipesBlobStoreStdioReader"] = None,
stderr_reader: Optional["PipesBlobStoreStdioReader"] = None,
):
self.interval = interval
self.counter = 1
self.stdout_reader = (
check.opt_inst_param(stdout_reader, "stdout_reader", PipesBlobStoreStdioReader)
or PipesNoOpStdioReader()
)
self.stderr_reader = (
check.opt_inst_param(stderr_reader, "stderr_reader", PipesBlobStoreStdioReader)
or PipesNoOpStdioReader()
)

@contextmanager
def read_messages(
Expand All @@ -249,23 +276,34 @@ def read_messages(
"""
with self.get_params() as params:
is_task_complete = Event()
thread = None
messages_thread = None
try:
thread = Thread(
target=self._reader_thread,
args=(
handler,
params,
is_task_complete,
),
daemon=True,
messages_thread = Thread(
target=self._messages_thread, args=(handler, params, is_task_complete)
)
thread.start()
messages_thread.start()
self.stdout_reader.start(params, is_task_complete)
self.stderr_reader.start(params, is_task_complete)
yield params
finally:
self.wait_for_stdio_logs(params)
is_task_complete.set()
if thread:
thread.join()
if messages_thread:
messages_thread.join()
self.stdout_reader.stop()
self.stderr_reader.stop()

# In cases where we are forwarding logs, in some cases the logs might not be written out until
# after the run completes. We wait for them to exist.
def wait_for_stdio_logs(self, params):
start_or_last_download = datetime.datetime.now()
while (
datetime.datetime.now() - start_or_last_download
).seconds <= WAIT_FOR_STDIO_LOGS_TIMEOUT and (
(self.stdout_reader and not self.stdout_reader.is_ready(params))
or (self.stderr_reader and not self.stderr_reader.is_ready(params))
):
time.sleep(5)

@abstractmethod
@contextmanager
Expand All @@ -280,15 +318,18 @@ def get_params(self) -> Iterator[PipesParams]:
@abstractmethod
def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: ...

def _reader_thread(
self, handler: "PipesMessageHandler", params: PipesParams, is_task_complete: Event
def _messages_thread(
self,
handler: "PipesMessageHandler",
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
chunk = self.download_messages_chunk(self.counter, params)
start_or_last_download = now
chunk = self.download_messages_chunk(self.counter, params)
if chunk:
for line in chunk.split("\n"):
message = json.loads(line)
Expand All @@ -299,12 +340,83 @@ def _reader_thread(
time.sleep(1)


class PipesBlobStoreStdioReader(ABC):
@abstractmethod
def start(self, params: PipesParams, is_task_complete: Event) -> None: ...

@abstractmethod
def stop(self) -> None: ...

@abstractmethod
def is_ready(self, params: PipesParams) -> bool: ...


@experimental
class PipesChunkedStdioReader(PipesBlobStoreStdioReader):
"""Reader for reading stdout/stderr logs from a blob store such as S3, Azure blob storage, or GCS.
Args:
interval (float): interval in seconds between attempts to download a chunk.
target_stream (TextIO): The stream to which to write the logs. Typcially `sys.stdout` or `sys.stderr`.
"""

def __init__(self, *, interval: float = 10, target_stream: TextIO):
self.interval = interval
self.target_stream = target_stream
self.thread: Optional[Thread] = None

@abstractmethod
def download_log_chunk(self, params: PipesParams) -> Optional[str]: ...

def start(self, params: PipesParams, is_task_complete: Event) -> None:
self.thread = Thread(target=self._reader_thread, args=(params, is_task_complete))
self.thread.start()

def stop(self) -> None:
if self.thread:
self.thread.join()

def _reader_thread(
self,
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (
(now - start_or_last_download).seconds > self.interval or is_task_complete.is_set()
) and self.is_ready(params):
start_or_last_download = now
chunk = self.download_log_chunk(params)
if chunk:
self.target_stream.write(chunk)
elif is_task_complete.is_set():
break
time.sleep(self.interval)


class PipesNoOpStdioReader(PipesBlobStoreStdioReader):
"""Default implementation for a pipes stdio reader that does nothing."""

def start(self, params: PipesParams, is_task_complete: Event) -> None:
pass

def stop(self) -> None:
pass

def is_ready(self, params: PipesParams) -> bool:
return True


def extract_message_or_forward_to_stdout(handler: "PipesMessageHandler", log_line: str):
# exceptions as control flow, you love to see it
try:
message = json.loads(log_line)
if PIPES_PROTOCOL_VERSION_FIELD in message.keys():
handler.handle_message(message)
else:
sys.stdout.writelines((log_line, "\n"))
except Exception:
# move non-message logs in to stdout for compute log capture
sys.stdout.writelines((log_line, "\n"))
Expand Down
31 changes: 28 additions & 3 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,37 @@
from dagster._core.pipes.client import (
PipesParams,
)
from dagster._core.pipes.utils import PipesBlobStoreMessageReader
from dagster._core.pipes.utils import PipesBlobStoreMessageReader, PipesBlobStoreStdioReader


class PipesS3MessageReader(PipesBlobStoreMessageReader):
def __init__(self, *, interval: float = 10, bucket: str, client: boto3.client):
super().__init__(interval=interval)
"""Message reader that reads messages by periodically reading message chunks from a specified S3
bucket.
If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when
`read_messages` is called. If they are not passed, then the reader performs no stdout/stderr
forwarding.
Args:
interval (float): interval in seconds between attempts to download a chunk
bucket (str): The S3 bucket to read from.
client (WorkspaceClient): A boto3 client.
stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs.
stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs.
"""

def __init__(
self,
*,
interval: float = 10,
bucket: str,
client: boto3.client,
stdout_reader: Optional[PipesBlobStoreStdioReader] = None,
stderr_reader: Optional[PipesBlobStoreStdioReader] = None,
):
super().__init__(
interval=interval, stdout_reader=stdout_reader, stderr_reader=stderr_reader
)
self.bucket = check.str_param(bucket, "bucket")
self.key_prefix = "".join(random.choices(string.ascii_letters, k=30))
self.client = client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
PipesDatabricksClient as PipesDatabricksClient,
PipesDbfsContextInjector as PipesDbfsContextInjector,
PipesDbfsMessageReader as PipesDbfsMessageReader,
PipesDbfsStdioReader as PipesDbfsStdioReader,
)
from .resources import (
DatabricksClientResource as DatabricksClientResource,
Expand Down
Loading

0 comments on commit 5072a06

Please sign in to comment.