diff --git a/dagrunner/execute_graph.py b/dagrunner/execute_graph.py index 3d89abc..1cd768e 100755 --- a/dagrunner/execute_graph.py +++ b/dagrunner/execute_graph.py @@ -161,7 +161,8 @@ def plugin_executor( callable_kwargs_init = {} else: raise ValueError( - f"expecting 1, 2 or 3 values to unpack for {callable_obj}, got {len(call)}" + f"expecting 1, 2 or 3 values to unpack for {callable_obj}, " + f"got {len(call)}" ) callable_kwargs_init = ( {} if callable_kwargs_init is None else callable_kwargs_init @@ -173,7 +174,8 @@ def plugin_executor( callable_kwargs = {} else: raise ValueError( - f"expecting 1 or 2 values to unpack for {callable_obj}, got {len(call)}" + f"expecting 1 or 2 values to unpack for {callable_obj}, got " + f"{len(call)}" ) callable_kwargs = {} if callable_kwargs is None else callable_kwargs diff --git a/dagrunner/plugin_framework.py b/dagrunner/plugin_framework.py index 05fd497..162f99d 100644 --- a/dagrunner/plugin_framework.py +++ b/dagrunner/plugin_framework.py @@ -189,11 +189,12 @@ def __call__( Raises: - RuntimeError: If the timeout is reached before all files are found. """ + # Define a key function def host_and_glob_key(path): - psplit = path.split(':') - host = psplit[0] if ':' in path else '' # Extract host if available - is_glob = psplit[-1] if '*' in psplit[-1] else '' # Glob pattern + psplit = path.split(":") + host = psplit[0] if ":" in path else "" # Extract host if available + is_glob = psplit[-1] if "*" in psplit[-1] else "" # Glob pattern return (host, is_glob) time_taken = 0 @@ -202,18 +203,20 @@ def host_and_glob_key(path): # Group by host and whether it's a glob pattern sorted_args = sorted(args, key=host_and_glob_key) - args_by_host = [[key, set(map(lambda path: path.split(':')[-1], group))] for - key, group in itertools.groupby(sorted_args, key=host_and_glob_key)] + args_by_host = [ + [key, set(map(lambda path: path.split(":")[-1], group))] + for key, group in itertools.groupby(sorted_args, key=host_and_glob_key) + ] for ind, ((host, globular), paths) in enumerate(args_by_host): globular = bool(globular) - host_msg = f"{host}:" if host else '' + host_msg = f"{host}:" if host else "" while time_taken < timeout or not timeout: if host: # bash equivalent to python glob (glob on remote host) expanded_paths = subprocess.run( - f'ssh {host} \'for file in {" ".join(paths)}; do if [ -e "$file" ]; ' - 'then echo "$file"; fi; done\'', + f'ssh {host} \'for file in {" ".join(paths)}; do if ' + '[ -e "$file" ]; then echo "$file"; fi; done\'', shell=True, check=True, text=True, @@ -222,10 +225,14 @@ def host_and_glob_key(path): if expanded_paths: expanded_paths = expanded_paths.split("\n") else: - expanded_paths = list(itertools.chain.from_iterable(map(glob, paths))) + expanded_paths = list( + itertools.chain.from_iterable(map(glob, paths)) + ) if expanded_paths: fpaths_found = fpaths_found.union(expanded_paths) - if globular and (not file_count or len(expanded_paths) >= file_count): + if globular and ( + not file_count or len(expanded_paths) >= file_count + ): # globular expansion completed paths = set() else: @@ -235,8 +242,8 @@ def host_and_glob_key(path): if paths: if timeout: print( - f"polling for {host_msg}{paths}, time taken: {time_taken}s of limit " - f"{timeout}s" + f"polling for {host_msg}{paths}, time taken: " + f"{time_taken}s of limit {timeout}s" ) time.sleep(polling) time_taken += polling @@ -246,10 +253,15 @@ def host_and_glob_key(path): break if paths: - raise FileNotFoundError(f"Timeout waiting for: {host_msg}{'; '.join(sorted(paths))}") + raise FileNotFoundError( + f"Timeout waiting for: {host_msg}{'; '.join(sorted(paths))}" + ) if verbose and fpaths_found: - print(f"The following files were polled and found: {'; '.join(sorted(fpaths_found))}") + print( + "The following files were polled and found: " + f"{'; '.join(sorted(fpaths_found))}" + ) return None diff --git a/dagrunner/tests/execute_graph/test_plugin_executor.py b/dagrunner/tests/execute_graph/test_plugin_executor.py index 81c89b1..a5a0ead 100644 --- a/dagrunner/tests/execute_graph/test_plugin_executor.py +++ b/dagrunner/tests/execute_graph/test_plugin_executor.py @@ -20,7 +20,8 @@ def __call__( ): return ( f"init_kwargs={self._init_kwargs}; " - f"init_named_arg={self._init_named_arg}; init_named_kwarg={self._init_named_kwarg}; " + f"init_named_arg={self._init_named_arg}; " + f"init_named_kwarg={self._init_named_kwarg}; " f"call_args={call_args}; call_kwargs={call_kwargs}; " f"call_named_arg={call_named_arg}; call_named_kwarg={call_named_kwarg}; " ) @@ -59,7 +60,8 @@ def __call__(self, *call_args, **call_kwargs): "init_named_kwarg=sentinel.init_named_kwarg; " "call_args=(sentinel.arg1, sentinel.arg2); " "call_kwargs={'call_other_kwarg': sentinel.call_other_kwarg}; " - "call_named_arg=sentinel.call_named_arg; call_named_kwarg=sentinel.call_named_kwarg; " + "call_named_arg=sentinel.call_named_arg; " + "call_named_kwarg=sentinel.call_named_kwarg; " ), ), # Passing class init args only @@ -87,7 +89,10 @@ def __call__(self, *call_args, **call_kwargs): ], ) def test_pass_class_arg_kwargs(plugin, init_args, call_args, target): - """Test passing named parameters to plugin class initialisation and __call__ method.""" + """ + Test passing named parameters to plugin class initialisation and __call__ + method. + """ args = (mock.sentinel.arg1, mock.sentinel.arg2) call = tuple([plugin, init_args, call_args]) res = plugin_executor(*args, call=call) @@ -95,12 +100,15 @@ def test_pass_class_arg_kwargs(plugin, init_args, call_args, target): def test_pass_common_args(): - """Passing 'common args', some relevant to class init and some to call method.""" + """ + Passing 'common args', some relevant to class init and some to call method. + """ args = (mock.sentinel.arg1, mock.sentinel.arg2) common_kwargs = { "init_named_arg": mock.sentinel.init_named_arg, "init_named_kwarg": mock.sentinel.init_named_kwarg, - "other_kwargs": mock.sentinel.other_kwargs, # this should be ignored (as not part of class signature) + # this should be ignored (as not part of class signature) + "other_kwargs": mock.sentinel.other_kwargs, "call_named_arg": mock.sentinel.call_named_arg, "call_named_kwarg": mock.sentinel.call_named_kwarg, } diff --git a/dagrunner/tests/plugin_framework/test_DataPolling.py b/dagrunner/tests/plugin_framework/test_DataPolling.py index 7476c94..f1d83bf 100644 --- a/dagrunner/tests/plugin_framework/test_DataPolling.py +++ b/dagrunner/tests/plugin_framework/test_DataPolling.py @@ -3,6 +3,8 @@ # This file is part of 'dagrunner' and is released under the BSD 3-Clause license. # See LICENSE in the root of the repository for full licensing details. import socket +import subprocess +from glob import glob from unittest.mock import patch import pytest @@ -20,7 +22,7 @@ def tmp_file(tmp_path_factory): def call_dp(*filepaths, verbose=True, **kwargs): dp = DataPolling() - #dp(*filepaths, timeout=0.001, polling=0.002, verbose=verbose, **kwargs) + # dp(*filepaths, timeout=0.001, polling=0.002, verbose=verbose, **kwargs) dp(*filepaths, timeout=0, polling=0, verbose=verbose, **kwargs) @@ -97,8 +99,6 @@ def test_mixture_of_hosts_local(tmp_dir, capsys): Ensure that we can poll for groups of shared inputs by common host and whether they are globular or not. """ - host_tmp_file = f"{socket.gethostname()}:{tmp_file}" - input_paths = [ f"{socket.gethostname()}:{tmp_dir / 'testA0.txt'}", f"{socket.gethostname()}:{tmp_dir / 'testA1.txt'}", @@ -109,47 +109,64 @@ def test_mixture_of_hosts_local(tmp_dir, capsys): ] target = { - str(pp) for pp in sorted([ - tmp_dir / 'testA0.txt', - tmp_dir / 'testA1.txt', - tmp_dir / 'testB0.txt', - tmp_dir / 'testB1.txt', - tmp_dir / 'testC0.txt', - tmp_dir / 'testC1.txt', - tmp_dir / 'testD0.txt', - tmp_dir / 'testD1.txt', - ]) + str(pp) + for pp in sorted( + [ + tmp_dir / "testA0.txt", + tmp_dir / "testA1.txt", + tmp_dir / "testB0.txt", + tmp_dir / "testB1.txt", + tmp_dir / "testC0.txt", + tmp_dir / "testC1.txt", + tmp_dir / "testD0.txt", + tmp_dir / "testD1.txt", + ] + ) } - from glob import glob - import subprocess # Mocking gethostname() so that our host doesn't match against our local host check # internally. with patch( "dagrunner.utils.socket.gethostname", return_value="dummy_host.dummy_domain" ): - # patch plugin_framework.glob with a wrapper so that we can check what it was called with + # patch plugin_framework.glob with a wrapper so that we can check what + # it was called with. with patch("dagrunner.plugin_framework.glob", wraps=glob) as mock_glob: - with patch("dagrunner.plugin_framework.subprocess.run", wraps=subprocess.run) as mock_subprocrun: + with patch( + "dagrunner.plugin_framework.subprocess.run", wraps=subprocess.run + ) as mock_subprocrun: call_dp(*input_paths) # check how subprocess.run was called ##################################### - assert len(mock_subprocrun.call_args_list) is 2 + assert len(mock_subprocrun.call_args_list) == 2 # group all non glob patterns into a single call (minimising ssh calls) - for res_index, targets in zip(range(2), [[f"{tmp_dir}/testA0.txt", f"{tmp_dir}/testA1.txt"], [f"{tmp_dir}/testB*.txt"]]): + for res_index, targets in zip( + range(2), + [[f"{tmp_dir}/testA0.txt", f"{tmp_dir}/testA1.txt"], [f"{tmp_dir}/testB*.txt"]], + ): objcall = mock_subprocrun.call_args_list[res_index] - results = objcall[0][0].replace(';', '').split() - assert sorted(filter(lambda substr: substr.startswith(str(tmp_dir)), results)) == targets + results = objcall[0][0].replace(";", "").split() + assert ( + sorted(filter(lambda substr: substr.startswith(str(tmp_dir)), results)) + == targets + ) # check how glob was called (glob supports only 1 path arguments) ##################################### - assert len(mock_glob.call_args_list) is 3 - targets = [f'{tmp_dir}/testC0.txt', f'{tmp_dir}/testC1.txt', f'{tmp_dir}/testD*.txt'] + assert len(mock_glob.call_args_list) == 3 + targets = [ + f"{tmp_dir}/testC0.txt", + f"{tmp_dir}/testC1.txt", + f"{tmp_dir}/testD*.txt", + ] for objcall in mock_glob.call_args_list: assert objcall[0][0] in targets targets.remove(objcall[0][0]) captured = capsys.readouterr() - assert f"The following files were polled and found: {'; '.join(sorted(target))}" in captured.out \ No newline at end of file + assert ( + f"The following files were polled and found: {'; '.join(sorted(target))}" + in captured.out + )