diff --git a/dagrunner/plugin_framework.py b/dagrunner/plugin_framework.py index 351059a..9672c2c 100644 --- a/dagrunner/plugin_framework.py +++ b/dagrunner/plugin_framework.py @@ -2,19 +2,17 @@ # # This file is part of 'dagrunner' and is released under the BSD 3-Clause license. # See LICENSE in the root of the repository for full licensing details. -import itertools import json import os import pickle import shutil import string import subprocess -import time import warnings from abc import ABC, abstractmethod from glob import glob -from dagrunner.utils import process_path, Singleton +from dagrunner.utils import process_path, Singleton, data_polling, stage_to_dir class _EventBase: @@ -103,62 +101,21 @@ def __call__(self, *args, **kwargs): return subprocess.run(*args, **kwargs, shell=True, check=True) -def _stage_to_dir(*args, staging_dir, verbose=False): - """ - Copy input filepaths to a staging area and update paths. - - Hard link copies are preferred (same host) and physical copies are made otherwise. - File name, size and modification time are used to evaluate if the destination file - exists already (matching criteria of rsync). If exists already, skip the copy. - Staged files are named: `__` to avoid - collision with identically names files. - """ - os.makedirs(staging_dir, exist_ok=True) - args = list(args) - for ind, arg in enumerate(args): - host, fpath = None, arg - if ":" in arg: - host, fpath = arg.split(":") - - if host: - source_mtime_size = subprocess.run( - ["ssh", host, "stat", "-c", "%Y_%s", fpath], - check=True, - text=True, - capture_output=True, - ).stdout.strip() - else: - source_mtime_size = ( - f"{int(os.path.getmtime(fpath))}_{os.path.getsize(fpath)}" - ) - - target = os.path.join( - staging_dir, f"{source_mtime_size}_{os.path.basename(fpath)}" - ) - if not os.path.exists(target): - if host: - rsync_command = ["scp", "-p", f"{host}:{fpath}", target] - subprocess.run( - rsync_command, check=True, text=True, capture_output=True - ) - else: - try: - os.link(arg, target) - except Exception: - warnings.warn( - f"Failed to hard link {arg} to {target}. Copying instead." - ) - shutil.copy2(arg, target) - else: - warnings.warn(f"Staged file {target} already exists. Skipping copy.") +class Load(Plugin): + def __init__(self, staging_dir=None, ignore_missing=False, verbose=False): + """ + Load data from a file. - args[ind] = target - if verbose: - print(f"Staged {arg} to {args[ind]}") - return args + The `load` method must be implemented by the subclass. + Args: + - staging_dir: Directory to stage files in. + - verbose: Print verbose output. + """ + self._staging_dir = staging_dir + self._verbose = verbose + self._ignore_missing = ignore_missing -class Load(Plugin): @abstractmethod def load(self, *args, **kwargs): """ @@ -176,175 +133,64 @@ def load(self, *args, **kwargs): """ raise NotImplementedError - def __call__(self, *args, staging_dir=None, verbose=False, **kwargs): + def __call__(self, *args, **kwargs): """ Load data from a file or list of files. Args: - *args: List of filepaths to load. `:` syntax supported for loading files from a remote host. - - staging_dir: Directory to stage files in. If the staging directory doesn't - exist, then create it. - - verbose: Print verbose output. + - **kwargs: Keyword arguments to pass to. """ args = list(map(process_path, args)) if ( any([arg.split(":")[0] for arg in args if ":" in arg]) - and staging_dir is None + and self._staging_dir is None ): raise ValueError( "Staging directory must be specified for loading remote files." ) - if staging_dir and args: - args = _stage_to_dir(*args, staging_dir=staging_dir, verbose=verbose) + if self._staging_dir and args: + try: + args = stage_to_dir(*args, staging_dir=self._staging_dir, verbose=self._verbose) + except FileNotFoundError as e: + if self._ignore_missing: + warnings.warn(str(e)) + return SKIP_EVENT + raise e + else: + missing_files = [not os.path.exists(arg) for arg in args] + if any(missing_files): + if self._ignore_missing: + warnings.warn("Ignoring missing files.") + return SKIP_EVENT + else: + raise FileNotFoundError( + f"Missing files: {', '.join(missing_files)}" + ) return self.load(*args, **kwargs) class DataPolling(Plugin): - def __init__(self, timeout=60 * 2, polling=1, file_count=None, error_on_missing=True, return_found=False, verbose=False): - """ - Args: - - timeout (int): Timeout in seconds (default is 120 seconds). - - polling (int): Time interval in seconds between each poll (default is 1 - second). - - file_count (int): Expected number of files to be found for globular - expansion (default is >= 1 files per pattern). - - error_on_missing (bool): Raise an exception if files are missing. - - return_found (bool): Return list of files found if True, otherwise return - None. error_on_missing and return_found cannot both be False. - When False, the DataPolling plugin acts only to indicate successful - completion of the poll, thereby triggering the execution of steps dependent - on this poll. - wile setting to True will mean that the plugin additionally passes along - files it has found. - - verbose (bool): Print verbose output. - """ + """A trigger plugin that completes when data is successfully polled.""" + def __init__(self, timeout=60 * 2, polling=1, file_count=None, verbose=False): self._timeout = timeout self._polling = polling self._file_count = file_count - self._error_on_missing = error_on_missing - self._return_found = return_found self._verbose = verbose - if not error_on_missing and not return_found: - raise ValueError( - "error_on_missing and return_found cannot both be False" - ) - def __call__(self, *args): - """ - Poll for the availability of files - - Poll for data and return when all data is available or otherwise raise an - exception if the timeout is reached. - - Args: - - *args: Variable length argument list of file patterns to be checked. - `:` syntax supported for files on a remote host. - - Returns: - - None, SKIP_EVENT or set: - Returning 'None' where 'return_found' is False, returning 'SKIP_EVENT' if - 'error_on_missing' is True and returning a 'set' of files found if - 'return_found' is True. - - Raises: - - RuntimeError: If the timeout is reached before all files are found. - """ - - # Define a key function - def host_and_glob_key(path): - psplit = path.split(":") - host = psplit[0] if ":" in path else "" # Extract host if available - is_glob = psplit[-1] if "*" in psplit[-1] else "" # Glob pattern - return (host, is_glob) - - time_taken = 0 - fpaths_found = set() - fpaths_not_found = set() - args = list(map(process_path, args)) - - # Group by host and whether it's a glob pattern - sorted_args = sorted(args, key=host_and_glob_key) - args_by_host = [ - [key, set(map(lambda path: path.split(":")[-1], group))] - for key, group in itertools.groupby(sorted_args, key=host_and_glob_key) - ] - - for ind, ((host, globular), paths) in enumerate(args_by_host): - globular = bool(globular) - host_msg = f"{host}:" if host else "" - while True: - if host: - # bash equivalent to python glob (glob on remote host) - expanded_paths = subprocess.run( - f'ssh {host} \'for file in {" ".join(paths)}; do if ' - '[ -e "$file" ]; then echo "$file"; fi; done\'', - shell=True, - check=True, - text=True, - capture_output=True, - ).stdout.strip() - if expanded_paths: - expanded_paths = expanded_paths.split("\n") - else: - expanded_paths = list( - itertools.chain.from_iterable(map(glob, paths)) - ) - if expanded_paths: - if host: - fpaths_found = fpaths_found.union(set([f"{host}:{path}" for path in expanded_paths])) - else: - fpaths_found = fpaths_found.union(expanded_paths) - if globular and ( - not self._file_count or len(expanded_paths) >= self._file_count - ): - # globular expansion completed - paths = set() - else: - # Remove paths we have found - paths = paths - set(expanded_paths) - - if paths: - if self._timeout and time_taken < self._timeout: - print( - f"self._polling for {host_msg}{paths}, time taken: " - f"{time_taken}s of limit {self._timeout}s" - ) - time.sleep(self._polling) - time_taken += self._polling - else: - break - else: - break - - if paths: - if host: - fpaths_not_found = fpaths_not_found.union(set([f"{host}:{path}" for path in paths])) - else: - fpaths_not_found = fpaths_not_found.union(paths) - if self._error_on_missing: - raise FileNotFoundError( - f"Timeout waiting for: {host_msg}{'; '.join(sorted(paths))}" - ) - elif not self._return_found: - return SKIP_EVENT - - if self._verbose and fpaths_found: - print( - "The following files were polled and found: " - f"{'; '.join(sorted(fpaths_found))}" - ) - if fpaths_not_found: - warnings.warn( - "The following files were polled and NOT found: " - f"{'; '.join(sorted(fpaths_not_found))}" - ) - if self._return_found: - return fpaths_found - return None + fpaths_found, fpaths_not_found = data_polling(*args, self._timeout, self._polling, self._file_count, fail_fast=True, verbose=self._verbose) + if fpaths_not_found: + raise FileNotFoundError( + f"Timeout waiting for: {'; '.join(sorted(fpaths_not_found))}" + ) + if self._verbose: + msg = f"These files were polled and found: {'; '.join(sorted(fpaths_found))}" + print(msg) + return class Input(NodeAwarePlugin): diff --git a/dagrunner/utils/__init__.py b/dagrunner/utils/__init__.py index ea64eae..adec3c0 100644 --- a/dagrunner/utils/__init__.py +++ b/dagrunner/utils/__init__.py @@ -3,12 +3,17 @@ # This file is part of 'dagrunner' and is released under the BSD 3-Clause license. # See LICENSE in the root of the repository for full licensing details. import argparse +from glob import glob import inspect +import itertools import os import socket +import shutil +import subprocess import threading import time from typing import Iterable +import warnings from abc import ABC, abstractmethod import dagrunner.utils._doc_styles as doc_styles @@ -408,3 +413,193 @@ def function_to_argparse_parse_args(*args, **kwargs): print(f"CLI call arguments: {args}") kwargs = args.pop("kwargs", None) or {} return args, kwargs + + +def data_polling(*args, timeout=60 * 2, polling=1, file_count=None, fail_fast=True, verbose=False): + """ + Poll for the availability of files + + Poll for data and return when all data is available or otherwise raise an + exception if the timeout is reached. + + Args: + - *args: Variable length argument list of file patterns to be checked. + `:` syntax supported for files on a remote host. + + Args: + - timeout (int): Timeout in seconds (default is 120 seconds). + - polling (int): Time interval in seconds between each poll (default is 1 + second). + - file_count (int): Expected number of files to be found for globular + expansion (default is >= 1 files per pattern). + - fail_fast (bool): Stop when a file is not found (default is True). + - verbose (bool): Print verbose output. + """ + # Define a key function + def host_and_glob_key(path): + psplit = path.split(":") + host = psplit[0] if ":" in path else "" # Extract host if available + is_glob = psplit[-1] if "*" in psplit[-1] else "" # Glob pattern + return (host, is_glob) + + time_taken = 0 + fpaths_found = set() + fpaths_not_found = set() + args = list(map(process_path, args)) + + # Group by host and whether it's a glob pattern + sorted_args = sorted(args, key=host_and_glob_key) + args_by_host = [ + [key, set(map(lambda path: path.split(":")[-1], group))] + for key, group in itertools.groupby(sorted_args, key=host_and_glob_key) + ] + + for ind, ((host, globular), paths) in enumerate(args_by_host): + globular = bool(globular) + host_msg = f"{host}:" if host else "" + while True: + if host: + # bash equivalent to python glob (glob on remote host) + expanded_paths = subprocess.run( + f'ssh {host} \'for file in {" ".join(paths)}; do if ' + '[ -e "$file" ]; then echo "$file"; fi; done\'', + shell=True, + check=True, + text=True, + capture_output=True, + ).stdout.strip() + if expanded_paths: + expanded_paths = expanded_paths.split("\n") + else: + expanded_paths = list( + itertools.chain.from_iterable(map(glob, paths)) + ) + if expanded_paths: + if host: + fpaths_found = fpaths_found.union(set([f"{host}:{path}" for path in expanded_paths])) + else: + fpaths_found = fpaths_found.union(expanded_paths) + if globular and ( + not file_count or len(expanded_paths) >= file_count + ): + # globular expansion completed + paths = set() + else: + # Remove paths we have found + paths = paths - set(expanded_paths) + + if paths: + if timeout and time_taken < timeout: + if verbose: + print( + f"polling for {host_msg}{paths}, time taken: " + f"{time_taken}s of limit {timeout}s" + ) + time.sleep(polling) + time_taken += polling + else: + break + else: + break + + if paths: + if host: + fpaths_not_found = fpaths_not_found.union(set([f"{host}:{path}" for path in paths])) + else: + fpaths_not_found = fpaths_not_found.union(paths) + + if fail_fast: + break + + return fpaths_found, fpaths_not_found + + +class _RemotePathHandler: + def __init__(self, fpath): + self._fpath = fpath + self._host, self._lpath = None, fpath + if ":" in fpath: + self._host, self._lpath = fpath.split(":") + + @property + def host(self): + self._host + + @property + def local_path(self): + self._lpath + + def __str__(self): + return self._fpath + + def exists(self): + if self._host: + # check if file exists on remote host + exists = subprocess.run( + ["ssh", self._host, "test", "-e", self._lpath], check=False + ).returncode == 0 + else: + exists = os.path.exists(self._lpath) + return exists + + def get_identity(self): + """An identity derived from modification time and file size in bytes""" + if self._host: + mtime = subprocess.run( + ["ssh", self._host, "stat", "-c", "%Y_%s", self._lpath], + check=True, + text=True, + capture_output=True, + ).stdout.strip() + else: + mtime = f"{int(os.path.getmtime(self._lpath))}_{os.path.getsize(self._lpath)}" + return mtime + + def copy(self, target): + if self._host: + rsync_command = ["scp", "-p", f"{self._host}:{self._lpath}", target] + subprocess.run( + rsync_command, check=True, text=True, capture_output=True + ) + else: + try: + os.link(self._lpath, target) + except Exception: + warnings.warn( + f"Failed to hard link {self._lpath} to {target}. Copying instead." + ) + shutil.copy2(self._lpath, target) + + +def stage_to_dir(*args, staging_dir, verbose=False): + """ + Copy input filepaths to a staging area and update paths. + + Hard link copies are preferred (same host) and physical copies are made otherwise. + File name, size and modification time are used to evaluate if the destination file + exists already (matching criteria of rsync). If exists already, skip the copy. + Staged files are named: `__` to avoid + collision with identically names files. + """ + os.makedirs(staging_dir, exist_ok=True) + args = list(args) + for ind, arg in enumerate(args): + + fpath = _RemotePathHandler(arg) + if not fpath.exists(): + raise FileNotFoundError(f"File '{fpath}' not found.") + + source_mtime_size = fpath.get_identity() + + target = os.path.join( + staging_dir, f"{source_mtime_size}_{os.path.basename(fpath)}" + ) + if not os.path.exists(target): + fpath.copy(target) + else: + warnings.warn(f"Staged file {target} already exists. Skipping copy.") + + args[ind] = target + if verbose: + print(f"Staged {arg} to {args[ind]}") + return args \ No newline at end of file