Skip to content

Commit

Permalink
[brief] Allow the trainer to populate registries automatically.
Browse files Browse the repository at this point in the history
[detailed]
- The main thing this fixes is the case where we use distributed
  training through spawn. Due to the forking of the python interpreter,
  all registries are going to be empty, so we need a way to populate
  them.
- This also offers a convenient way of populating the registries without
  having to manually call them prior to creating the trainer.
  • Loading branch information
marovira committed Apr 26, 2024
1 parent 0ad3873 commit b936461
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion src/helios/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ def __init__(
chkpt_root: pathlib.Path | None = None,
log_path: pathlib.Path | None = None,
run_path: pathlib.Path | None = None,
src_root: pathlib.Path | None = None,
import_prefix: str = "",
):
"""Create the trainer."""
self._model: hlm.Model | None = None
Expand Down Expand Up @@ -255,6 +257,9 @@ def __init__(
self._log_path = log_path
self._run_path = run_path

self._src_root = src_root
self._import_prefix = import_prefix

self._run_name = run_name

self._validate_flags()
Expand Down Expand Up @@ -499,7 +504,10 @@ def _configure_env(self) -> None:
"""
Configure the training environment.
This will seed the RNGs as well as setup any CUDA state (if using).
This will seed the RNGs as well as setup any CUDA state (if using). It will also
set all of the registries provided the source root is not None. This is to prevent
the registries from being empty if distributed training is launched through spawn
(note that torchrun doesn't have this problem).
"""
rng.seed_rngs(self._random_seed)
torch.use_deterministic_algorithms(self._enable_deterministic)
Expand All @@ -512,6 +520,11 @@ def _configure_env(self) -> None:

logging.create_default_loggers(self._enable_tensorboard)

if self._src_root is not None:
core.update_all_registries(
self._src_root, recurse=True, import_prefix=self._import_prefix
)

def _setup_datamodule(self) -> None:
"""Finish setting up the datamodule."""
self.datamodule.is_distributed = self._is_distributed
Expand Down Expand Up @@ -619,6 +632,9 @@ def _validate_flags(self):
if self._log_path.exists() and not self._log_path.is_dir():
raise ValueError("error: log path must be a directory")

if self._src_root is not None and not self._src_root.is_dir():
raise ValueError("error: source root must be a directory")

def _setup_device_flags(self, use_cpu: bool | None):
"""
Configure the device state.
Expand Down

0 comments on commit b936461

Please sign in to comment.