diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 59d6ded012..04c436c150 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -20,7 +20,7 @@ ) from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner from torch.distributed.checkpoint.storage import StorageReader, StorageWriter - +from torch.nn.parallel import DistributedDataParallel from torchtnt.framework.callbacks._checkpoint_utils import ( _prepare_app_state_for_checkpoint, _prepare_app_state_for_restore, @@ -41,6 +41,7 @@ from torchtnt.utils.checkpoint import BestCheckpointConfig, CheckpointPath from torchtnt.utils.optimizer import init_optim_state from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn + from torchtnt.utils.stateful import MultiStateful, Stateful @@ -62,6 +63,48 @@ FileSystemWriter as Writer, ) +# below code provides BC for PyTorch versions which don't include distributed state dict +# TODO: remove below code once this path is not longer supported +try: + import torch.distributed.checkpoint.state_dict as dsd + + # pyre-ignore Incompatible variable type [9] + get_model_state_dict = dsd.get_model_state_dict + + def set_model_state_dict(mod: torch.nn.Module, state_dict: Dict[str, Any]) -> None: + return dsd.set_model_state_dict(mod, state_dict) + +except ImportError: + logger.warn( + "torch.distributed.checkpoint.state_dict checkpoint is not available, " + "falling back on defaults. Consider updating PyTorch, as this version " + "will not be supported in the future." + ) + + def get_model_state_dict(mod: torch.nn.Module) -> Dict[str, Any]: + return mod.state_dict() + + def set_model_state_dict(mod: torch.nn.Module, state_dict: Dict[str, Any]) -> None: + return mod.load_state_dict(state_dict) + + +class DSDModelWrapper(Stateful): + """This wrapper converts state dicts to Distributed State Dicts, essentially generating + state dicts as if they were created using single-device methods. This is useful for + when checkpoint models might be resharded, or loaded in notebooks or otherwise non-distributed + settings. + + """ + + def __init__(self, mod: torch.nn.Module) -> None: + self.mod: torch.nn.Module = mod + + def state_dict(self) -> Dict[str, Any]: + return get_model_state_dict(self.mod) + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + set_model_state_dict(self.mod, state_dict) + class DistributedCheckpointSaver(BaseCheckpointer): """ @@ -148,6 +191,11 @@ def _checkpoint_impl( curr_snapshot_wait = hook == "on_train_end" app_state = _prepare_app_state_for_checkpoint(state, unit, intra_epoch) + + for key, obj in app_state.items(): + if isinstance(obj, DistributedDataParallel): + app_state[key] = DSDModelWrapper(obj) + # TODO: evaluate whether we need to implement the equivalent of torchsnapshot.RNGState() if self._async_checkpoint: with get_timing_context(state, f"{self.__class__.__name__}.async_save"): @@ -315,7 +363,7 @@ def restore( ) # necessary for loading optimizers since states are initialized lazy - for obj in app_state.values(): + for key, obj in app_state.items(): # sometimes optimizers are actually held in a wrapper which handles calling # state_dict and load_state_dict, sa is the case for # `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case. @@ -323,6 +371,9 @@ def restore( if isinstance(optimizer, torch.optim.Optimizer): init_optim_state(optimizer) + if isinstance(obj, DistributedDataParallel): + app_state[key] = DSDModelWrapper(obj) + try: dcp.load( {"app_state": MultiStateful(app_state)},