diff --git a/src/helios/plugins/optuna.py b/src/helios/plugins/optuna.py index a64efea..009a3d6 100644 --- a/src/helios/plugins/optuna.py +++ b/src/helios/plugins/optuna.py @@ -2,8 +2,10 @@ import optuna except ImportError as e: raise ImportError("error: OptunaPlugin requires Optuna to be installed") from e +import gc +import pathlib +import pickle import typing -import warnings import torch @@ -20,6 +22,152 @@ _ORIG_NUMBER_KEY = "hl:id" +def resume_study( + study_args: dict[str, typing.Any], + failed_states: typing.Sequence = (optuna.trial.TrialState.FAIL,), +) -> optuna.Study: + """ + Resume a study that stopped because of a failure. + + The goal of this function is to allow studies that failed due to an error (either an + exception, system error etc.) to continue utilising the built-in checkpoint system + from Helios. To accomplish this, the function will do the following: + + #. Grab all the trials from the study created by the given ``study_args``, splitting + them into three groups: completed, failed, and failed but completed. + #. Create a new study with the same name and storage. This new study will get all of + the completed trials of the original, and will have the failed trials re-enqueued. + + .. warning:: + This function should **only** be called before trials are started. + .. warning:: + This function **requires** ``RDBStorage`` as the storage argument for + ``optuna.create_study``. + + The ``failed_states`` argument can be used to set additional trial states to be + considered as "failures". This can be useful when dealing with special cases where + trials were either completed or pruned but need to be re-run. + + This function works in tandem with + :py:meth:`~helios.plugins.optuna.OptunaPlugin.configure_model` to ensure that when + the failed trial is re-run, the original save name is restored so any saved + checkpoints can be re-used so the trial can continue instead of starting from + scratch. + + .. note:: + Only trials that fail but haven't been completed will be enqueued by this + function. If a trial fails and is completed later on, it will be skipped. + + Args: + study_args: dictionary of arguments for ``optuna.create_study``. + failed_states: the trial states that are considered to be failures and should + be re-enqueued. + """ + if "storage" not in study_args: + raise TypeError("error: RDB storage is required for resuming studies") + if "load_if_exists" not in study_args or not study_args["load_if_exists"]: + raise KeyError("error: study must be created with 'load_if_exists' set to True") + + storage_str: str = study_args["storage"] + if not isinstance(storage_str, str): + raise TypeError("error: only strings are supported for 'storage'") + + storage = pathlib.Path(storage_str.removeprefix("sqlite:///")).resolve() + + # Step 1: create the study with the current DB and grab all the trials. + study = optuna.create_study(**study_args) + complete: list[optuna.trial.FrozenTrial] = [] + failed_but_completed: list[optuna.trial.FrozenTrial] = [] + failed: dict[int, optuna.trial.FrozenTrial] = {} + + for trial in study.trials: + if ( + trial.state == optuna.trial.TrialState.COMPLETE + and _ORIG_NUMBER_KEY in trial.user_attrs + ): + failed_but_completed.append(trial) + elif trial.state == optuna.trial.TrialState.COMPLETE: + complete.append(trial) + elif trial.state in failed_states: + failed[trial.number] = trial + + # Make sure that any trials that failed but were completed are pruned from the failed + # trials list. + for trial in failed_but_completed: + trial_num = trial.user_attrs[_ORIG_NUMBER_KEY] + failed.pop(trial_num, None) + + # Make sure that the study is cleared out before we attempt to rename the storage + del study + gc.collect() + + # Step 2: rename the DB and create a new empty study. + tmp_storage = storage.parent / (storage.stem + "_tmp" + storage.suffix) + storage.rename(tmp_storage) + study = optuna.create_study(**study_args) + + # Step 3: move all the trials into the new study, re-setting all trials that failed. + for trial in complete: + study.add_trial(trial) + + for _, trial in failed.items(): + study.enqueue_trial(trial.params, {_ORIG_NUMBER_KEY: trial.number}) + + # Once everything's done, clean up the temp storage. + tmp_storage.unlink() + + return study + + +def checkpoint_sampler(trial: optuna.Trial, chkpt_root: pathlib.Path) -> None: + """ + Create a checkpoint with the state of the sampler. + + This function can be used to ensure that if a study is restarted, the state of the + sampler is recovered so trials can be reproducible. The function will automatically + create a checkpoint using ``torch.save``. + + .. note:: + It is recommended that this function be called at the start of the objective + function to ensure the checkpoint is made correctly, but it can be called at any + time. + + Args: + trial: the current trial. + chkpt_root: the root where the checkpoints will be saved. + """ + chkpt_path = chkpt_root / (f"sampler_trial-{trial.number}.pkl") + + sampler = trial.study.sampler + with chkpt_path.open("wb") as outfile: + pickle.dump(sampler, outfile) + + +def restore_sampler(chkpt_root: pathlib.Path) -> optuna.samplers.BaseSampler: + """ + Restore the sampler from a previously saved checkpoint. + + This function can be used in tandem with + :py:func:`~helios.plugins.optuna.checkpoint_sampler` to ensure that the last + checkpoint is loaded and the correct state is restored for the sampler. This function + **needs** to be called before ``optuna.create_study`` is called. + + Args: + chkpt_root: the root where the checkpoints are stored. + + Returns: + The restored sampler. + """ + + def key(path: pathlib.Path) -> int: + return int(path.stem.split("-")[-1]) + + chkpts = list(chkpt_root.glob("*.pkl")) + chkpts.sort(key=key) + sampler = pickle.load(chkpts[-1].open("rb")) + return sampler + + @hlp.PLUGIN_REGISTRY.register class OptunaPlugin(hlp.Plugin): """ @@ -93,71 +241,6 @@ def trial(self) -> optuna.Trial: def trial(self, t: optuna.Trial) -> None: self._trial = t - @classmethod - def enqueue_failed_trials( - cls, - study: optuna.study.Study, - failed_states: typing.Sequence = (optuna.trial.TrialState.FAIL,), - ) -> None: - """ - Enqueue any failed trials so they can be re-run. - - This will add any failed trials from a previous run. This is used for cases when - the study had to be stopped due to an error, exception, or by the user, allowing - trials that didn't finish to complete. - - .. warning:: - This function should **only** be called before trials are started. - - The ``failed_states`` argument can be used to set additional trial states to be - considered as "failures". This can be useful when dealing with special cases where - trials were either completed or pruned but need to be re-run. - - This function works in tandem with - :py:meth:`~helios.plugins.optuna.OptunaPlugin.configure_model` to ensure that when - the failed trial is re-run, the original save name is restored so any saved - checkpoints can be re-used so the trial can continue instead of starting from - scratch. - - .. note:: - Only trials that fail but haven't been completed will be enqueued by this - function. If a trial fails and is completed later on, it will be skipped. - - .. warning:: - Depending on the reason for a trial failing, it is possible for this function - to re-add trials that will continue to fail. If you require special handling, - you may override this function to achieve your desired behaviour. - - - Args: - study: the study to get the failed trials from and enqueue them. - failed_states: the trial states that are considered to be failures and should - be re-enqueued. - """ - if optuna.trial.TrialState.COMPLETE in failed_states: - warnings.warn( - "warning: re-enqueuing completed trials could lead to incorrect results", - stacklevel=2, - ) - failed_but_completed: list[optuna.trial.FrozenTrial] = [] - failed: dict[int, optuna.trial.FrozenTrial] = {} - for trial in study.trials: - if ( - trial.state == optuna.trial.TrialState.COMPLETE - and _ORIG_NUMBER_KEY in trial.user_attrs - ): - failed_but_completed.append(trial) - - if trial.state in failed_states: - failed[trial.number] = trial - - for trial in failed_but_completed: - trial_num = trial.user_attrs[_ORIG_NUMBER_KEY] - failed.pop(trial_num, None) - - for _, trial in failed.items(): - study.enqueue_trial(trial.params, {_ORIG_NUMBER_KEY: trial.number}) - def configure_trainer(self, trainer: hlt.Trainer) -> None: """ Configure the trainer with the required settings. diff --git a/test/test_plugins.py b/test/test_plugins.py index b9d9421..98285c4 100644 --- a/test/test_plugins.py +++ b/test/test_plugins.py @@ -1,3 +1,5 @@ +import dataclasses as dc +import functools import pathlib import typing @@ -7,8 +9,12 @@ import helios.model as hlm import helios.plugins as hlp +import helios.plugins.optuna as hlpo import helios.trainer as hlt -from helios.plugins.optuna import _ORIG_NUMBER_KEY, OptunaPlugin +from helios.core import rng + +# Ignore the use of private members so we can test them correctly. +# ruff: noqa: SLF001 class ExceptionPlugin(hlp.Plugin): @@ -124,12 +130,12 @@ def test_configure(self) -> None: ) class TestOptunaPlugin: def test_plugin_id(self) -> None: - assert hasattr(OptunaPlugin, "plugin_id") - assert OptunaPlugin.plugin_id == "optuna" + assert hasattr(hlpo.OptunaPlugin, "plugin_id") + assert hlpo.OptunaPlugin.plugin_id == "optuna" def test_invalid_storage(self) -> None: def objective(trial: optuna.Trial) -> int: - plugin = OptunaPlugin(trial, "accuracy") + plugin = hlpo.OptunaPlugin(trial, "accuracy") plugin.is_distributed = True with pytest.raises(ValueError): plugin.setup() @@ -141,7 +147,7 @@ def objective(trial: optuna.Trial) -> int: def test_configure(self) -> None: def objective(trial: optuna.Trial) -> int: - plugin = OptunaPlugin(trial, "accuracy") + plugin = hlpo.OptunaPlugin(trial, "accuracy") trainer = hlt.Trainer() plugin.configure_trainer(trainer) @@ -171,7 +177,7 @@ def check_in_sequence(self, val: typing.Any, t: type, seq: typing.Container) -> def test_suggest(self) -> None: def objective(trial: optuna.Trial) -> int: - plugin = OptunaPlugin(trial, "accuracy") + plugin = hlpo.OptunaPlugin(trial, "accuracy") with pytest.raises(KeyError): plugin.suggest("foo", "bar") @@ -199,7 +205,7 @@ def objective(trial: optuna.Trial) -> int: def test_state_dict(self) -> None: def objective(trial: optuna.Trial) -> int: - plugin = OptunaPlugin(trial, "accuracy") + plugin = hlpo.OptunaPlugin(trial, "accuracy") x = plugin.suggest("float", "x", low=-10, high=10) state_dict = plugin.state_dict() @@ -214,42 +220,41 @@ def objective(trial: optuna.Trial) -> int: def test_resume_trial(self, tmp_path: pathlib.Path) -> None: num_trials = 10 + storage_path = tmp_path / "trial_test.db" successful_trials = [False for _ in range(num_trials)] offset = 0 + study_args = { + "study_name": "trial_test", + "storage": f"sqlite:///{storage_path}", + "load_if_exists": True, + } def objective(trial: optuna.Trial) -> float: nonlocal offset - plugin = OptunaPlugin(trial, "accuracy") + plugin = hlpo.OptunaPlugin(trial, "accuracy") plugin.setup() model = PluginModel("plugin-model") plugin.configure_model(model) trial_num = trial.number - if _ORIG_NUMBER_KEY in trial.user_attrs: - trial_num = trial.user_attrs[_ORIG_NUMBER_KEY] + if hlpo._ORIG_NUMBER_KEY in trial.user_attrs: + trial_num = trial.user_attrs[hlpo._ORIG_NUMBER_KEY] assert model.save_name == f"plugin-model_trial-{trial_num}" - # Artificially offset the number by 1 so that when we subtract 1 later on - # the indices work out correctly. + # Artificially offset the trial number so we don't raise the exception + # again. trial_num += 1 offset = 1 - if trial.number == num_trials / 2: + if trial_num == num_trials / 2: raise RuntimeError("half-way stop") successful_trials[trial_num - offset] = True + offset = 0 return 0 - def create_study() -> optuna.Study: - storage_path = tmp_path / "trial_test.db" - return optuna.create_study( - study_name="trial_test", - storage=f"sqlite:///{storage_path}", - load_if_exists=True, - ) - def optimize(study: optuna.Study) -> None: study.optimize( objective, @@ -262,30 +267,84 @@ def optimize(study: optuna.Study) -> None: ], ) - study = create_study() + study = hlpo.resume_study(study_args) with pytest.raises(RuntimeError): optimize(study) del study - study = create_study() - - OptunaPlugin.enqueue_failed_trials(study) + study = hlpo.resume_study(study_args) optimize(study) for v in successful_trials: assert v - def test_enqueue_failed_trials(self, tmp_path: pathlib.Path) -> None: - def objective(trial: optuna.Trial) -> float: - return 0 + def test_sampler_checkpoints(self, tmp_path: pathlib.Path) -> None: + @dc.dataclass + class TestRun: + samples: list[float] + chkpt_root: pathlib.Path + + num_trials = 10 + run1 = TestRun([], tmp_path / "run1") + run2 = TestRun([], tmp_path / "run2") + + run1.chkpt_root.mkdir(exist_ok=True) + run2.chkpt_root.mkdir(exist_ok=True) + + def objective( + trial: optuna.Trial, + raise_error: bool, + run: TestRun, + ) -> float: + hlpo.checkpoint_sampler(trial, run.chkpt_root) + res = trial.suggest_float("accuracy", 0, 1) + if raise_error and trial.number == num_trials // 2: + raise RuntimeError("half-way stop") + run.samples.append(res) + return res - storage_path = tmp_path / "enqueue_test.db" - study = optuna.create_study( - study_name="enqueue_test", storage=f"sqlite:///{storage_path}" + def create_study( + sampler: optuna.samplers.BaseSampler | None = None, + ) -> optuna.Study: + return optuna.create_study( + study_name="chkpt_test", + sampler=optuna.samplers.TPESampler(seed=rng.get_default_seed()) + if sampler is None + else sampler, + ) + + study = create_study() + study.optimize( + functools.partial( + objective, + raise_error=False, + run=run1, + ), + n_trials=num_trials, ) - with pytest.warns(UserWarning): - OptunaPlugin.enqueue_failed_trials( - study, [optuna.trial.TrialState.FAIL, optuna.trial.TrialState.COMPLETE] + chkpts = list(run1.chkpt_root.glob("*.pkl")) + assert len(chkpts) == num_trials + assert all(chkpt.stem == f"sampler_trial-{i}" for i, chkpt in enumerate(chkpts)) + del study + + study = create_study() + with pytest.raises(RuntimeError): + study.optimize( + functools.partial( + objective, + raise_error=True, + run=run2, + ), + n_trials=num_trials, ) + + torch.serialization.clear_safe_globals() + sampler = hlpo.restore_sampler(run2.chkpt_root) + study = create_study(sampler) + study.optimize( + functools.partial(objective, raise_error=False, run=run2), + n_trials=num_trials // 2, + ) + assert run2.samples == run1.samples