Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Oct 11, 2023
1 parent 367bbeb commit 204dcaa
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 117 deletions.
1 change: 1 addition & 0 deletions python_modules/dagster/dagster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@
from dagster._core.pipes.subprocess import PipesSubprocessClient as PipesSubprocessClient
from dagster._core.pipes.utils import (
PipesBlobStoreMessageReader as PipesBlobStoreMessageReader,
PipesBlobStoreStdioReader as PipesBlobStoreStdioReader,
PipesEnvContextInjector as PipesEnvContextInjector,
PipesFileContextInjector as PipesFileContextInjector,
PipesFileMessageReader as PipesFileMessageReader,
Expand Down
145 changes: 85 additions & 60 deletions python_modules/dagster/dagster/_core/pipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import sys
import tempfile
import time
from abc import abstractmethod
from abc import ABC, abstractmethod
from contextlib import contextmanager
from threading import Event, Thread
from typing import Iterator, Optional
from typing import Iterator, Optional, TextIO

from dagster_pipes import (
PIPES_PROTOCOL_VERSION_FIELD,
Expand Down Expand Up @@ -213,6 +213,7 @@ def no_messages_debug_text(self) -> str:
WAIT_FOR_STDIO_LOGS_TIMEOUT = 60


@experimental
class PipesBlobStoreMessageReader(PipesMessageReader):
"""Message reader that reads a sequence of message chunks written by an external process into a
blob store such as S3, Azure blob storage, or GCS.
Expand All @@ -226,24 +227,37 @@ class PipesBlobStoreMessageReader(PipesMessageReader):
counter (starting from 1) on successful write, keeping counters on the read and write end in
sync.
If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when
`read_messages` is called. If they are not passed, then the reader performs no stdout/stderr
forwarding.
Args:
interval (float): interval in seconds between attempts to download a chunk
forward_stdout (bool): whether to forward stdout from the pipes process to Dagster.
forward_stderr (bool): whether to forward stderr from the pipes process to Dagster.
stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs.
stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs.
"""

interval: float
counter: int
forward_stdout: bool
forward_stderr: bool
stdout_reader: "PipesBlobStoreStdioReader"
stderr_reader: "PipesBlobStoreStdioReader"

def __init__(
self, interval: float = 10, forward_stdout: bool = False, forward_stderr: bool = False
self,
interval: float = 10,
stdout_reader: Optional["PipesBlobStoreStdioReader"] = None,
stderr_reader: Optional["PipesBlobStoreStdioReader"] = None,
):
self.interval = interval
self.counter = 1
self.forward_stdout = forward_stdout
self.forward_stderr = forward_stderr
self.stdout_reader = (
check.opt_inst_param(stdout_reader, "stdout_reader", PipesBlobStoreStdioReader)
or PipesNoOpStdioReader()
)
self.stderr_reader = (
check.opt_inst_param(stderr_reader, "stderr_reader", PipesBlobStoreStdioReader)
or PipesNoOpStdioReader()
)

@contextmanager
def read_messages(
Expand All @@ -263,33 +277,21 @@ def read_messages(
with self.get_params() as params:
is_task_complete = Event()
messages_thread = None
stdout_thread = None
stderr_thread = None
try:
messages_thread = Thread(
target=self._messages_thread, args=(handler, params, is_task_complete)
)
messages_thread.start()
if self.forward_stdout:
stdout_thread = Thread(
target=self._stdout_thread, args=(params, is_task_complete)
)
stdout_thread.start()
if self.forward_stderr:
stderr_thread = Thread(
target=self._stderr_thread, args=(params, is_task_complete)
)
stderr_thread.start()
self.stdout_reader.start(params, is_task_complete)
self.stderr_reader.start(params, is_task_complete)
yield params
finally:
self.wait_for_stdio_logs(params)
is_task_complete.set()
if messages_thread:
messages_thread.join()
if stdout_thread:
stdout_thread.join()
if stderr_thread:
stderr_thread.join()
self.stdout_reader.stop()
self.stderr_reader.stop()

# In cases where we are forwarding logs, in some cases the logs might not be written out until
# after the run completes. We wait for them to exist.
Expand All @@ -298,8 +300,8 @@ def wait_for_stdio_logs(self, params):
while (
datetime.datetime.now() - start_or_last_download
).seconds <= WAIT_FOR_STDIO_LOGS_TIMEOUT and (
(self.forward_stdout and not self.stdout_log_exists(params))
or (self.forward_stderr and not self.stderr_log_exists(params))
(self.stdout_reader and not self.stdout_reader.is_ready(params))
or (self.stderr_reader and not self.stderr_reader.is_ready(params))
):
time.sleep(5)

Expand All @@ -316,18 +318,6 @@ def get_params(self) -> Iterator[PipesParams]:
@abstractmethod
def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: ...

def download_stdout_chunk(self, params: PipesParams) -> Optional[str]:
raise NotImplementedError()

def stdout_log_exists(self, params: PipesParams) -> bool:
raise NotImplementedError()

def download_stderr_chunk(self, params: PipesParams) -> Optional[str]:
raise NotImplementedError()

def stderr_log_exists(self, params: PipesParams) -> bool:
raise NotImplementedError()

def _messages_thread(
self,
handler: "PipesMessageHandler",
Expand All @@ -349,39 +339,74 @@ def _messages_thread(
break
time.sleep(1)

def _stdout_thread(
self,
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
start_or_last_download = now
chunk = self.download_stdout_chunk(params)
if chunk:
sys.stdout.write(chunk)
elif is_task_complete.is_set():
break
time.sleep(10)

def _stderr_thread(
class PipesBlobStoreStdioReader(ABC):
@abstractmethod
def start(self, params: PipesParams, is_task_complete: Event) -> None: ...

@abstractmethod
def stop(self) -> None: ...

@abstractmethod
def is_ready(self, params: PipesParams) -> bool: ...


@experimental
class PipesChunkedStdioReader(PipesBlobStoreStdioReader):
"""Reader for reading stdout/stderr logs from a blob store such as S3, Azure blob storage, or GCS.
Args:
interval (float): interval in seconds between attempts to download a chunk.
target_stream (TextIO): The stream to which to write the logs. Typcially `sys.stdout` or `sys.stderr`.
"""

def __init__(self, *, interval: float = 10, target_stream: TextIO):
self.interval = interval
self.target_stream = target_stream
self.thread: Optional[Thread] = None

@abstractmethod
def download_log_chunk(self, params: PipesParams) -> Optional[str]: ...

def start(self, params: PipesParams, is_task_complete: Event) -> None:
self.thread = Thread(target=self._reader_thread, args=(params, is_task_complete))
self.thread.start()

def stop(self) -> None:
if self.thread:
self.thread.join()

def _reader_thread(
self,
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
if (
(now - start_or_last_download).seconds > self.interval or is_task_complete.is_set()
) and self.is_ready(params):
start_or_last_download = now
chunk = self.download_stderr_chunk(params)
chunk = self.download_log_chunk(params)
if chunk:
sys.stderr.write(chunk)
self.target_stream.write(chunk)
elif is_task_complete.is_set():
break
time.sleep(10)
time.sleep(self.interval)


class PipesNoOpStdioReader(PipesBlobStoreStdioReader):
"""Default implementation for a pipes stdio reader that does nothing."""

def start(self, params: PipesParams, is_task_complete: Event) -> None:
pass

def stop(self) -> None:
pass

def is_ready(self, params: PipesParams) -> bool:
return True


def extract_message_or_forward_to_stdout(handler: "PipesMessageHandler", log_line: str):
Expand Down
31 changes: 28 additions & 3 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,37 @@
from dagster._core.pipes.client import (
PipesParams,
)
from dagster._core.pipes.utils import PipesBlobStoreMessageReader
from dagster._core.pipes.utils import PipesBlobStoreMessageReader, PipesBlobStoreStdioReader


class PipesS3MessageReader(PipesBlobStoreMessageReader):
def __init__(self, *, interval: float = 10, bucket: str, client: boto3.client):
super().__init__(interval=interval)
"""Message reader that reads messages by periodically reading message chunks from a specified S3
bucket.
If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when
`read_messages` is called. If they are not passed, then the reader performs no stdout/stderr
forwarding.
Args:
interval (float): interval in seconds between attempts to download a chunk
bucket (str): The S3 bucket to read from.
client (WorkspaceClient): A boto3 client.
stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs.
stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs.
"""

def __init__(
self,
*,
interval: float = 10,
bucket: str,
client: boto3.client,
stdout_reader: Optional[PipesBlobStoreStdioReader] = None,
stderr_reader: Optional[PipesBlobStoreStdioReader] = None,
):
super().__init__(
interval=interval, stdout_reader=stdout_reader, stderr_reader=stderr_reader
)
self.bucket = check.str_param(bucket, "bucket")
self.key_prefix = "".join(random.choices(string.ascii_letters, k=30))
self.client = client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
PipesDatabricksClient as PipesDatabricksClient,
PipesDbfsContextInjector as PipesDbfsContextInjector,
PipesDbfsMessageReader as PipesDbfsMessageReader,
PipesDbfsStdioReader as PipesDbfsStdioReader,
)
from .resources import (
DatabricksClientResource as DatabricksClientResource,
Expand Down
Loading

0 comments on commit 204dcaa

Please sign in to comment.