diff --git a/src/helios/trainer.py b/src/helios/trainer.py index b01a696..edc70d8 100644 --- a/src/helios/trainer.py +++ b/src/helios/trainer.py @@ -218,6 +218,8 @@ def __init__( chkpt_root: pathlib.Path | None = None, log_path: pathlib.Path | None = None, run_path: pathlib.Path | None = None, + src_root: pathlib.Path | None = None, + import_prefix: str = "", ): """Create the trainer.""" self._model: hlm.Model | None = None @@ -255,6 +257,9 @@ def __init__( self._log_path = log_path self._run_path = run_path + self._src_root = src_root + self._import_prefix = import_prefix + self._run_name = run_name self._validate_flags() @@ -499,7 +504,10 @@ def _configure_env(self) -> None: """ Configure the training environment. - This will seed the RNGs as well as setup any CUDA state (if using). + This will seed the RNGs as well as setup any CUDA state (if using). It will also + set all of the registries provided the source root is not None. This is to prevent + the registries from being empty if distributed training is launched through spawn + (note that torchrun doesn't have this problem). """ rng.seed_rngs(self._random_seed) torch.use_deterministic_algorithms(self._enable_deterministic) @@ -512,6 +520,11 @@ def _configure_env(self) -> None: logging.create_default_loggers(self._enable_tensorboard) + if self._src_root is not None: + core.update_all_registries( + self._src_root, recurse=True, import_prefix=self._import_prefix + ) + def _setup_datamodule(self) -> None: """Finish setting up the datamodule.""" self.datamodule.is_distributed = self._is_distributed @@ -619,6 +632,9 @@ def _validate_flags(self): if self._log_path.exists() and not self._log_path.is_dir(): raise ValueError("error: log path must be a directory") + if self._src_root is not None and not self._src_root.is_dir(): + raise ValueError("error: source root must be a directory") + def _setup_device_flags(self, use_cpu: bool | None): """ Configure the device state.