diff --git a/dagrunner/execute_graph.py b/dagrunner/execute_graph.py index 1cd768e..aff757a 100755 --- a/dagrunner/execute_graph.py +++ b/dagrunner/execute_graph.py @@ -130,7 +130,9 @@ def plugin_executor( arg for arg in args if arg is not None ] # support plugins that have no return value if call is None: - raise ValueError("call is a required argument") + raise ValueError( + f"call is a required argument\nnode_properties: {node_properties}" + ) if verbose: print(f"args: {args}") print(f"call: {call}") @@ -162,7 +164,7 @@ def plugin_executor( else: raise ValueError( f"expecting 1, 2 or 3 values to unpack for {callable_obj}, " - f"got {len(call)}" + f"got {len(call)}\nnode_properties: {node_properties}" ) callable_kwargs_init = ( {} if callable_kwargs_init is None else callable_kwargs_init @@ -175,7 +177,7 @@ def plugin_executor( else: raise ValueError( f"expecting 1 or 2 values to unpack for {callable_obj}, got " - f"{len(call)}" + f"{len(call)}\nnode_properties: {node_properties}" ) callable_kwargs = {} if callable_kwargs is None else callable_kwargs @@ -188,7 +190,14 @@ def plugin_executor( callable_kwargs_init | _get_common_args_matching_signature(callable_obj, common_kwargs) ) - callable_obj = callable_obj(**callable_kwargs_init) + try: + callable_obj = callable_obj(**callable_kwargs_init) + except Exception as err: + msg = ( + f"Failed to initialise {obj_name} with {callable_kwargs_init}" + f"\nnode_properties: {node_properties}" + ) + raise RuntimeError(msg) from err call_msg = f"(**{callable_kwargs_init})" callable_kwargs = callable_kwargs | _get_common_args_matching_signature( @@ -203,7 +212,14 @@ def plugin_executor( with TimeIt() as timer, dask.config.set( scheduler="single-threaded" ), CaptureProcMemory() as mem: - res = callable_obj(*args, **callable_kwargs) + try: + res = callable_obj(*args, **callable_kwargs) + except Exception as err: + msg = ( + f"Failed to execute {obj_name} with {args}, {callable_kwargs}" + f"\nnode_properties: {node_properties}" + ) + raise RuntimeError(msg) from err msg = f"{str(timer)}; {msg}; {mem.max()}" logging.info(msg) @@ -292,8 +308,9 @@ def __init__( function. Optional. - `scheduler` (str): Accepted values include "ray", "multiprocessing" and those recognised - by dask: "threads", "processes" and "single-threaded" (useful for debugging). - See https://docs.dask.org/en/latest/scheduling.html. Optional. + by dask: "threads", "processes" and "single-threaded" (useful for debugging) + and "distributed". See https://docs.dask.org/en/latest/scheduling.html. + Optional. - `num_workers` (int): Number of processes or threads to use. Optional. - `dry_run` (bool): diff --git a/dagrunner/runner/schedulers/dask.py b/dagrunner/runner/schedulers/dask.py index 1e82745..8628528 100644 --- a/dagrunner/runner/schedulers/dask.py +++ b/dagrunner/runner/schedulers/dask.py @@ -89,6 +89,10 @@ def __init__(self, num_workers, profiler_filepath=None, **kwargs): self._profiler_output = profiler_filepath self._kwargs = kwargs self._cluster = None + self._client = None + # bug: dashboard cannot be disabled + # see https://github.com/dask/distributed/issues/8136 + self._dashboard_address = None def __enter__(self): """Create a local cluster and connect a client to it.""" @@ -96,9 +100,12 @@ def __enter__(self): n_workers=self._num_workers, processes=True, threads_per_worker=1, + dashboard_address=self._dashboard_address, **self._kwargs, ) - Client(self._cluster) + self._client = Client(self._cluster) + if self._dashboard_address: + print(f"dashboard link: {self._client.dashboard_link}") return self def __exit__(self, exc_type, exc_value, exc_traceback): @@ -182,6 +189,7 @@ def run(self, dask_graph, verbose=False): scheduler=self._scheduler, num_workers=self._num_workers, chunksize=1, + **self._kwargs, ) visualize( [prof, rprof, cprof], @@ -193,7 +201,10 @@ def run(self, dask_graph, verbose=False): print(f"{max([res.mem for res in rprof.results])}MB total memory used") else: res = self._dask_container.compute( - scheduler=self._scheduler, num_workers=self._num_workers, chunksize=1 + scheduler=self._scheduler, + num_workers=self._num_workers, + chunksize=1, + **self._kwargs, ) return res diff --git a/dagrunner/tests/execute_graph/test_plugin_executor.py b/dagrunner/tests/execute_graph/test_plugin_executor.py index a5a0ead..5ca0036 100644 --- a/dagrunner/tests/execute_graph/test_plugin_executor.py +++ b/dagrunner/tests/execute_graph/test_plugin_executor.py @@ -160,3 +160,47 @@ def test_pass_common_args_override(): ) res = plugin_executor(*args, call=call, common_kwargs=common_kwargs) assert res == target + + +def test_missing_call_args(): + """Raise an error if 'call' arg isn't provided.""" + kwargs = {"key1": mock.sentinel.value1, "key2": mock.sentinel.value2} + msg = f"call is a required argument\nnode_properties: {kwargs}" + with pytest.raises(ValueError, match=msg): + plugin_executor(mock.sentinel.arg1, **kwargs) + + +def test_class_plugin_unexpected_tuple_unpack(): + """Expecting inits kwargs and call kwargs but no more.""" + msg = "expecting 1, 2 or 3 values to unpack.*got 4" + with pytest.raises(ValueError, match=msg): + plugin_executor(mock.sentinel.arg1, call=(DummyPlugin, {}, {}, {})) + + +def test_callable_plugin_unexpected_tuple_unpack(): + """Expecting call kwargs but no more.""" + msg = "expecting 1 or 2 values to unpack.*got 3" + with pytest.raises(ValueError, match=msg): + plugin_executor(mock.sentinel.arg1, call=(DummyPluginNoNamedParam(), {}, {})) + + +class BadDummyInitPlugin: + def __init__(self, **kwargs): + raise ValueError("some error") + + def __call__(self, *args, **kwargs): + pass + + +def test_extended_init_failure_context(): + with pytest.raises(RuntimeError, match="Failed to initialise"): + plugin_executor(mock.sentinel.arg1, call=(BadDummyInitPlugin,)) + + +def bad_call_plugin(*args): + raise ValueError("some error") + + +def test_extended_call_plugin_failure_context(): + with pytest.raises(RuntimeError, match="Failed to execute"): + plugin_executor(mock.sentinel.arg1, call=(bad_call_plugin,)) diff --git a/docs/dagrunner.execute_graph.md b/docs/dagrunner.execute_graph.md index 2540bdb..ca8b51a 100644 --- a/docs/dagrunner.execute_graph.md +++ b/docs/dagrunner.execute_graph.md @@ -16,7 +16,7 @@ see [function: dagrunner.utils.visualisation.visualise_graph](dagrunner.utils.vi ## class: `ExecuteGraph` -[Source](../dagrunner/execute_graph.py#L266) +[Source](../dagrunner/execute_graph.py#L282) ### Call Signature: @@ -26,7 +26,7 @@ ExecuteGraph(networkx_graph: str, networkx_graph_kwargs: dict = None,