-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: fix common arg handling #45
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,25 +3,25 @@ | |
# | ||
# This file is part of 'dagrunner' and is released under the BSD 3-Clause license. | ||
# See LICENSE in the root of the repository for full licensing details. | ||
import importlib | ||
import inspect | ||
import logging | ||
import warnings | ||
from functools import partial | ||
|
||
import importlib | ||
import networkx as nx | ||
|
||
import dask | ||
import networkx as nx | ||
from dask.base import tokenize | ||
from dask.utils import apply | ||
|
||
from dagrunner.plugin_framework import NodeAwarePlugin | ||
from dagrunner.runner.schedulers import SCHEDULERS | ||
from dagrunner.utils import ( | ||
TimeIt, | ||
function_to_argparse, | ||
logger, | ||
) | ||
from dagrunner.plugin_framework import NodeAwarePlugin | ||
from dagrunner.runner.schedulers import SCHEDULERS | ||
from dagrunner.utils.visualisation import visualise_graph | ||
from dagrunner.utils import logger | ||
|
||
|
||
class SkipBranch(Exception): | ||
|
@@ -38,12 +38,17 @@ class SkipBranch(Exception): | |
pass | ||
|
||
|
||
def _get_common_args_matching_signature(callable_obj, common_kwargs): | ||
"""Get subset of arguments which match the callable signature.""" | ||
def _get_common_args_matching_signature(callable_obj, common_kwargs, keys=None): | ||
""" | ||
Get subset of arguments which match the callable signature. | ||
|
||
Also additionally include those 'keys' provided | ||
""" | ||
keys = [] if keys is None else keys | ||
return { | ||
key: value | ||
for key, value in common_kwargs.items() | ||
if key in inspect.signature(callable_obj).parameters | ||
if key in inspect.signature(callable_obj).parameters or key in keys | ||
} | ||
|
||
|
||
|
@@ -56,19 +61,22 @@ def plugin_executor( | |
**node_properties, | ||
): | ||
""" | ||
Executes a plugin function or method with the provided arguments and keyword arguments. | ||
Executes a plugin callable with the provided arguments and keyword arguments. | ||
|
||
Args: | ||
- `*args`: Positional arguments to be passed to the plugin function or method. | ||
- `call`: A tuple containing the callable object or python dot path to one, keyword arguments | ||
to instantiate this class (optional and where this callable is a class) and finally the keyword | ||
arguments to be passed to this callable. | ||
- `*args`: Positional arguments to be passed to the plugin callable. | ||
- `call`: A tuple containing the callable object or python dot path to one, keyword | ||
arguments to instantiate this class (optional and where this callable is a class) | ||
and finally the keyword arguments to be passed to this callable. | ||
- `verbose`: A boolean indicating whether to print verbose output. | ||
- `dry_run`: A boolean indicating whether to perform a dry run without executing the plugin. | ||
- `common_kwargs`: A dictionary of optional keyword arguments to apply to all applicable plugins. | ||
That is, being passed to the plugin initialisation and or call if such keywords are expected | ||
from the plugin. This is a useful alternative to global or environment variable usage. | ||
- `**node_properties`: Node properties. These will be passed to 'node-aware' plugins. | ||
- `dry_run`: A boolean indicating whether to perform a dry run without executing | ||
the plugin. | ||
- `common_kwargs`: A dictionary of optional keyword arguments to apply to all | ||
applicable plugins. That is, being passed to the plugin initialisation and or | ||
call if such keywords are expected from the plugin. This is a useful alternative | ||
to global or environment variable usage. | ||
- `**node_properties`: Node properties. These will be passed to 'node-aware' | ||
plugins. | ||
|
||
Returns: | ||
- The result of executing the plugin function or method. | ||
|
@@ -124,9 +132,9 @@ def plugin_executor( | |
callable_obj = callable_obj(**callable_kwargs_init) | ||
call_msg = f"(**{callable_kwargs_init})" | ||
|
||
callable_kwargs = callable_kwargs | { | ||
key: value for key, value in common_kwargs.items() if key in callable_kwargs | ||
} # based on overriding arguments | ||
callable_kwargs = callable_kwargs | _get_common_args_matching_signature( | ||
callable_obj, common_kwargs, callable_kwargs.keys() | ||
) # based on overriding arguments | ||
Comment on lines
+135
to
+137
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changes that weren't included in #43, hidden pytest failures |
||
|
||
msg = f"{obj_name}{call_msg}(*{args}, **{callable_kwargs})" | ||
if verbose: | ||
|
@@ -223,8 +231,8 @@ def __init__( | |
Keyword arguments to pass to the networkx graph callable. Optional. | ||
- `plugin_executor` (callable): | ||
A callable object that executes a plugin function or method with the provided | ||
arguments and keyword arguments. By default, uses the `plugin_executor` function. | ||
Optional. | ||
arguments and keyword arguments. By default, uses the `plugin_executor` | ||
function. Optional. | ||
- `scheduler` (str): | ||
Accepted values include "ray", "multiprocessing" and those recognised | ||
by dask: "threads", "processes" and "single-threaded" (useful for debugging). | ||
|
@@ -249,7 +257,8 @@ def __init__( | |
self._plugin_executor = plugin_executor | ||
if scheduler not in SCHEDULERS: | ||
raise ValueError( | ||
f"scheduler '{scheduler}' not recognised, please choose from {list(SCHEDULERS.keys())}" | ||
f"scheduler '{scheduler}' not recognised, please choose from " | ||
f"{list(SCHEDULERS.keys())}" | ||
) | ||
self._scheduler = SCHEDULERS[scheduler] | ||
self._num_workers = num_workers | ||
|
@@ -264,11 +273,13 @@ def nxgraph(self): | |
|
||
def _process_graph(self): | ||
""" | ||
Create flattened dictionary describing the relationship between each of our nodes. | ||
Create flattened dictionary describing the relationship between nodes. | ||
|
||
Here we wrap our nodes to ensure common parameters are share across all | ||
executed nodes (e.g. dry-run, verbose). | ||
|
||
TODO: Potentially support 'clobber' i.e. partial graph execution from a graph failure recovery. | ||
TODO: Potentially support 'clobber' i.e. partial graph execution from a graph | ||
failure recovery. | ||
""" | ||
executor = partial( | ||
self._plugin_executor, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,24 @@ | |
CALLING_MOD = os.path.basename(sys.argv[0]) | ||
|
||
|
||
def assert_help_str(help_str, tar): | ||
# Remove line wraps for easier comparison | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line wrap changes between versions of argparse, so we remove them. Again a cause of failures on CI with success locally. |
||
formatted_help_str = [] | ||
help_strs = help_str.split("\n") | ||
for line in help_strs: | ||
if line.startswith(" "): | ||
# more than 2 space indent | ||
formatted_help_str[-1] += f" {line.lstrip()}" | ||
else: | ||
formatted_help_str.append(line) | ||
help_str = "\n".join(formatted_help_str) | ||
|
||
if "optional arguments:" in help_str: | ||
# older versions of argparse use "optional arguments: instead of options:" | ||
tar = tar.replace("options:", "optional arguments:") | ||
assert help_str == tar | ||
Comment on lines
+27
to
+30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Supporting old and new argparse versions (not sure when this exactly changed right now). Was a cause of failures on CI and local succeeds. |
||
|
||
|
||
def get_parser_help_string(parser): | ||
buffer = StringIO() | ||
parser.print_help(file=buffer) | ||
|
@@ -56,7 +74,7 @@ def test_basic_args_kwargs_optional(): | |
--arg2 ARG2 Description of arg2. | ||
--arg3 Description of arg3. Optional. | ||
""" | ||
assert help_str == tar | ||
assert_help_str(help_str, tar) | ||
|
||
args = parser.parse_args(["3", "--arg2", "arg2", "--arg3"]) | ||
assert args.arg1 == 3 | ||
|
@@ -93,10 +111,9 @@ def test_kwargs_param_expand(): | |
|
||
options: | ||
-h, --help show this help message and exit | ||
--kwargs key value Optional global keyword arguments to apply to all | ||
applicable plugins. Key-value pair argument. | ||
--kwargs key value Optional global keyword arguments to apply to all applicable plugins. Key-value pair argument. | ||
""" | ||
assert help_str == tar | ||
assert_help_str(help_str, tar) | ||
|
||
args = parser.parse_args(["--kwargs", "key1", "val1", "--kwargs", "key2", "val2"]) | ||
assert "kwargs" in args | ||
|
@@ -133,11 +150,10 @@ def test_dict_param(): | |
|
||
options: | ||
-h, --help show this help message and exit | ||
--dkwarg1 key value Description of kwarg1. Optional. Key-value pair | ||
argument. | ||
--dkwarg1 key value Description of kwarg1. Optional. Key-value pair argument. | ||
--dkwarg2 key value Description of kwarg2. Key-value pair argument. | ||
""" | ||
assert help_str == tar | ||
assert_help_str(help_str, tar) | ||
args = parser.parse_args( | ||
[ | ||
"--dkwarg1", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exclude logging test in CI until #44