From 0f488360f693b798611e5698c267e507bef102c0 Mon Sep 17 00:00:00 2001 From: "Mauricio A. Rovira Galvez" <8482308+marovira@users.noreply.github.com> Date: Fri, 22 Nov 2024 16:17:48 -0800 Subject: [PATCH] [brief] Big expansion for ensuring optuna trials can be resumed. [detailed] - The previous system was unable to cope with situations where studies failed due to external reasons (system shutdowns, errors, process being killed, etc). To address this, the function was moved out of the plug-in class (since it didn't really need to be there) and now focuses on creating a new study with the correct trials moved over from the previous version. This ensures that if any trials fail or the process is killed and a trial is left running, it is correctly handled. - Extends the system to allow checkpoints to be made from the samplers. This ensures reproducibility even if the study is stopped. --- src/helios/plugins/optuna.py | 215 ++++++++++++++++++++++++----------- test/test_plugins.py | 127 +++++++++++++++------ 2 files changed, 242 insertions(+), 100 deletions(-) diff --git a/src/helios/plugins/optuna.py b/src/helios/plugins/optuna.py index a64efea..009a3d6 100644 --- a/src/helios/plugins/optuna.py +++ b/src/helios/plugins/optuna.py @@ -2,8 +2,10 @@ import optuna except ImportError as e: raise ImportError("error: OptunaPlugin requires Optuna to be installed") from e +import gc +import pathlib +import pickle import typing -import warnings import torch @@ -20,6 +22,152 @@ _ORIG_NUMBER_KEY = "hl:id" +def resume_study( + study_args: dict[str, typing.Any], + failed_states: typing.Sequence = (optuna.trial.TrialState.FAIL,), +) -> optuna.Study: + """ + Resume a study that stopped because of a failure. + + The goal of this function is to allow studies that failed due to an error (either an + exception, system error etc.) to continue utilising the built-in checkpoint system + from Helios. To accomplish this, the function will do the following: + + #. Grab all the trials from the study created by the given ``study_args``, splitting + them into three groups: completed, failed, and failed but completed. + #. Create a new study with the same name and storage. This new study will get all of + the completed trials of the original, and will have the failed trials re-enqueued. + + .. warning:: + This function should **only** be called before trials are started. + .. warning:: + This function **requires** ``RDBStorage`` as the storage argument for + ``optuna.create_study``. + + The ``failed_states`` argument can be used to set additional trial states to be + considered as "failures". This can be useful when dealing with special cases where + trials were either completed or pruned but need to be re-run. + + This function works in tandem with + :py:meth:`~helios.plugins.optuna.OptunaPlugin.configure_model` to ensure that when + the failed trial is re-run, the original save name is restored so any saved + checkpoints can be re-used so the trial can continue instead of starting from + scratch. + + .. note:: + Only trials that fail but haven't been completed will be enqueued by this + function. If a trial fails and is completed later on, it will be skipped. + + Args: + study_args: dictionary of arguments for ``optuna.create_study``. + failed_states: the trial states that are considered to be failures and should + be re-enqueued. + """ + if "storage" not in study_args: + raise TypeError("error: RDB storage is required for resuming studies") + if "load_if_exists" not in study_args or not study_args["load_if_exists"]: + raise KeyError("error: study must be created with 'load_if_exists' set to True") + + storage_str: str = study_args["storage"] + if not isinstance(storage_str, str): + raise TypeError("error: only strings are supported for 'storage'") + + storage = pathlib.Path(storage_str.removeprefix("sqlite:///")).resolve() + + # Step 1: create the study with the current DB and grab all the trials. + study = optuna.create_study(**study_args) + complete: list[optuna.trial.FrozenTrial] = [] + failed_but_completed: list[optuna.trial.FrozenTrial] = [] + failed: dict[int, optuna.trial.FrozenTrial] = {} + + for trial in study.trials: + if ( + trial.state == optuna.trial.TrialState.COMPLETE + and _ORIG_NUMBER_KEY in trial.user_attrs + ): + failed_but_completed.append(trial) + elif trial.state == optuna.trial.TrialState.COMPLETE: + complete.append(trial) + elif trial.state in failed_states: + failed[trial.number] = trial + + # Make sure that any trials that failed but were completed are pruned from the failed + # trials list. + for trial in failed_but_completed: + trial_num = trial.user_attrs[_ORIG_NUMBER_KEY] + failed.pop(trial_num, None) + + # Make sure that the study is cleared out before we attempt to rename the storage + del study + gc.collect() + + # Step 2: rename the DB and create a new empty study. + tmp_storage = storage.parent / (storage.stem + "_tmp" + storage.suffix) + storage.rename(tmp_storage) + study = optuna.create_study(**study_args) + + # Step 3: move all the trials into the new study, re-setting all trials that failed. + for trial in complete: + study.add_trial(trial) + + for _, trial in failed.items(): + study.enqueue_trial(trial.params, {_ORIG_NUMBER_KEY: trial.number}) + + # Once everything's done, clean up the temp storage. + tmp_storage.unlink() + + return study + + +def checkpoint_sampler(trial: optuna.Trial, chkpt_root: pathlib.Path) -> None: + """ + Create a checkpoint with the state of the sampler. + + This function can be used to ensure that if a study is restarted, the state of the + sampler is recovered so trials can be reproducible. The function will automatically + create a checkpoint using ``torch.save``. + + .. note:: + It is recommended that this function be called at the start of the objective + function to ensure the checkpoint is made correctly, but it can be called at any + time. + + Args: + trial: the current trial. + chkpt_root: the root where the checkpoints will be saved. + """ + chkpt_path = chkpt_root / (f"sampler_trial-{trial.number}.pkl") + + sampler = trial.study.sampler + with chkpt_path.open("wb") as outfile: + pickle.dump(sampler, outfile) + + +def restore_sampler(chkpt_root: pathlib.Path) -> optuna.samplers.BaseSampler: + """ + Restore the sampler from a previously saved checkpoint. + + This function can be used in tandem with + :py:func:`~helios.plugins.optuna.checkpoint_sampler` to ensure that the last + checkpoint is loaded and the correct state is restored for the sampler. This function + **needs** to be called before ``optuna.create_study`` is called. + + Args: + chkpt_root: the root where the checkpoints are stored. + + Returns: + The restored sampler. + """ + + def key(path: pathlib.Path) -> int: + return int(path.stem.split("-")[-1]) + + chkpts = list(chkpt_root.glob("*.pkl")) + chkpts.sort(key=key) + sampler = pickle.load(chkpts[-1].open("rb")) + return sampler + + @hlp.PLUGIN_REGISTRY.register class OptunaPlugin(hlp.Plugin): """ @@ -93,71 +241,6 @@ def trial(self) -> optuna.Trial: def trial(self, t: optuna.Trial) -> None: self._trial = t - @classmethod - def enqueue_failed_trials( - cls, - study: optuna.study.Study, - failed_states: typing.Sequence = (optuna.trial.TrialState.FAIL,), - ) -> None: - """ - Enqueue any failed trials so they can be re-run. - - This will add any failed trials from a previous run. This is used for cases when - the study had to be stopped due to an error, exception, or by the user, allowing - trials that didn't finish to complete. - - .. warning:: - This function should **only** be called before trials are started. - - The ``failed_states`` argument can be used to set additional trial states to be - considered as "failures". This can be useful when dealing with special cases where - trials were either completed or pruned but need to be re-run. - - This function works in tandem with - :py:meth:`~helios.plugins.optuna.OptunaPlugin.configure_model` to ensure that when - the failed trial is re-run, the original save name is restored so any saved - checkpoints can be re-used so the trial can continue instead of starting from - scratch. - - .. note:: - Only trials that fail but haven't been completed will be enqueued by this - function. If a trial fails and is completed later on, it will be skipped. - - .. warning:: - Depending on the reason for a trial failing, it is possible for this function - to re-add trials that will continue to fail. If you require special handling, - you may override this function to achieve your desired behaviour. - - - Args: - study: the study to get the failed trials from and enqueue them. - failed_states: the trial states that are considered to be failures and should - be re-enqueued. - """ - if optuna.trial.TrialState.COMPLETE in failed_states: - warnings.warn( - "warning: re-enqueuing completed trials could lead to incorrect results", - stacklevel=2, - ) - failed_but_completed: list[optuna.trial.FrozenTrial] = [] - failed: dict[int, optuna.trial.FrozenTrial] = {} - for trial in study.trials: - if ( - trial.state == optuna.trial.TrialState.COMPLETE - and _ORIG_NUMBER_KEY in trial.user_attrs - ): - failed_but_completed.append(trial) - - if trial.state in failed_states: - failed[trial.number] = trial - - for trial in failed_but_completed: - trial_num = trial.user_attrs[_ORIG_NUMBER_KEY] - failed.pop(trial_num, None) - - for _, trial in failed.items(): - study.enqueue_trial(trial.params, {_ORIG_NUMBER_KEY: trial.number}) - def configure_trainer(self, trainer: hlt.Trainer) -> None: """ Configure the trainer with the required settings. diff --git a/test/test_plugins.py b/test/test_plugins.py index b9d9421..98285c4 100644 --- a/test/test_plugins.py +++ b/test/test_plugins.py @@ -1,3 +1,5 @@ +import dataclasses as dc +import functools import pathlib import typing @@ -7,8 +9,12 @@ import helios.model as hlm import helios.plugins as hlp +import helios.plugins.optuna as hlpo import helios.trainer as hlt -from helios.plugins.optuna import _ORIG_NUMBER_KEY, OptunaPlugin +from helios.core import rng + +# Ignore the use of private members so we can test them correctly. +# ruff: noqa: SLF001 class ExceptionPlugin(hlp.Plugin): @@ -124,12 +130,12 @@ def test_configure(self) -> None: ) class TestOptunaPlugin: def test_plugin_id(self) -> None: - assert hasattr(OptunaPlugin, "plugin_id") - assert OptunaPlugin.plugin_id == "optuna" + assert hasattr(hlpo.OptunaPlugin, "plugin_id") + assert hlpo.OptunaPlugin.plugin_id == "optuna" def test_invalid_storage(self) -> None: def objective(trial: optuna.Trial) -> int: - plugin = OptunaPlugin(trial, "accuracy") + plugin = hlpo.OptunaPlugin(trial, "accuracy") plugin.is_distributed = True with pytest.raises(ValueError): plugin.setup() @@ -141,7 +147,7 @@ def objective(trial: optuna.Trial) -> int: def test_configure(self) -> None: def objective(trial: optuna.Trial) -> int: - plugin = OptunaPlugin(trial, "accuracy") + plugin = hlpo.OptunaPlugin(trial, "accuracy") trainer = hlt.Trainer() plugin.configure_trainer(trainer) @@ -171,7 +177,7 @@ def check_in_sequence(self, val: typing.Any, t: type, seq: typing.Container) -> def test_suggest(self) -> None: def objective(trial: optuna.Trial) -> int: - plugin = OptunaPlugin(trial, "accuracy") + plugin = hlpo.OptunaPlugin(trial, "accuracy") with pytest.raises(KeyError): plugin.suggest("foo", "bar") @@ -199,7 +205,7 @@ def objective(trial: optuna.Trial) -> int: def test_state_dict(self) -> None: def objective(trial: optuna.Trial) -> int: - plugin = OptunaPlugin(trial, "accuracy") + plugin = hlpo.OptunaPlugin(trial, "accuracy") x = plugin.suggest("float", "x", low=-10, high=10) state_dict = plugin.state_dict() @@ -214,42 +220,41 @@ def objective(trial: optuna.Trial) -> int: def test_resume_trial(self, tmp_path: pathlib.Path) -> None: num_trials = 10 + storage_path = tmp_path / "trial_test.db" successful_trials = [False for _ in range(num_trials)] offset = 0 + study_args = { + "study_name": "trial_test", + "storage": f"sqlite:///{storage_path}", + "load_if_exists": True, + } def objective(trial: optuna.Trial) -> float: nonlocal offset - plugin = OptunaPlugin(trial, "accuracy") + plugin = hlpo.OptunaPlugin(trial, "accuracy") plugin.setup() model = PluginModel("plugin-model") plugin.configure_model(model) trial_num = trial.number - if _ORIG_NUMBER_KEY in trial.user_attrs: - trial_num = trial.user_attrs[_ORIG_NUMBER_KEY] + if hlpo._ORIG_NUMBER_KEY in trial.user_attrs: + trial_num = trial.user_attrs[hlpo._ORIG_NUMBER_KEY] assert model.save_name == f"plugin-model_trial-{trial_num}" - # Artificially offset the number by 1 so that when we subtract 1 later on - # the indices work out correctly. + # Artificially offset the trial number so we don't raise the exception + # again. trial_num += 1 offset = 1 - if trial.number == num_trials / 2: + if trial_num == num_trials / 2: raise RuntimeError("half-way stop") successful_trials[trial_num - offset] = True + offset = 0 return 0 - def create_study() -> optuna.Study: - storage_path = tmp_path / "trial_test.db" - return optuna.create_study( - study_name="trial_test", - storage=f"sqlite:///{storage_path}", - load_if_exists=True, - ) - def optimize(study: optuna.Study) -> None: study.optimize( objective, @@ -262,30 +267,84 @@ def optimize(study: optuna.Study) -> None: ], ) - study = create_study() + study = hlpo.resume_study(study_args) with pytest.raises(RuntimeError): optimize(study) del study - study = create_study() - - OptunaPlugin.enqueue_failed_trials(study) + study = hlpo.resume_study(study_args) optimize(study) for v in successful_trials: assert v - def test_enqueue_failed_trials(self, tmp_path: pathlib.Path) -> None: - def objective(trial: optuna.Trial) -> float: - return 0 + def test_sampler_checkpoints(self, tmp_path: pathlib.Path) -> None: + @dc.dataclass + class TestRun: + samples: list[float] + chkpt_root: pathlib.Path + + num_trials = 10 + run1 = TestRun([], tmp_path / "run1") + run2 = TestRun([], tmp_path / "run2") + + run1.chkpt_root.mkdir(exist_ok=True) + run2.chkpt_root.mkdir(exist_ok=True) + + def objective( + trial: optuna.Trial, + raise_error: bool, + run: TestRun, + ) -> float: + hlpo.checkpoint_sampler(trial, run.chkpt_root) + res = trial.suggest_float("accuracy", 0, 1) + if raise_error and trial.number == num_trials // 2: + raise RuntimeError("half-way stop") + run.samples.append(res) + return res - storage_path = tmp_path / "enqueue_test.db" - study = optuna.create_study( - study_name="enqueue_test", storage=f"sqlite:///{storage_path}" + def create_study( + sampler: optuna.samplers.BaseSampler | None = None, + ) -> optuna.Study: + return optuna.create_study( + study_name="chkpt_test", + sampler=optuna.samplers.TPESampler(seed=rng.get_default_seed()) + if sampler is None + else sampler, + ) + + study = create_study() + study.optimize( + functools.partial( + objective, + raise_error=False, + run=run1, + ), + n_trials=num_trials, ) - with pytest.warns(UserWarning): - OptunaPlugin.enqueue_failed_trials( - study, [optuna.trial.TrialState.FAIL, optuna.trial.TrialState.COMPLETE] + chkpts = list(run1.chkpt_root.glob("*.pkl")) + assert len(chkpts) == num_trials + assert all(chkpt.stem == f"sampler_trial-{i}" for i, chkpt in enumerate(chkpts)) + del study + + study = create_study() + with pytest.raises(RuntimeError): + study.optimize( + functools.partial( + objective, + raise_error=True, + run=run2, + ), + n_trials=num_trials, ) + + torch.serialization.clear_safe_globals() + sampler = hlpo.restore_sampler(run2.chkpt_root) + study = create_study(sampler) + study.optimize( + functools.partial(objective, raise_error=False, run=run2), + n_trials=num_trials // 2, + ) + assert run2.samples == run1.samples