Skip to content

Commit

Permalink
[brief] Add a way to set the sharing strategy from torch.multiprocess…
Browse files Browse the repository at this point in the history
…ing.

[detailed]
  • Loading branch information
marovira committed May 29, 2024
1 parent f88b9dd commit 507a553
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/helios/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(
src_root: pathlib.Path | None = None,
import_prefix: str = "",
print_banner: bool = True,
mp_sharing_strategy: str | None = None,
):
"""Create the trainer."""
self._model: hlm.Model | None = None
Expand Down Expand Up @@ -347,6 +348,21 @@ def test(self, model: hlm.Model, datamodule: data.DataModule) -> None:
logging.close_default_loggers()
raise RuntimeError("error: uncaught exception") from e

def set_mp_sharing_strategy(self, strategy: str) -> None:
"""
Set the torch multiprocessing strategy.
This can be used to switch the sharing strategy used by torch.multiprocessing, but
it must be called prior to any calls to fit or test.
Note:
If the trainer is set to not use distributed training or if it's started from
torchrun, then this function does nothing.
"""
if not self._is_distributed or self._is_torchrun:
return

mp.set_sharing_strategy(strategy)

def _launch(
self, model: hlm.Model, datamodule: data.DataModule, mode: _TrainerMode
) -> None:
Expand Down

0 comments on commit 507a553

Please sign in to comment.