Skip to content

Commit

Permalink
formatting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
cpelley committed Sep 26, 2024
1 parent 0038a4a commit 83929fa
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 45 deletions.
6 changes: 4 additions & 2 deletions dagrunner/execute_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def plugin_executor(
callable_kwargs_init = {}
else:
raise ValueError(
f"expecting 1, 2 or 3 values to unpack for {callable_obj}, got {len(call)}"
f"expecting 1, 2 or 3 values to unpack for {callable_obj}, "
f"got {len(call)}"
)
callable_kwargs_init = (
{} if callable_kwargs_init is None else callable_kwargs_init
Expand All @@ -173,7 +174,8 @@ def plugin_executor(
callable_kwargs = {}
else:
raise ValueError(
f"expecting 1 or 2 values to unpack for {callable_obj}, got {len(call)}"
f"expecting 1 or 2 values to unpack for {callable_obj}, got "
f"{len(call)}"
)
callable_kwargs = {} if callable_kwargs is None else callable_kwargs

Expand Down
40 changes: 26 additions & 14 deletions dagrunner/plugin_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,12 @@ def __call__(
Raises:
- RuntimeError: If the timeout is reached before all files are found.
"""

# Define a key function
def host_and_glob_key(path):
psplit = path.split(':')
host = psplit[0] if ':' in path else '' # Extract host if available
is_glob = psplit[-1] if '*' in psplit[-1] else '' # Glob pattern
psplit = path.split(":")
host = psplit[0] if ":" in path else "" # Extract host if available
is_glob = psplit[-1] if "*" in psplit[-1] else "" # Glob pattern
return (host, is_glob)

time_taken = 0
Expand All @@ -202,18 +203,20 @@ def host_and_glob_key(path):

# Group by host and whether it's a glob pattern
sorted_args = sorted(args, key=host_and_glob_key)
args_by_host = [[key, set(map(lambda path: path.split(':')[-1], group))] for
key, group in itertools.groupby(sorted_args, key=host_and_glob_key)]
args_by_host = [
[key, set(map(lambda path: path.split(":")[-1], group))]
for key, group in itertools.groupby(sorted_args, key=host_and_glob_key)
]

for ind, ((host, globular), paths) in enumerate(args_by_host):
globular = bool(globular)
host_msg = f"{host}:" if host else ''
host_msg = f"{host}:" if host else ""
while time_taken < timeout or not timeout:
if host:
# bash equivalent to python glob (glob on remote host)
expanded_paths = subprocess.run(
f'ssh {host} \'for file in {" ".join(paths)}; do if [ -e "$file" ]; '
'then echo "$file"; fi; done\'',
f'ssh {host} \'for file in {" ".join(paths)}; do if '
'[ -e "$file" ]; then echo "$file"; fi; done\'',
shell=True,
check=True,
text=True,
Expand All @@ -222,10 +225,14 @@ def host_and_glob_key(path):
if expanded_paths:
expanded_paths = expanded_paths.split("\n")
else:
expanded_paths = list(itertools.chain.from_iterable(map(glob, paths)))
expanded_paths = list(
itertools.chain.from_iterable(map(glob, paths))
)
if expanded_paths:
fpaths_found = fpaths_found.union(expanded_paths)
if globular and (not file_count or len(expanded_paths) >= file_count):
if globular and (
not file_count or len(expanded_paths) >= file_count
):
# globular expansion completed
paths = set()
else:
Expand All @@ -235,8 +242,8 @@ def host_and_glob_key(path):
if paths:
if timeout:
print(
f"polling for {host_msg}{paths}, time taken: {time_taken}s of limit "
f"{timeout}s"
f"polling for {host_msg}{paths}, time taken: "
f"{time_taken}s of limit {timeout}s"
)
time.sleep(polling)
time_taken += polling
Expand All @@ -246,10 +253,15 @@ def host_and_glob_key(path):
break

if paths:
raise FileNotFoundError(f"Timeout waiting for: {host_msg}{'; '.join(sorted(paths))}")
raise FileNotFoundError(
f"Timeout waiting for: {host_msg}{'; '.join(sorted(paths))}"
)

if verbose and fpaths_found:
print(f"The following files were polled and found: {'; '.join(sorted(fpaths_found))}")
print(
"The following files were polled and found: "
f"{'; '.join(sorted(fpaths_found))}"
)
return None


Expand Down
18 changes: 13 additions & 5 deletions dagrunner/tests/execute_graph/test_plugin_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def __call__(
):
return (
f"init_kwargs={self._init_kwargs}; "
f"init_named_arg={self._init_named_arg}; init_named_kwarg={self._init_named_kwarg}; "
f"init_named_arg={self._init_named_arg}; "
f"init_named_kwarg={self._init_named_kwarg}; "
f"call_args={call_args}; call_kwargs={call_kwargs}; "
f"call_named_arg={call_named_arg}; call_named_kwarg={call_named_kwarg}; "
)
Expand Down Expand Up @@ -59,7 +60,8 @@ def __call__(self, *call_args, **call_kwargs):
"init_named_kwarg=sentinel.init_named_kwarg; "
"call_args=(sentinel.arg1, sentinel.arg2); "
"call_kwargs={'call_other_kwarg': sentinel.call_other_kwarg}; "
"call_named_arg=sentinel.call_named_arg; call_named_kwarg=sentinel.call_named_kwarg; "
"call_named_arg=sentinel.call_named_arg; "
"call_named_kwarg=sentinel.call_named_kwarg; "
),
),
# Passing class init args only
Expand Down Expand Up @@ -87,20 +89,26 @@ def __call__(self, *call_args, **call_kwargs):
],
)
def test_pass_class_arg_kwargs(plugin, init_args, call_args, target):
"""Test passing named parameters to plugin class initialisation and __call__ method."""
"""
Test passing named parameters to plugin class initialisation and __call__
method.
"""
args = (mock.sentinel.arg1, mock.sentinel.arg2)
call = tuple([plugin, init_args, call_args])
res = plugin_executor(*args, call=call)
assert res == target


def test_pass_common_args():
"""Passing 'common args', some relevant to class init and some to call method."""
"""
Passing 'common args', some relevant to class init and some to call method.
"""
args = (mock.sentinel.arg1, mock.sentinel.arg2)
common_kwargs = {
"init_named_arg": mock.sentinel.init_named_arg,
"init_named_kwarg": mock.sentinel.init_named_kwarg,
"other_kwargs": mock.sentinel.other_kwargs, # this should be ignored (as not part of class signature)
# this should be ignored (as not part of class signature)
"other_kwargs": mock.sentinel.other_kwargs,
"call_named_arg": mock.sentinel.call_named_arg,
"call_named_kwarg": mock.sentinel.call_named_kwarg,
}
Expand Down
65 changes: 41 additions & 24 deletions dagrunner/tests/plugin_framework/test_DataPolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This file is part of 'dagrunner' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
import socket
import subprocess
from glob import glob
from unittest.mock import patch

import pytest
Expand All @@ -20,7 +22,7 @@ def tmp_file(tmp_path_factory):

def call_dp(*filepaths, verbose=True, **kwargs):
dp = DataPolling()
#dp(*filepaths, timeout=0.001, polling=0.002, verbose=verbose, **kwargs)
# dp(*filepaths, timeout=0.001, polling=0.002, verbose=verbose, **kwargs)
dp(*filepaths, timeout=0, polling=0, verbose=verbose, **kwargs)


Expand Down Expand Up @@ -97,8 +99,6 @@ def test_mixture_of_hosts_local(tmp_dir, capsys):
Ensure that we can poll for groups of shared inputs by common host and whether
they are globular or not.
"""
host_tmp_file = f"{socket.gethostname()}:{tmp_file}"

input_paths = [
f"{socket.gethostname()}:{tmp_dir / 'testA0.txt'}",
f"{socket.gethostname()}:{tmp_dir / 'testA1.txt'}",
Expand All @@ -109,47 +109,64 @@ def test_mixture_of_hosts_local(tmp_dir, capsys):
]

target = {
str(pp) for pp in sorted([
tmp_dir / 'testA0.txt',
tmp_dir / 'testA1.txt',
tmp_dir / 'testB0.txt',
tmp_dir / 'testB1.txt',
tmp_dir / 'testC0.txt',
tmp_dir / 'testC1.txt',
tmp_dir / 'testD0.txt',
tmp_dir / 'testD1.txt',
])
str(pp)
for pp in sorted(
[
tmp_dir / "testA0.txt",
tmp_dir / "testA1.txt",
tmp_dir / "testB0.txt",
tmp_dir / "testB1.txt",
tmp_dir / "testC0.txt",
tmp_dir / "testC1.txt",
tmp_dir / "testD0.txt",
tmp_dir / "testD1.txt",
]
)
}

from glob import glob
import subprocess
# Mocking gethostname() so that our host doesn't match against our local host check
# internally.
with patch(
"dagrunner.utils.socket.gethostname", return_value="dummy_host.dummy_domain"
):
# patch plugin_framework.glob with a wrapper so that we can check what it was called with
# patch plugin_framework.glob with a wrapper so that we can check what
# it was called with.
with patch("dagrunner.plugin_framework.glob", wraps=glob) as mock_glob:
with patch("dagrunner.plugin_framework.subprocess.run", wraps=subprocess.run) as mock_subprocrun:
with patch(
"dagrunner.plugin_framework.subprocess.run", wraps=subprocess.run
) as mock_subprocrun:
call_dp(*input_paths)

# check how subprocess.run was called
#####################################
assert len(mock_subprocrun.call_args_list) is 2
assert len(mock_subprocrun.call_args_list) == 2

# group all non glob patterns into a single call (minimising ssh calls)
for res_index, targets in zip(range(2), [[f"{tmp_dir}/testA0.txt", f"{tmp_dir}/testA1.txt"], [f"{tmp_dir}/testB*.txt"]]):
for res_index, targets in zip(
range(2),
[[f"{tmp_dir}/testA0.txt", f"{tmp_dir}/testA1.txt"], [f"{tmp_dir}/testB*.txt"]],
):
objcall = mock_subprocrun.call_args_list[res_index]
results = objcall[0][0].replace(';', '').split()
assert sorted(filter(lambda substr: substr.startswith(str(tmp_dir)), results)) == targets
results = objcall[0][0].replace(";", "").split()
assert (
sorted(filter(lambda substr: substr.startswith(str(tmp_dir)), results))
== targets
)

# check how glob was called (glob supports only 1 path arguments)
#####################################
assert len(mock_glob.call_args_list) is 3
targets = [f'{tmp_dir}/testC0.txt', f'{tmp_dir}/testC1.txt', f'{tmp_dir}/testD*.txt']
assert len(mock_glob.call_args_list) == 3
targets = [
f"{tmp_dir}/testC0.txt",
f"{tmp_dir}/testC1.txt",
f"{tmp_dir}/testD*.txt",
]
for objcall in mock_glob.call_args_list:
assert objcall[0][0] in targets
targets.remove(objcall[0][0])

captured = capsys.readouterr()
assert f"The following files were polled and found: {'; '.join(sorted(target))}" in captured.out
assert (
f"The following files were polled and found: {'; '.join(sorted(target))}"
in captured.out
)

0 comments on commit 83929fa

Please sign in to comment.