diff --git a/tuning/benchmark_analysis.ipynb b/tuning/benchmark_analysis.ipynb new file mode 100644 index 000000000..016a66f8f --- /dev/null +++ b/tuning/benchmark_analysis.ipynb @@ -0,0 +1,213 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5c06148d9ff6b57", + "metadata": { + "collapsed": false + }, + "source": [ + "This notebook loads all the optuna studies in the \"tuning\" folder and arranges them in a dataframe. It also loads the performance of the best model from the paper and the rerun results.\n", + "\n", + "It can serve as a starting point for further analysis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31e6f532-15c3-494a-8a3a-de25ecc1ee90", + "metadata": {}, + "outputs": [], + "source": [ + "# Load all the studies into a dataframe\n", + "\n", + "import optuna\n", + "from collections import Counter\n", + "from optuna.trial import TrialState\n", + "import pandas as pd\n", + "import numpy as np\n", + "import datetime\n", + "from pathlib import Path\n", + "\n", + "import imitation.util.sacred_file_parsing as sfp\n", + "\n", + "\n", + "experiment_log_files = list(Path().glob(\"*/*.log\"))\n", + "\n", + "experiment_log_files\n", + "\n", + "raw_study_data = []\n", + "\n", + "for log_file in experiment_log_files:\n", + " d = dict()\n", + " \n", + " d['logfile'] = log_file\n", + " \n", + " study = optuna.load_study(storage=optuna.storages.JournalStorage(\n", + " optuna.storages.JournalFileStorage(str(log_file))\n", + " ),\n", + " # in our case, we have one journal file per study so the study name can be\n", + " # inferred\n", + " study_name=None,\n", + " )\n", + " d['study'] = study\n", + " d['study_name'] = study.study_name\n", + " \n", + " trial_state_counter = Counter(t.state for t in study.trials)\n", + " n_completed_trials = trial_state_counter[TrialState.COMPLETE]\n", + " d['trials'] = n_completed_trials\n", + " d['trials_running'] = Counter(t.state for t in study.trials)[TrialState.RUNNING]\n", + " d['trials_failed'] = Counter(t.state for t in study.trials)[TrialState.FAIL]\n", + " d['all_trials'] = len(study.trials)\n", + " \n", + " if n_completed_trials > 0:\n", + " d['best_value'] = round(study.best_trial.value, 2)\n", + " \n", + " assert \"_\" in study.study_name\n", + " study_segments = study.study_name.split(\"_\") \n", + " assert len(study_segments) > 3\n", + " tuning, algo, with_ = study_segments[:3]\n", + " assert (tuning, with_) == (\"tuning\", \"with\")\n", + " \n", + " d['algo'] = algo\n", + " d['env'] = \"_\".join(study_segments[3:])\n", + " d['best_trial_duration'] = study.best_trial.duration\n", + " d['mean_duration'] = sum([t.duration for t in study.trials if t.state == TrialState.COMPLETE], datetime.timedelta())/n_completed_trials\n", + " \n", + " reruns_folder = log_file.parent / \"reruns\"\n", + " rerun_results = [round(run['result']['imit_stats']['monitor_return_mean'], 2)\n", + " for conf, run in sfp.find_sacred_runs(reruns_folder, only_completed_runs=True)]\n", + " d['rerun_values'] = rerun_results\n", + " \n", + " raw_study_data.append(d)\n", + " \n", + "study_data = pd.DataFrame(raw_study_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b604bc7e-2e61-4f7f-acfe-87b57e8a2f5a", + "metadata": {}, + "outputs": [], + "source": [ + "# Add performance of the best model from the paper\n", + "import pandas as pd\n", + "\n", + "environments = [\n", + " \"seals_ant\",\n", + " \"seals_half_cheetah\",\n", + " \"seals_hopper\",\n", + " \"seals_swimmer\",\n", + " \"seals_walker\",\n", + " \"seals_humanoid\",\n", + " \"seals_cartpole\",\n", + " \"pendulum\",\n", + " \"seals_mountain_car\"\n", + "]\n", + "\n", + "pc_paper_700 = dict(\n", + " seals_ant=200,\n", + " seals_half_cheetah=4700,\n", + " seals_hopper=4500,\n", + " seals_swimmer=170,\n", + " seals_walker=4900,\n", + " seals_humanoid=\"-\",\n", + " seals_cartpole=\"-\",\n", + " pendulum=1300,\n", + " seals_mountain_car=\"-\",\n", + ")\n", + "\n", + "pc_paper_1400 = dict(\n", + " seals_ant=100,\n", + " seals_half_cheetah=5600,\n", + " seals_hopper=4500,\n", + " seals_swimmer=175,\n", + " seals_walker=5900,\n", + " seals_humanoid=\"-\",\n", + " seals_cartpole=\"-\",\n", + " pendulum=750,\n", + " seals_mountain_car=\"-\",\n", + ")\n", + "\n", + "rl_paper = dict(\n", + " seals_ant=16,\n", + " seals_half_cheetah=420,\n", + " seals_hopper=4210,\n", + " seals_swimmer=175,\n", + " seals_walker=5370,\n", + " seals_humanoid=\"-\",\n", + " seals_cartpole=\"-\",\n", + " pendulum=1300,\n", + " seals_mountain_car=\"-\",\n", + ")\n", + "\n", + "rl_ours = dict(\n", + " seals_ant=3034,\n", + " seals_half_cheetah=1675.76,\n", + " seals_hopper=203.45,\n", + " seals_swimmer=292.84,\n", + " seals_walker=2465.56,\n", + " seals_humanoid=3224.12,\n", + " seals_cartpole=500.00,\n", + " pendulum=-189.25,\n", + " seals_mountain_car=-97.00,\n", + ")\n", + "\n", + "for algo, values_by_env in dict(\n", + " pc_paper_700=pc_paper_700,\n", + " pc_paper_1400=pc_paper_1400,\n", + " rl_paper=rl_paper,\n", + " rl_ours=rl_ours,\n", + ").items():\n", + " for env, value in values_by_env.items():\n", + " if value == \"-\":\n", + " continue\n", + " raw_study_data.append(dict(\n", + " algo=algo,\n", + " env=env,\n", + " best_value=value,\n", + " ))\n", + " \n", + "study_data = pd.DataFrame(raw_study_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e9ae5ca-5002-411b-beaf-cb98eb12f54c", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import display\n", + "\n", + "print(\"Benchmark Data\")\n", + "display(study_data[[\"algo\", \"env\", \"best_value\"]])\n", + "\n", + "print(\"Rerun Data\")\n", + "display(study_data[[\"algo\", \"env\", \"best_value\", \"rerun_values\"]][study_data[\"rerun_values\"].map(np.std) > 0])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}