Skip to content

Commit

Permalink
[brief] Big expansion for ensuring optuna trials can be resumed.
Browse files Browse the repository at this point in the history
[detailed]
- The previous system was unable to cope with situations where studies
  failed due to external reasons (system shutdowns, errors, process
  being killed, etc). To address this, the function was moved out of the
  plug-in class (since it didn't really need to be there) and now
  focuses on creating a new study with the correct trials moved over
  from the previous version. This ensures that if any trials fail or the
  process is killed and a trial is left running, it is correctly
  handled.
- Extends the system to allow checkpoints to be made from the samplers.
  This ensures reproducibility even if the study is stopped.
  • Loading branch information
marovira committed Nov 23, 2024
1 parent 456085e commit 0f48836
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 100 deletions.
215 changes: 149 additions & 66 deletions src/helios/plugins/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import optuna
except ImportError as e:
raise ImportError("error: OptunaPlugin requires Optuna to be installed") from e
import gc
import pathlib
import pickle
import typing
import warnings

import torch

Expand All @@ -20,6 +22,152 @@
_ORIG_NUMBER_KEY = "hl:id"


def resume_study(
study_args: dict[str, typing.Any],
failed_states: typing.Sequence = (optuna.trial.TrialState.FAIL,),
) -> optuna.Study:
"""
Resume a study that stopped because of a failure.
The goal of this function is to allow studies that failed due to an error (either an
exception, system error etc.) to continue utilising the built-in checkpoint system
from Helios. To accomplish this, the function will do the following:
#. Grab all the trials from the study created by the given ``study_args``, splitting
them into three groups: completed, failed, and failed but completed.
#. Create a new study with the same name and storage. This new study will get all of
the completed trials of the original, and will have the failed trials re-enqueued.
.. warning::
This function should **only** be called before trials are started.
.. warning::
This function **requires** ``RDBStorage`` as the storage argument for
``optuna.create_study``.
The ``failed_states`` argument can be used to set additional trial states to be
considered as "failures". This can be useful when dealing with special cases where
trials were either completed or pruned but need to be re-run.
This function works in tandem with
:py:meth:`~helios.plugins.optuna.OptunaPlugin.configure_model` to ensure that when
the failed trial is re-run, the original save name is restored so any saved
checkpoints can be re-used so the trial can continue instead of starting from
scratch.
.. note::
Only trials that fail but haven't been completed will be enqueued by this
function. If a trial fails and is completed later on, it will be skipped.
Args:
study_args: dictionary of arguments for ``optuna.create_study``.
failed_states: the trial states that are considered to be failures and should
be re-enqueued.
"""
if "storage" not in study_args:
raise TypeError("error: RDB storage is required for resuming studies")
if "load_if_exists" not in study_args or not study_args["load_if_exists"]:
raise KeyError("error: study must be created with 'load_if_exists' set to True")

storage_str: str = study_args["storage"]
if not isinstance(storage_str, str):
raise TypeError("error: only strings are supported for 'storage'")

storage = pathlib.Path(storage_str.removeprefix("sqlite:///")).resolve()

# Step 1: create the study with the current DB and grab all the trials.
study = optuna.create_study(**study_args)
complete: list[optuna.trial.FrozenTrial] = []
failed_but_completed: list[optuna.trial.FrozenTrial] = []
failed: dict[int, optuna.trial.FrozenTrial] = {}

for trial in study.trials:
if (
trial.state == optuna.trial.TrialState.COMPLETE
and _ORIG_NUMBER_KEY in trial.user_attrs
):
failed_but_completed.append(trial)
elif trial.state == optuna.trial.TrialState.COMPLETE:
complete.append(trial)
elif trial.state in failed_states:
failed[trial.number] = trial

# Make sure that any trials that failed but were completed are pruned from the failed
# trials list.
for trial in failed_but_completed:
trial_num = trial.user_attrs[_ORIG_NUMBER_KEY]
failed.pop(trial_num, None)

# Make sure that the study is cleared out before we attempt to rename the storage
del study
gc.collect()

# Step 2: rename the DB and create a new empty study.
tmp_storage = storage.parent / (storage.stem + "_tmp" + storage.suffix)
storage.rename(tmp_storage)
study = optuna.create_study(**study_args)

# Step 3: move all the trials into the new study, re-setting all trials that failed.
for trial in complete:
study.add_trial(trial)

for _, trial in failed.items():
study.enqueue_trial(trial.params, {_ORIG_NUMBER_KEY: trial.number})

# Once everything's done, clean up the temp storage.
tmp_storage.unlink()

return study


def checkpoint_sampler(trial: optuna.Trial, chkpt_root: pathlib.Path) -> None:
"""
Create a checkpoint with the state of the sampler.
This function can be used to ensure that if a study is restarted, the state of the
sampler is recovered so trials can be reproducible. The function will automatically
create a checkpoint using ``torch.save``.
.. note::
It is recommended that this function be called at the start of the objective
function to ensure the checkpoint is made correctly, but it can be called at any
time.
Args:
trial: the current trial.
chkpt_root: the root where the checkpoints will be saved.
"""
chkpt_path = chkpt_root / (f"sampler_trial-{trial.number}.pkl")

sampler = trial.study.sampler
with chkpt_path.open("wb") as outfile:
pickle.dump(sampler, outfile)


def restore_sampler(chkpt_root: pathlib.Path) -> optuna.samplers.BaseSampler:
"""
Restore the sampler from a previously saved checkpoint.
This function can be used in tandem with
:py:func:`~helios.plugins.optuna.checkpoint_sampler` to ensure that the last
checkpoint is loaded and the correct state is restored for the sampler. This function
**needs** to be called before ``optuna.create_study`` is called.
Args:
chkpt_root: the root where the checkpoints are stored.
Returns:
The restored sampler.
"""

def key(path: pathlib.Path) -> int:
return int(path.stem.split("-")[-1])

chkpts = list(chkpt_root.glob("*.pkl"))
chkpts.sort(key=key)
sampler = pickle.load(chkpts[-1].open("rb"))
return sampler


@hlp.PLUGIN_REGISTRY.register
class OptunaPlugin(hlp.Plugin):
"""
Expand Down Expand Up @@ -93,71 +241,6 @@ def trial(self) -> optuna.Trial:
def trial(self, t: optuna.Trial) -> None:
self._trial = t

@classmethod
def enqueue_failed_trials(
cls,
study: optuna.study.Study,
failed_states: typing.Sequence = (optuna.trial.TrialState.FAIL,),
) -> None:
"""
Enqueue any failed trials so they can be re-run.
This will add any failed trials from a previous run. This is used for cases when
the study had to be stopped due to an error, exception, or by the user, allowing
trials that didn't finish to complete.
.. warning::
This function should **only** be called before trials are started.
The ``failed_states`` argument can be used to set additional trial states to be
considered as "failures". This can be useful when dealing with special cases where
trials were either completed or pruned but need to be re-run.
This function works in tandem with
:py:meth:`~helios.plugins.optuna.OptunaPlugin.configure_model` to ensure that when
the failed trial is re-run, the original save name is restored so any saved
checkpoints can be re-used so the trial can continue instead of starting from
scratch.
.. note::
Only trials that fail but haven't been completed will be enqueued by this
function. If a trial fails and is completed later on, it will be skipped.
.. warning::
Depending on the reason for a trial failing, it is possible for this function
to re-add trials that will continue to fail. If you require special handling,
you may override this function to achieve your desired behaviour.
Args:
study: the study to get the failed trials from and enqueue them.
failed_states: the trial states that are considered to be failures and should
be re-enqueued.
"""
if optuna.trial.TrialState.COMPLETE in failed_states:
warnings.warn(
"warning: re-enqueuing completed trials could lead to incorrect results",
stacklevel=2,
)
failed_but_completed: list[optuna.trial.FrozenTrial] = []
failed: dict[int, optuna.trial.FrozenTrial] = {}
for trial in study.trials:
if (
trial.state == optuna.trial.TrialState.COMPLETE
and _ORIG_NUMBER_KEY in trial.user_attrs
):
failed_but_completed.append(trial)

if trial.state in failed_states:
failed[trial.number] = trial

for trial in failed_but_completed:
trial_num = trial.user_attrs[_ORIG_NUMBER_KEY]
failed.pop(trial_num, None)

for _, trial in failed.items():
study.enqueue_trial(trial.params, {_ORIG_NUMBER_KEY: trial.number})

def configure_trainer(self, trainer: hlt.Trainer) -> None:
"""
Configure the trainer with the required settings.
Expand Down
Loading

0 comments on commit 0f48836

Please sign in to comment.