Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Oct 11, 2023
1 parent 2c03d47 commit 367bbeb
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 51 deletions.
29 changes: 14 additions & 15 deletions python_modules/dagster-pipes/dagster_pipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from abc import ABC, abstractmethod
from contextlib import ExitStack, contextmanager
from io import StringIO
from threading import Event, Lock, Thread
from queue import Queue
from threading import Event, Thread
from typing import (
IO,
TYPE_CHECKING,
Expand Down Expand Up @@ -81,7 +82,7 @@ def _make_message(method: str, params: Optional[Mapping[str, Any]]) -> "PipesMes


class PipesMessage(TypedDict):
"""A message sent from the orchestration process to the external process."""
"""A message sent from the external process to the orchestration process."""

__dagster_pipes_version: str
method: str
Expand Down Expand Up @@ -534,19 +535,17 @@ class PipesBlobStoreMessageWriterChannel(PipesMessageWriterChannel):

def __init__(self, *, interval: float = 10):
self._interval = interval
self._lock = Lock()
self._buffer = []
self._buffer: Queue[PipesMessage] = Queue()
self._counter = 1

def write_message(self, message: PipesMessage) -> None:
with self._lock:
self._buffer.append(message)
self._buffer.put(message)

def flush_messages(self) -> Sequence[PipesMessage]:
with self._lock:
messages = list(self._buffer)
self._buffer.clear()
return messages
items = []
while not self._buffer.empty():
items.append(self._buffer.get())
return items

@abstractmethod
def upload_messages_chunk(self, payload: StringIO, index: int) -> None: ...
Expand All @@ -567,15 +566,15 @@ def buffered_upload_loop(self) -> Iterator[None]:
def _upload_loop(self, is_task_complete: Event) -> None:
start_or_last_upload = datetime.datetime.now()
while True:
num_pending = len(self._buffer)
now = datetime.datetime.now()
if num_pending == 0 and is_task_complete.is_set():
if self._buffer.empty() and is_task_complete.is_set():
break
elif is_task_complete.is_set() or (now - start_or_last_upload).seconds > self._interval:
payload = "\n".join([json.dumps(message) for message in self.flush_messages()])
self.upload_messages_chunk(StringIO(payload), self._counter)
start_or_last_upload = now
self._counter += 1
if len(payload) > 0:
self.upload_messages_chunk(StringIO(payload), self._counter)
start_or_last_upload = now
self._counter += 1
time.sleep(1)


Expand Down
121 changes: 104 additions & 17 deletions python_modules/dagster/dagster/_core/pipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def read_messages(
self,
handler: "PipesMessageHandler",
) -> Iterator[PipesParams]:
"""Set up a thread to read streaming messages from teh external process by tailing the
"""Set up a thread to read streaming messages from the external process by tailing the
target file.
Args:
Expand Down Expand Up @@ -208,6 +208,11 @@ def no_messages_debug_text(self) -> str:
return "Attempted to read messages from a local temporary file."


# Number of seconds to wait after an external process has completed for stdio logs to become
# available. If this is exceeded, proceed with exiting without picking up logs.
WAIT_FOR_STDIO_LOGS_TIMEOUT = 60


class PipesBlobStoreMessageReader(PipesMessageReader):
"""Message reader that reads a sequence of message chunks written by an external process into a
blob store such as S3, Azure blob storage, or GCS.
Expand All @@ -223,14 +228,22 @@ class PipesBlobStoreMessageReader(PipesMessageReader):
Args:
interval (float): interval in seconds between attempts to download a chunk
forward_stdout (bool): whether to forward stdout from the pipes process to Dagster.
forward_stderr (bool): whether to forward stderr from the pipes process to Dagster.
"""

interval: float
counter: int
forward_stdout: bool
forward_stderr: bool

def __init__(self, interval: float = 10):
def __init__(
self, interval: float = 10, forward_stdout: bool = False, forward_stderr: bool = False
):
self.interval = interval
self.counter = 1
self.forward_stdout = forward_stdout
self.forward_stderr = forward_stderr

@contextmanager
def read_messages(
Expand All @@ -249,23 +262,46 @@ def read_messages(
"""
with self.get_params() as params:
is_task_complete = Event()
thread = None
messages_thread = None
stdout_thread = None
stderr_thread = None
try:
thread = Thread(
target=self._reader_thread,
args=(
handler,
params,
is_task_complete,
),
daemon=True,
messages_thread = Thread(
target=self._messages_thread, args=(handler, params, is_task_complete)
)
thread.start()
messages_thread.start()
if self.forward_stdout:
stdout_thread = Thread(
target=self._stdout_thread, args=(params, is_task_complete)
)
stdout_thread.start()
if self.forward_stderr:
stderr_thread = Thread(
target=self._stderr_thread, args=(params, is_task_complete)
)
stderr_thread.start()
yield params
finally:
self.wait_for_stdio_logs(params)
is_task_complete.set()
if thread:
thread.join()
if messages_thread:
messages_thread.join()
if stdout_thread:
stdout_thread.join()
if stderr_thread:
stderr_thread.join()

# In cases where we are forwarding logs, in some cases the logs might not be written out until
# after the run completes. We wait for them to exist.
def wait_for_stdio_logs(self, params):
start_or_last_download = datetime.datetime.now()
while (
datetime.datetime.now() - start_or_last_download
).seconds <= WAIT_FOR_STDIO_LOGS_TIMEOUT and (
(self.forward_stdout and not self.stdout_log_exists(params))
or (self.forward_stderr and not self.stderr_log_exists(params))
):
time.sleep(5)

@abstractmethod
@contextmanager
Expand All @@ -280,15 +316,30 @@ def get_params(self) -> Iterator[PipesParams]:
@abstractmethod
def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]: ...

def _reader_thread(
self, handler: "PipesMessageHandler", params: PipesParams, is_task_complete: Event
def download_stdout_chunk(self, params: PipesParams) -> Optional[str]:
raise NotImplementedError()

def stdout_log_exists(self, params: PipesParams) -> bool:
raise NotImplementedError()

def download_stderr_chunk(self, params: PipesParams) -> Optional[str]:
raise NotImplementedError()

def stderr_log_exists(self, params: PipesParams) -> bool:
raise NotImplementedError()

def _messages_thread(
self,
handler: "PipesMessageHandler",
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
chunk = self.download_messages_chunk(self.counter, params)
start_or_last_download = now
chunk = self.download_messages_chunk(self.counter, params)
if chunk:
for line in chunk.split("\n"):
message = json.loads(line)
Expand All @@ -298,13 +349,49 @@ def _reader_thread(
break
time.sleep(1)

def _stdout_thread(
self,
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
start_or_last_download = now
chunk = self.download_stdout_chunk(params)
if chunk:
sys.stdout.write(chunk)
elif is_task_complete.is_set():
break
time.sleep(10)

def _stderr_thread(
self,
params: PipesParams,
is_task_complete: Event,
) -> None:
start_or_last_download = datetime.datetime.now()
while True:
now = datetime.datetime.now()
if (now - start_or_last_download).seconds > self.interval or is_task_complete.is_set():
start_or_last_download = now
chunk = self.download_stderr_chunk(params)
if chunk:
sys.stderr.write(chunk)
elif is_task_complete.is_set():
break
time.sleep(10)


def extract_message_or_forward_to_stdout(handler: "PipesMessageHandler", log_line: str):
# exceptions as control flow, you love to see it
try:
message = json.loads(log_line)
if PIPES_PROTOCOL_VERSION_FIELD in message.keys():
handler.handle_message(message)
else:
sys.stdout.writelines((log_line, "\n"))
except Exception:
# move non-message logs in to stdout for compute log capture
sys.stdout.writelines((log_line, "\n"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import string
import time
from contextlib import contextmanager
from typing import Iterator, Mapping, Optional
from typing import Iterator, Literal, Mapping, Optional

import dagster._check as check
from dagster._annotations import experimental
Expand All @@ -23,6 +23,7 @@
open_pipes_session,
)
from dagster_pipes import (
DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES,
PipesContextData,
PipesExtras,
PipesParams,
Expand Down Expand Up @@ -109,6 +110,12 @@ def run(
**(self.env or {}),
**pipes_session.get_bootstrap_env_vars(),
}
cluster_log_root = pipes_session.get_bootstrap_params()[
DAGSTER_PIPES_BOOTSTRAP_PARAM_NAMES["messages"]
]["cluster_log_root"]
submit_task_dict["new_cluster"]["cluster_log_conf"] = {
"dbfs": {"destination": f"dbfs:{cluster_log_root}"}
}
task = jobs.SubmitTask.from_dict(submit_task_dict)
run_id = self.client.jobs.submit(
tasks=[task],
Expand All @@ -135,6 +142,7 @@ def run(
f"Error running Databricks job: {run.state.state_message}"
)
time.sleep(5)
time.sleep(30) # 30 seconds to make sure logs are flushed
return PipesClientCompletedInvocation(tuple(pipes_session.get_results()))


Expand Down Expand Up @@ -200,22 +208,35 @@ class PipesDbfsMessageReader(PipesBlobStoreMessageReader):
Args:
interval (float): interval in seconds between attempts to download a chunk
client (WorkspaceClient): A databricks `WorkspaceClient` object.
cluster_log_root (Optional[str]): The root path on DBFS where the cluster logs are written.
If set, this will be used to read stderr/stdout logs.
"""

def __init__(self, *, interval: int = 10, client: WorkspaceClient):
super().__init__(interval=interval)
def __init__(
self,
*,
interval: int = 10,
client: WorkspaceClient,
forward_stdout: bool = False,
forward_stderr: bool = False,
):
super().__init__(
interval=interval, forward_stdout=forward_stdout, forward_stderr=forward_stderr
)
self.dbfs_client = files.DbfsAPI(client.api_client)
self.stdio_position = {"stdout": 0, "stderr": 0}

@contextmanager
def get_params(self) -> Iterator[PipesParams]:
with dbfs_tempdir(self.dbfs_client) as tempdir:
yield {"path": tempdir}
with dbfs_tempdir(self.dbfs_client) as messages_tempdir, dbfs_tempdir(
self.dbfs_client
) as logs_tempdir:
yield {"path": messages_tempdir, "cluster_log_root": logs_tempdir}

def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[str]:
message_path = os.path.join(params["path"], f"{index}.json")
try:
raw_message = self.dbfs_client.read(message_path)

# Files written to dbfs using the Python IO interface used in PipesDbfsMessageWriter are
# base64-encoded.
return base64.b64decode(raw_message.data).decode("utf-8")
Expand All @@ -225,6 +246,47 @@ def download_messages_chunk(self, index: int, params: PipesParams) -> Optional[s
except IOError:
return None

def download_stdout_chunk(self, params: PipesParams) -> Optional[str]:
return self._download_stdio_chunk(params, "stdout")

def stdout_log_exists(self, params) -> bool:
return self._get_stdio_log_path(params, "stdout") is not None

def download_stderr_chunk(self, params: PipesParams) -> Optional[str]:
return self._download_stdio_chunk(params, "stderr")

def stderr_log_exists(self, params) -> bool:
return self._get_stdio_log_path(params, "stderr") is not None

def _download_stdio_chunk(
self, params: PipesParams, stream: Literal["stdout", "stderr"]
) -> Optional[str]:
log_path = self._get_stdio_log_path(params, stream)
if log_path is None:
return None
else:
try:
read_response = self.dbfs_client.read(log_path)
assert read_response.data
content = base64.b64decode(read_response.data).decode("utf-8")
chunk = content[self.stdio_position[stream] :]
self.stdio_position[stream] = len(content)
return chunk
except IOError:
return None

# The directory containing logs will not exist until either 5 minutes have elapsed or the
# job has finished.
def _get_stdio_log_path(
self, params: PipesParams, stream: Literal["stdout", "stderr"]
) -> Optional[str]:
log_root_path = os.path.join(params["cluster_log_root"])
child_dirs = list(self.dbfs_client.list(log_root_path))
if len(child_dirs) > 0:
return f"dbfs:{child_dirs[0].path}/driver/{stream}"
else:
return None

def no_messages_debug_text(self) -> str:
return (
"Attempted to read messages from a temporary file in dbfs. Expected"
Expand Down
Loading

0 comments on commit 367bbeb

Please sign in to comment.