Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Oct 10, 2023
1 parent 7f92f57 commit 83f3d8c
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 101 deletions.
87 changes: 36 additions & 51 deletions python_modules/dagster/dagster/_core/pipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import sys
import tempfile
import time
from abc import abstractmethod
from abc import ABC, abstractmethod
from contextlib import contextmanager
from threading import Event, Thread
from typing import Iterator, Optional
from typing import Iterator, Optional, TextIO

from dagster_pipes import (
PIPES_PROTOCOL_VERSION_FIELD,
Expand Down Expand Up @@ -234,16 +234,19 @@ class PipesBlobStoreMessageReader(PipesMessageReader):

interval: float
counter: int
forward_stdout: bool
forward_stderr: bool
stdout_reader: Optional["PipesBlobStoreStdioReader"]
stderr_reader: Optional["PipesBlobStoreStdioReader"]

def __init__(
self, interval: float = 10, forward_stdout: bool = False, forward_stderr: bool = False
self,
interval: float = 10,
stdout_reader: Optional["PipesBlobStoreStdioReader"] = None,
stderr_reader: Optional["PipesBlobStoreStdioReader"] = None,
):
self.interval = interval
self.counter = 1
self.forward_stdout = forward_stdout
self.forward_stderr = forward_stderr
self.stdout_reader = stdout_reader
self.stderr_reader = stderr_reader

@contextmanager
def read_messages(
Expand All @@ -270,16 +273,10 @@ def read_messages(
target=self._messages_thread, args=(handler, params, is_task_complete)
)
messages_thread.start()
if self.forward_stdout:
stdout_thread = Thread(
target=self._stdout_thread, args=(params, is_task_complete)
)
stdout_thread.start()
if self.forward_stderr:
stderr_thread = Thread(
target=self._stderr_thread, args=(params, is_task_complete)
)
stderr_thread.start()
if self.stdout_reader:
stdout_thread = self.stdout_reader.start_thread(params, is_task_complete)
if self.stderr_reader:
stderr_thread = self.stderr_reader.start_thread(params, is_task_complete)
yield params
finally:
self.wait_for_stdio_logs(params)
Expand All @@ -298,8 +295,8 @@ def wait_for_stdio_logs(self, params):
while (
datetime.datetime.now() - start_or_last_download
).seconds <= WAIT_FOR_STDIO_LOGS_TIMEOUT and (
(self.forward_stdout and not self.stdout_log_exists(params))
or (self.forward_stderr and not self.stderr_log_exists(params))
(self.stdout_reader and not self.stdout_reader.log_exists(params))
or (self.stderr_reader and not self.stderr_reader.log_exists(params))
):
time.sleep(5)

Expand All @@ -316,18 +313,6 @@ def get_params(self) -> Iterator[PipesParams]:
@abstractmethod
def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: ...

def download_stdout_chunk(self, params: PipesParams) -> Optional[str]:
raise NotImplementedError()

def stdout_log_exists(self, params: PipesParams) -> bool:
raise NotImplementedError()

def download_stderr_chunk(self, params: PipesParams) -> Optional[str]:
raise NotImplementedError()

def stderr_log_exists(self, params: PipesParams) -> bool:
raise NotImplementedError()

def _messages_thread(
self,
handler: "PipesMessageHandler",
Expand All @@ -349,24 +334,24 @@ def _messages_thread(
break
time.sleep(1)

def _stdout_thread(
self,
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
start_or_last_download = now
chunk = self.download_stdout_chunk(params)
if chunk:
sys.stdout.write(chunk)
elif is_task_complete.is_set():
break
time.sleep(10)

def _stderr_thread(
class PipesBlobStoreStdioReader(ABC):
def __init__(self, *, interval: float = 10, target_stream: TextIO):
self.interval = interval
self.target_stream = target_stream

@abstractmethod
def download_log_chunk(self, params: PipesParams) -> Optional[str]: ...

@abstractmethod
def log_exists(self, params: PipesParams) -> bool: ...

def start_thread(self, params: PipesParams, is_task_complete: Event) -> Thread:
thread = Thread(target=self._reader_thread, args=(params, is_task_complete))
thread.start()
return thread

def _reader_thread(
self,
params: PipesParams,
is_task_complete: Event,
Expand All @@ -376,12 +361,12 @@ def _stderr_thread(
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
start_or_last_download = now
chunk = self.download_stderr_chunk(params)
chunk = self.download_log_chunk(params)
if chunk:
sys.stderr.write(chunk)
self.target_stream.write(chunk)
elif is_task_complete.is_set():
break
time.sleep(10)
time.sleep(self.interval)


def extract_message_or_forward_to_stdout(handler: "PipesMessageHandler", log_line: str):
Expand Down
16 changes: 13 additions & 3 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,22 @@
from dagster._core.pipes.client import (
PipesParams,
)
from dagster._core.pipes.utils import PipesBlobStoreMessageReader
from dagster._core.pipes.utils import PipesBlobStoreMessageReader, PipesBlobStoreStdioReader


class PipesS3MessageReader(PipesBlobStoreMessageReader):
def __init__(self, *, interval: float = 10, bucket: str, client: boto3.client):
super().__init__(interval=interval)
def __init__(
self,
*,
interval: float = 10,
bucket: str,
client: boto3.client,
stdout_reader: Optional[PipesBlobStoreStdioReader] = None,
stderr_reader: Optional[PipesBlobStoreStdioReader] = None,
):
super().__init__(
interval=interval, stdout_reader=stdout_reader, stderr_reader=stderr_reader
)
self.bucket = check.str_param(bucket, "bucket")
self.key_prefix = "".join(random.choices(string.ascii_letters, k=30))
self.client = client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import os
import random
import string
import sys
import time
from contextlib import contextmanager
from typing import Iterator, Literal, Mapping, Optional
from contextlib import ExitStack, contextmanager
from typing import Iterator, Literal, Mapping, Optional, TextIO

import dagster._check as check
from dagster._annotations import experimental
Expand All @@ -20,6 +21,7 @@
)
from dagster._core.pipes.utils import (
PipesBlobStoreMessageReader,
PipesBlobStoreStdioReader,
open_pipes_session,
)
from dagster_pipes import (
Expand Down Expand Up @@ -70,7 +72,15 @@ def __init__(
message_reader,
"message_reader",
PipesMessageReader,
) or PipesDbfsMessageReader(client=self.client)
) or PipesDbfsMessageReader(
client=self.client,
stdout_reader=PipesDbfsStdioReader(
client=self.client, remote_log_name="stdout", target_stream=sys.stdout
),
stderr_reader=PipesDbfsStdioReader(
client=self.client, remote_log_name="stderr", target_stream=sys.stderr
),
)

def run(
self,
Expand Down Expand Up @@ -112,10 +122,11 @@ def run(
}
cluster_log_root = pipes_session.get_bootstrap_params()[
DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES["messages"]
]["cluster_log_root"]
submit_task_dict["new_cluster"]["cluster_log_conf"] = {
"dbfs": {"destination": f"dbfs:{cluster_log_root}"}
}
].get("cluster_log_root")
if cluster_log_root is not None:
submit_task_dict["new_cluster"]["cluster_log_conf"] = {
"dbfs": {"destination": f"dbfs:{cluster_log_root}"}
}
task = jobs.SubmitTask.from_dict(submit_task_dict)
run_id = self.client.jobs.submit(
tasks=[task],
Expand Down Expand Up @@ -217,21 +228,22 @@ def __init__(
*,
interval: int = 10,
client: WorkspaceClient,
forward_stdout: bool = False,
forward_stderr: bool = False,
stdout_reader: Optional[PipesBlobStoreStdioReader] = None,
stderr_reader: Optional[PipesBlobStoreStdioReader] = None,
):
super().__init__(
interval=interval, forward_stdout=forward_stdout, forward_stderr=forward_stderr
interval=interval, stdout_reader=stdout_reader, stderr_reader=stderr_reader
)
self.dbfs_client = files.DbfsAPI(client.api_client)
self.stdio_position = {"stdout": 0, "stderr": 0}

@contextmanager
def get_params(self) -> Iterator[PipesParams]:
with dbfs_tempdir(self.dbfs_client) as messages_tempdir, dbfs_tempdir(
self.dbfs_client
) as logs_tempdir:
yield {"path": messages_tempdir, "cluster_log_root": logs_tempdir}
with ExitStack() as stack:
params: PipesParams = {}
params["path"] = stack.enter_context(dbfs_tempdir(self.dbfs_client))
if self.stdout_reader or self.stderr_reader:
params["cluster_log_root"] = stack.enter_context(dbfs_tempdir(self.dbfs_client))
yield params

def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]:
message_path = os.path.join(params["path"], f"{index}.json")
Expand All @@ -246,50 +258,52 @@ def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[s
except IOError:
return None

def download_stdout_chunk(self, params: PipesParams) -> Optional[str]:
return self._download_stdio_chunk(params, "stdout")

def stdout_log_exists(self, params) -> bool:
return self._get_stdio_log_path(params, "stdout") is not None
def no_messages_debug_text(self) -> str:
return (
"Attempted to read messages from a temporary file in dbfs. Expected"
" PipesDbfsMessageWriter to be explicitly passed to open_dagster_pipes in the external"
" process."
)

def download_stderr_chunk(self, params: PipesParams) -> Optional[str]:
return self._download_stdio_chunk(params, "stderr")

def stderr_log_exists(self, params) -> bool:
return self._get_stdio_log_path(params, "stderr") is not None
class PipesDbfsStdioReader(PipesBlobStoreStdioReader):
def __init__(
self,
*,
interval: float = 10,
remote_log_name: Literal["stdout", "stderr"],
target_stream: TextIO,
client: WorkspaceClient,
):
super().__init__(interval=interval, target_stream=target_stream)
self.dbfs_client = files.DbfsAPI(client.api_client)
self.remote_log_name = remote_log_name
self.log_position = 0

def _download_stdio_chunk(
self, params: PipesParams, stream: Literal["stdout", "stderr"]
) -> Optional[str]:
log_path = self._get_stdio_log_path(params, stream)
def download_log_chunk(self, params: PipesParams) -> Optional[str]:
log_path = self._get_log_path(params)
if log_path is None:
return None
else:
try:
read_response = self.dbfs_client.read(log_path)
assert read_response.data
content = base64.b64decode(read_response.data).decode("utf-8")
chunk = content[self.stdio_position[stream] :]
self.stdio_position[stream] = len(content)
chunk = content[self.log_position :]
self.log_position = len(content)
return chunk
except IOError:
return None

def log_exists(self, params) -> bool:
return self._get_log_path(params) is not None

# The directory containing logs will not exist until either 5 minutes have elapsed or the
# job has finished.
def _get_stdio_log_path(
self, params: PipesParams, stream: Literal["stdout", "stderr"]
) -> Optional[str]:
def _get_log_path(self, params: PipesParams) -> Optional[str]:
log_root_path = os.path.join(params["cluster_log_root"])
child_dirs = list(self.dbfs_client.list(log_root_path))
if len(child_dirs) > 0:
return f"dbfs:{child_dirs[0].path}/driver/{stream}"
return f"dbfs:{child_dirs[0].path}/driver/{self.remote_log_name}"
else:
return None

def no_messages_debug_text(self) -> str:
return (
"Attempted to read messages from a temporary file in dbfs. Expected"
" PipesDbfsMessageWriter to be explicitly passed to open_dagster_pipes in the external"
" process."
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
from dagster import AssetExecutionContext, asset, materialize
from dagster._core.errors import DagsterPipesExecutionError
from dagster_databricks.pipes import PipesDatabricksClient, PipesDbfsMessageReader, dbfs_tempdir
from dagster_databricks.pipes import PipesDatabricksClient, dbfs_tempdir
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import files, jobs

Expand Down Expand Up @@ -81,6 +81,8 @@ def get_repo_root() -> str:
return path


# Upload the Dagster Pipes wheel to DBFS. Use this fixture to avoid needing to manually reupload
# dagster-pipes if it has changed between test runs.
@contextmanager
def upload_dagster_pipes_whl(client: WorkspaceClient) -> Iterator[None]:
repo_root = get_repo_root()
Expand Down Expand Up @@ -129,11 +131,6 @@ def number_x(context: AssetExecutionContext, pipes_client: PipesDatabricksClient
resources={
"pipes_client": PipesDatabricksClient(
client,
message_reader=PipesDbfsMessageReader(
client=client,
forward_stdout=True,
forward_stderr=True,
),
)
},
raise_on_error=False,
Expand Down

0 comments on commit 83f3d8c

Please sign in to comment.