diff --git a/setup.py b/setup.py index 1c2c85af6..1a76e49fb 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ IS_NOT_WINDOWS = os.name != "nt" -PARALLEL_REQUIRE = ["ray[debug,tune]~=2.0.0"] +PARALLEL_REQUIRE = ["ray[debug,tune]~=2.9.0"] ATARI_REQUIRE = [ "seals[atari]~=0.2.1", ] diff --git a/src/imitation/scripts/parallel.py b/src/imitation/scripts/parallel.py index d5e5e2378..76a068224 100644 --- a/src/imitation/scripts/parallel.py +++ b/src/imitation/scripts/parallel.py @@ -188,13 +188,12 @@ def _ray_tune_sacred_wrapper( `ex.run`) and `reporter`. The function returns the run result. """ - def inner(config: Mapping[str, Any], reporter) -> Mapping[str, Any]: + def inner(config: Mapping[str, Any]) -> Mapping[str, Any]: """Trainable function with the correct signature for `ray.tune`. Args: config: Keyword arguments for `ex.run()`, where `ex` is the `sacred.Experiment` instance associated with `sacred_ex_name`. - reporter: Callback to report progress to Ray. Returns: Result from `ray.Run` object.