Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Oct 10, 2023
1 parent 7f92f57 commit 73b846e
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 108 deletions.
109 changes: 50 additions & 59 deletions python_modules/dagster/dagster/_core/pipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import sys
import tempfile
import time
from abc import abstractmethod
from abc import ABC, abstractmethod
from contextlib import contextmanager
from threading import Event, Thread
from typing import Iterator, Optional
from typing import Iterator, Optional, TextIO

from dagster_pipes import (
PIPES_PROTOCOL_VERSION_FIELD,
Expand Down Expand Up @@ -226,24 +226,31 @@ class PipesBlobStoreMessageReader(PipesMessageReader):
counter (starting from 1) on successful write, keeping counters on the read and write end in
sync.
If `stdout_reader` or `stderr_reader` are passed, this reader will also start them when
`read_messages` is called. If they are not passed, then the reader performs no stdout/stderr
forwarding.
Args:
interval (float): interval in seconds between attempts to download a chunk
forward_stdout (bool): whether to forward stdout from the pipes process to Dagster.
forward_stderr (bool): whether to forward stderr from the pipes process to Dagster.
stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs.
stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs.
"""

interval: float
counter: int
forward_stdout: bool
forward_stderr: bool
stdout_reader: Optional["PipesBlobStoreStdioReader"]
stderr_reader: Optional["PipesBlobStoreStdioReader"]

def __init__(
self, interval: float = 10, forward_stdout: bool = False, forward_stderr: bool = False
self,
interval: float = 10,
stdout_reader: Optional["PipesBlobStoreStdioReader"] = None,
stderr_reader: Optional["PipesBlobStoreStdioReader"] = None,
):
self.interval = interval
self.counter = 1
self.forward_stdout = forward_stdout
self.forward_stderr = forward_stderr
self.stdout_reader = stdout_reader
self.stderr_reader = stderr_reader

@contextmanager
def read_messages(
Expand All @@ -263,33 +270,25 @@ def read_messages(
with self.get_params() as params:
is_task_complete = Event()
messages_thread = None
stdout_thread = None
stderr_thread = None
try:
messages_thread = Thread(
target=self._messages_thread, args=(handler, params, is_task_complete)
)
messages_thread.start()
if self.forward_stdout:
stdout_thread = Thread(
target=self._stdout_thread, args=(params, is_task_complete)
)
stdout_thread.start()
if self.forward_stderr:
stderr_thread = Thread(
target=self._stderr_thread, args=(params, is_task_complete)
)
stderr_thread.start()
if self.stdout_reader:
self.stdout_reader.start(params, is_task_complete)
if self.stderr_reader:
self.stderr_reader.start(params, is_task_complete)
yield params
finally:
self.wait_for_stdio_logs(params)
is_task_complete.set()
if messages_thread:
messages_thread.join()
if stdout_thread:
stdout_thread.join()
if stderr_thread:
stderr_thread.join()
if self.stdout_reader:
self.stdout_reader.stop()
if self.stderr_reader:
self.stderr_reader.stop()

# In cases where we are forwarding logs, in some cases the logs might not be written out until
# after the run completes. We wait for them to exist.
Expand All @@ -298,8 +297,8 @@ def wait_for_stdio_logs(self, params):
while (
datetime.datetime.now() - start_or_last_download
).seconds <= WAIT_FOR_STDIO_LOGS_TIMEOUT and (
(self.forward_stdout and not self.stdout_log_exists(params))
or (self.forward_stderr and not self.stderr_log_exists(params))
(self.stdout_reader and not self.stdout_reader.log_exists(params))
or (self.stderr_reader and not self.stderr_reader.log_exists(params))
):
time.sleep(5)

Expand All @@ -316,18 +315,6 @@ def get_params(self) -> Iterator[PipesParams]:
@abstractmethod
def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: ...

def download_stdout_chunk(self, params: PipesParams) -> Optional[str]:
raise NotImplementedError()

def stdout_log_exists(self, params: PipesParams) -> bool:
raise NotImplementedError()

def download_stderr_chunk(self, params: PipesParams) -> Optional[str]:
raise NotImplementedError()

def stderr_log_exists(self, params: PipesParams) -> bool:
raise NotImplementedError()

def _messages_thread(
self,
handler: "PipesMessageHandler",
Expand All @@ -349,24 +336,28 @@ def _messages_thread(
break
time.sleep(1)

def _stdout_thread(
self,
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
start_or_last_download = now
chunk = self.download_stdout_chunk(params)
if chunk:
sys.stdout.write(chunk)
elif is_task_complete.is_set():
break
time.sleep(10)

def _stderr_thread(
class PipesBlobStoreStdioReader(ABC):
def __init__(self, *, interval: float = 10, target_stream: TextIO):
self.interval = interval
self.target_stream = target_stream
self.thread: Optional[Thread] = None

@abstractmethod
def download_log_chunk(self, params: PipesParams) -> Optional[str]: ...

@abstractmethod
def log_exists(self, params: PipesParams) -> bool: ...

def start(self, params: PipesParams, is_task_complete: Event) -> None:
self.thread = Thread(target=self._reader_thread, args=(params, is_task_complete))
self.thread.start()

def stop(self) -> None:
if self.thread:
self.thread.join()

def _reader_thread(
self,
params: PipesParams,
is_task_complete: Event,
Expand All @@ -376,12 +367,12 @@ def _stderr_thread(
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
start_or_last_download = now
chunk = self.download_stderr_chunk(params)
chunk = self.download_log_chunk(params)
if chunk:
sys.stderr.write(chunk)
self.target_stream.write(chunk)
elif is_task_complete.is_set():
break
time.sleep(10)
time.sleep(self.interval)


def extract_message_or_forward_to_stdout(handler: "PipesMessageHandler", log_line: str):
Expand Down
27 changes: 24 additions & 3 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,33 @@
from dagster._core.pipes.client import (
PipesParams,
)
from dagster._core.pipes.utils import PipesBlobStoreMessageReader
from dagster._core.pipes.utils import PipesBlobStoreMessageReader, PipesBlobStoreStdioReader


class PipesS3MessageReader(PipesBlobStoreMessageReader):
def __init__(self, *, interval: float = 10, bucket: str, client: boto3.client):
super().__init__(interval=interval)
"""Message reader that reads messages by periodically reading message chunks from a specified S3
bucket.
Args:
interval (float): interval in seconds between attempts to download a chunk
bucket (str): The S3 bucket to read from.
client (WorkspaceClient): A boto3 client.
stdout_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stdout logs.
stderr_reader (Optional[PipesBlobStoreStdioReader]): A reader for reading stderr logs.
"""

def __init__(
self,
*,
interval: float = 10,
bucket: str,
client: boto3.client,
stdout_reader: Optional[PipesBlobStoreStdioReader] = None,
stderr_reader: Optional[PipesBlobStoreStdioReader] = None,
):
super().__init__(
interval=interval, stdout_reader=stdout_reader, stderr_reader=stderr_reader
)
self.bucket = check.str_param(bucket, "bucket")
self.key_prefix = "".join(random.choices(string.ascii_letters, k=30))
self.client = client
Expand Down
Loading

0 comments on commit 73b846e

Please sign in to comment.