diff --git a/src/helios/trainer.py b/src/helios/trainer.py index 1dfef8b..aaa7868 100644 --- a/src/helios/trainer.py +++ b/src/helios/trainer.py @@ -225,6 +225,7 @@ def __init__( src_root: pathlib.Path | None = None, import_prefix: str = "", print_banner: bool = True, + mp_sharing_strategy: str | None = None, ): """Create the trainer.""" self._model: hlm.Model | None = None @@ -347,6 +348,21 @@ def test(self, model: hlm.Model, datamodule: data.DataModule) -> None: logging.close_default_loggers() raise RuntimeError("error: uncaught exception") from e + def set_mp_sharing_strategy(self, strategy: str) -> None: + """ + Set the torch multiprocessing strategy. + + This can be used to switch the sharing strategy used by torch.multiprocessing, but + it must be called prior to any calls to fit or test. + Note: + If the trainer is set to not use distributed training or if it's started from + torchrun, then this function does nothing. + """ + if not self._is_distributed or self._is_torchrun: + return + + mp.set_sharing_strategy(strategy) + def _launch( self, model: hlm.Model, datamodule: data.DataModule, mode: _TrainerMode ) -> None: