Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adversarial algorithm matching original paper's implementation #770

Draft
wants to merge 56 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
b4210c1
Merge py file changes from benchmark-algs
taufeeque9 Jan 4, 2023
97bc063
Clean parallel script
taufeeque9 Jan 10, 2023
9291225
Undo the changes from #653 to the dagger benchmark config files.
ernestum Jan 26, 2023
276d863
Improve readability and interpretability of benchmarking tests.
ernestum Jan 25, 2023
37eb914
Add pxponential beta scheduler for dagger
taufeeque9 Mar 1, 2023
877383b
Ignore coverage for unknown algorithms.
ernestum Feb 2, 2023
c8e55cb
Cleanup and extend tests for beta schedules in dagger.
ernestum Feb 2, 2023
6b9b306
Merge branch 'master' into benchmark-pr
taufeeque9 Feb 6, 2023
8576465
Fix test cases
taufeeque9 Feb 8, 2023
d81eb68
Add optuna to dependencies
taufeeque9 Feb 8, 2023
27467d3
Fix test case
taufeeque9 Feb 8, 2023
b59a768
Merge branch 'master' into benchmark-pr
taufeeque9 Feb 8, 2023
1a3b6b8
Clean up the scripts
taufeeque9 Feb 9, 2023
7a438da
Remove reporter(done) since mean_return is reported by the runs
taufeeque9 Feb 9, 2023
5bc5835
Merge branch 'master' into benchmark-pr
taufeeque9 Feb 20, 2023
2e56de8
Add beta_schedule parameter to dagger script
taufeeque9 Feb 23, 2023
84e854a
Merge branch 'master' into benchmark-pr
taufeeque9 Mar 16, 2023
73d8576
Update config policy kwargs
taufeeque9 Mar 16, 2023
9fdf878
Changes from review
taufeeque9 May 16, 2023
1c1dbc4
Fix errors with some configs
taufeeque9 May 16, 2023
3467af2
Merge branch 'master' into benchmark-pr
taufeeque9 May 16, 2023
44c4e97
Updates based on review
taufeeque9 Jun 14, 2023
4d493ae
Merge branch 'master' into benchmark-pr
taufeeque9 Jun 14, 2023
ab01269
Change metric everywhere
taufeeque9 Jun 14, 2023
f64580e
Merge branch 'master' into benchmark-pr
taufeeque9 Jul 11, 2023
e896d7d
Separate tuning code from parallel.py
taufeeque9 Jul 11, 2023
64c3a8d
Fix docstring
taufeeque9 Jul 11, 2023
8fba0d3
Removing resume option as it is getting tricky to correctly implement
taufeeque9 Jul 11, 2023
12ab31c
Minor fixes
taufeeque9 Jul 11, 2023
19b0f2c
Updates from review
taufeeque9 Jul 16, 2023
046b8d9
fix lint error
taufeeque9 Jul 16, 2023
8eee082
Add documentation for using the tuning script
taufeeque9 Jul 16, 2023
5ce7658
Fix lint error
taufeeque9 Jul 17, 2023
a8be331
Updates from the review
taufeeque9 Jul 18, 2023
4ff006d
Fix file name test errors
taufeeque9 Jul 18, 2023
6933afa
Add tune_run_kwargs in parallel script
taufeeque9 Jul 19, 2023
77f9d9b
Fix test errors
taufeeque9 Jul 19, 2023
54eb8a6
Fix test
taufeeque9 Jul 19, 2023
d50238f
Fix lint
taufeeque9 Jul 19, 2023
3fe22d4
Updates from review
taufeeque9 Jul 19, 2023
c50aa20
Simplify few lines of code
taufeeque9 Jul 20, 2023
000af61
Updates from review
taufeeque9 Aug 4, 2023
8b55134
Fix test
taufeeque9 Aug 4, 2023
f3ba2b5
Revert "Fix test"
taufeeque9 Aug 4, 2023
f8251c7
Fix test
taufeeque9 Aug 4, 2023
664fc37
Convert Dict to Mapping in input argument
taufeeque9 Aug 7, 2023
8690e1d
Ignore coverage in script configurations.
ernestum Aug 30, 2023
dd9eb6a
Pin huggingface_sb3 version.
ernestum Aug 30, 2023
b3930f4
Merge branch 'master' into benchmark-pr
ernestum Sep 26, 2023
40d87ef
Update to the newest seals environment versions.
ernestum Sep 26, 2023
71f6c92
Push gymnasium dependency to 0.29 to ensure mujoco envs work.
ernestum Sep 27, 2023
53c1212
Update adversarial algorithm
taufeeque9 Aug 8, 2023
47b3874
Fix test errors
taufeeque9 Aug 8, 2023
9fa8969
Fix test errors
taufeeque9 Aug 8, 2023
3edf518
Don't enter the generator logging ctx twice.
ernestum Sep 26, 2023
ce8c87d
Update common.py to fix test errors
taufeeque9 Sep 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion benchmarking/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,27 @@ python -m imitation.scripts.<train_script> <algo> with benchmarking/<config_name

```python
...
ex.add_config('benchmarking/<config_name>.json')
from imitation.scripts.<train_script> import <train_ex>
<train_ex>.run(command_name="<algo>", named_configs=["benchmarking/<config_name>.json"])
```

# Tuning Hyperparameters

The hyperparameters of any algorithm in imitation can be tuned using the `tuning.py` script.
The benchmarking hyperparameter configs were generated by tuning the hyperparameters using
the search space defined in the `tuning_config.py` script. The tuning script proceeds in two
phases: 1) The hyperparameters are tuned using the search space provided, and 2) the best
hyperparameter config found in the first phase based on the maximum mean return is
re-evaluated on a separate set of seeds, and the mean and standard deviation of these trials
are reported.

To tune the hyperparameters of an algorithm using the default search space provided:
```bash
python tuning.py with {algo} 'parallel_run_config.base_named_configs=["{env}"]'
```

In this command, `{algo}` provides the default search space and settings to be used for
the specific algorithm, which is defined in the `tuning_config.py` script and
`'parallel_run_config.base_named_configs=["{env}"]'` sets the environment to tune the algorithm in.
See the documentation of `tuning.py` and `parallel.py` scripts for many other arguments that can be
provided through the command line to change the tuning behavior.
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Ant-v0"
"gym_id": "seals/Ant-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/HalfCheetah-v0"
"gym_id": "seals/HalfCheetah-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Hopper-v0"
"gym_id": "seals/Hopper-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
},
"expert": {
"loader_kwargs": {
"gym_id": "seals/Swimmer-v0",
"gym_id": "seals/Swimmer-v1",
"organization": "HumanCompatibleAI"
}
},
Expand Down Expand Up @@ -81,6 +81,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Swimmer-v0"
"gym_id": "seals/Swimmer-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
},
"expert": {
"loader_kwargs": {
"gym_id": "seals/Walker2d-v0",
"gym_id": "seals/Walker2d-v1",
"organization": "HumanCompatibleAI"
}
},
Expand Down Expand Up @@ -81,6 +81,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Walker2d-v0"
"gym_id": "seals/Walker2d-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Ant-v0"
"gym_id": "seals/Ant-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/HalfCheetah-v0"
"gym_id": "seals/HalfCheetah-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Hopper-v0"
"gym_id": "seals/Hopper-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Swimmer-v0"
"gym_id": "seals/Swimmer-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Walker2d-v0"
"gym_id": "seals/Walker2d-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Ant-v0"
"gym_id": "seals/Ant-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/HalfCheetah-v0"
"gym_id": "seals/HalfCheetah-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Hopper-v0"
"gym_id": "seals/Hopper-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Swimmer-v0"
"gym_id": "seals/Swimmer-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Walker2d-v0"
"gym_id": "seals/Walker2d-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Ant-v0"
"gym_id": "seals/Ant-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/HalfCheetah-v0"
"gym_id": "seals/HalfCheetah-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Hopper-v0"
"gym_id": "seals/Hopper-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
},
"expert": {
"loader_kwargs": {
"gym_id": "seals/Swimmer-v0",
"gym_id": "seals/Swimmer-v1",
"organization": "HumanCompatibleAI"
}
},
Expand Down Expand Up @@ -81,6 +81,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Swimmer-v0"
"gym_id": "seals/Swimmer-v1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
},
"expert": {
"loader_kwargs": {
"gym_id": "seals/Walker2d-v0",
"gym_id": "seals/Walker2d-v1",
"organization": "HumanCompatibleAI"
}
},
Expand Down Expand Up @@ -81,6 +81,6 @@
"n_episodes_eval": 50
},
"environment": {
"gym_id": "seals/Walker2d-v0"
"gym_id": "seals/Walker2d-v1"
}
}
174 changes: 174 additions & 0 deletions benchmarking/tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Tunes the hyperparameters of the algorithms."""

import copy
import pathlib
from typing import Any, Dict

import numpy as np
import ray
from pandas.api import types as pd_types
from ray.tune.search import optuna
from sacred.observers import FileStorageObserver
from tuning_config import parallel_ex, tuning_ex


@tuning_ex.main
def tune(
parallel_run_config: Dict[str, Any],
eval_best_trial_resource_multiplier: int = 1,
num_eval_seeds: int = 5,
) -> None:
"""Tune hyperparameters of imitation algorithms using parallel script.

Args:
parallel_run_config: Dictionary of arguments to pass to the parallel script.
eval_best_trial_resource_multiplier: Factor by which to multiply the
number of cpus per trial in `resources_per_trial`. This is useful for
allocating more resources per trial to the evaluation trials than the
resources for hyperparameter tuning since number of evaluation trials
is usually much smaller than the number of tuning trials.
num_eval_seeds: Number of distinct seeds to evaluate the best trial on.
Set to 0 to disable evaluation.

Raises:
ValueError: If no trials are returned by the parallel run of tuning.
"""
updated_parallel_run_config = copy.deepcopy(parallel_run_config)
search_alg = optuna.OptunaSearch()
if "tune_run_kwargs" in updated_parallel_run_config:
updated_parallel_run_config["tune_run_kwargs"]["search_alg"] = search_alg
else:
updated_parallel_run_config["tune_run_kwargs"] = dict(search_alg=search_alg)
run = parallel_ex.run(config_updates=updated_parallel_run_config)
experiment_analysis = run.result
if not experiment_analysis.trials:
raise ValueError(
"No trials found. Please ensure that the `experiment_checkpoint_path` "
"in `parallel_run_config` is passed correctly "
"or that the tuning run finished properly.",
)

return_key = "imit_stats/monitor_return_mean"
if updated_parallel_run_config["sacred_ex_name"] == "train_rl":
return_key = "monitor_return_mean"
best_trial = find_best_trial(experiment_analysis, return_key, print_return=True)

if num_eval_seeds > 0: # evaluate the best trial
resources_per_trial_eval = copy.deepcopy(
updated_parallel_run_config["resources_per_trial"],
)
# update cpus per trial only if it is provided in `resources_per_trial`
# Uses the default values (cpu=1) if it is not provided
if "cpu" in updated_parallel_run_config["resources_per_trial"]:
resources_per_trial_eval["cpu"] *= eval_best_trial_resource_multiplier
evaluate_trial(
best_trial,
num_eval_seeds,
updated_parallel_run_config["run_name"] + "_best_hp_eval",
updated_parallel_run_config,
resources_per_trial_eval,
return_key,
)


def find_best_trial(
experiment_analysis: ray.tune.analysis.ExperimentAnalysis,
return_key: str,
print_return: bool = False,
) -> ray.tune.experiment.Trial:
"""Find the trial with the best mean return across all seeds.

Args:
experiment_analysis: The result of a parallel/tuning experiment.
return_key: The key of the return metric in the results dataframe.
print_return: Whether to print the mean and std of the returns
of the best trial.

Returns:
best_trial: The trial with the best mean return across all seeds.
"""
df = experiment_analysis.results_df
# convert object dtype to str required by df.groupby
for col in df.columns:
if pd_types.is_object_dtype(df[col]):
df[col] = df[col].astype("str")
# group into separate HP configs
grp_keys = [c for c in df.columns if c.startswith("config") and "seed" not in c]
grps = df.groupby(grp_keys)
# store mean return of runs across all seeds in a group
df["mean_return"] = grps[return_key].transform(lambda x: x.mean())
best_config_df = df[df["mean_return"] == df["mean_return"].max()]
row = best_config_df.iloc[0]
best_config_tag = row["experiment_tag"]
assert experiment_analysis.trials is not None # for mypy
best_trial = [
t for t in experiment_analysis.trials if best_config_tag in t.experiment_tag
][0]

if print_return:
all_returns = df[df["mean_return"] == row["mean_return"]][return_key]
all_returns = all_returns.to_numpy()
print("All returns:", all_returns)
print("Mean return:", row["mean_return"])
print("Std return:", np.std(all_returns))
print("Total seeds:", len(all_returns))
return best_trial


def evaluate_trial(
trial: ray.tune.experiment.Trial,
num_eval_seeds: int,
run_name: str,
parallel_run_config: Dict[str, Any],
resources_per_trial: Dict[str, int],
return_key: str,
print_return: bool = False,
):
"""Evaluate a given trial of a parallel run on a separate set of seeds.

Args:
trial: The trial to evaluate.
num_eval_seeds: Number of distinct seeds to evaluate the best trial on.
run_name: The name of the evaluation run.
parallel_run_config: Dictionary of arguments passed to the parallel
script to get best_trial.
resources_per_trial: Resources to be used for each evaluation trial.
return_key: The key of the return metric in the results dataframe.
print_return: Whether to print the mean and std of the evaluation returns.

Returns:
eval_run: The result of the evaluation run.
"""
config = trial.config
config["config_updates"].update(
seed=ray.tune.grid_search(list(range(100, 100 + num_eval_seeds))),
)
eval_config_updates = parallel_run_config.copy()
eval_config_updates.update(
run_name=run_name,
num_samples=1,
search_space=config,
resources_per_trial=resources_per_trial,
search_alg=None,
repeat=1,
experiment_checkpoint_path="",
)
eval_run = parallel_ex.run(config_updates=eval_config_updates)
eval_result = eval_run.result
returns = eval_result.results_df[return_key].to_numpy()
if print_return:
print("All returns:", returns)
print("Mean:", np.mean(returns))
print("Std:", np.std(returns))
return eval_run


def main_console():
observer_path = pathlib.Path.cwd() / "output" / "sacred" / "tuning"
observer = FileStorageObserver(observer_path)
tuning_ex.observers.append(observer)
tuning_ex.run_commandline()


if __name__ == "__main__": # pragma: no cover
main_console()
Loading