From b9364617d77524f92d8095bb1dfd8a9950210867 Mon Sep 17 00:00:00 2001 From: "Mauricio A. Rovira Galvez" <8482308+marovira@users.noreply.github.com> Date: Fri, 26 Apr 2024 12:06:09 -0700 Subject: [PATCH] [brief] Allow the trainer to populate registries automatically. [detailed] - The main thing this fixes is the case where we use distributed training through spawn. Due to the forking of the python interpreter, all registries are going to be empty, so we need a way to populate them. - This also offers a convenient way of populating the registries without having to manually call them prior to creating the trainer. --- src/helios/trainer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/helios/trainer.py b/src/helios/trainer.py index b01a696..edc70d8 100644 --- a/src/helios/trainer.py +++ b/src/helios/trainer.py @@ -218,6 +218,8 @@ def __init__( chkpt_root: pathlib.Path | None = None, log_path: pathlib.Path | None = None, run_path: pathlib.Path | None = None, + src_root: pathlib.Path | None = None, + import_prefix: str = "", ): """Create the trainer.""" self._model: hlm.Model | None = None @@ -255,6 +257,9 @@ def __init__( self._log_path = log_path self._run_path = run_path + self._src_root = src_root + self._import_prefix = import_prefix + self._run_name = run_name self._validate_flags() @@ -499,7 +504,10 @@ def _configure_env(self) -> None: """ Configure the training environment. - This will seed the RNGs as well as setup any CUDA state (if using). + This will seed the RNGs as well as setup any CUDA state (if using). It will also + set all of the registries provided the source root is not None. This is to prevent + the registries from being empty if distributed training is launched through spawn + (note that torchrun doesn't have this problem). """ rng.seed_rngs(self._random_seed) torch.use_deterministic_algorithms(self._enable_deterministic) @@ -512,6 +520,11 @@ def _configure_env(self) -> None: logging.create_default_loggers(self._enable_tensorboard) + if self._src_root is not None: + core.update_all_registries( + self._src_root, recurse=True, import_prefix=self._import_prefix + ) + def _setup_datamodule(self) -> None: """Finish setting up the datamodule.""" self.datamodule.is_distributed = self._is_distributed @@ -619,6 +632,9 @@ def _validate_flags(self): if self._log_path.exists() and not self._log_path.is_dir(): raise ValueError("error: log path must be a directory") + if self._src_root is not None and not self._src_root.is_dir(): + raise ValueError("error: source root must be a directory") + def _setup_device_flags(self, use_cpu: bool | None): """ Configure the device state.