Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wraps DDP models with DSD #857

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
from torch.distributed.checkpoint.storage import StorageReader, StorageWriter

from torch.nn.parallel import DistributedDataParallel
from torchtnt.framework.callbacks._checkpoint_utils import (
_prepare_app_state_for_checkpoint,
_prepare_app_state_for_restore,
Expand All @@ -41,6 +41,7 @@
from torchtnt.utils.checkpoint import BestCheckpointConfig, CheckpointPath
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn

from torchtnt.utils.stateful import MultiStateful, Stateful


Expand All @@ -62,6 +63,48 @@
FileSystemWriter as Writer,
)

# below code provides BC for PyTorch versions which don't include distributed state dict
# TODO: remove below code once this path is not longer supported
try:
import torch.distributed.checkpoint.state_dict as dsd

# pyre-ignore Incompatible variable type [9]
get_model_state_dict = dsd.get_model_state_dict

def set_model_state_dict(mod: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
return dsd.set_model_state_dict(mod, state_dict)

except ImportError:
logger.warn(
"torch.distributed.checkpoint.state_dict checkpoint is not available, "
"falling back on defaults. Consider updating PyTorch, as this version "
"will not be supported in the future."
)

def get_model_state_dict(mod: torch.nn.Module) -> Dict[str, Any]:
return mod.state_dict()

def set_model_state_dict(mod: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
return mod.load_state_dict(state_dict)


class DSDModelWrapper(Stateful):
"""This wrapper converts state dicts to Distributed State Dicts, essentially generating
state dicts as if they were created using single-device methods. This is useful for
when checkpoint models might be resharded, or loaded in notebooks or otherwise non-distributed
settings.

"""

def __init__(self, mod: torch.nn.Module) -> None:
self.mod: torch.nn.Module = mod

def state_dict(self) -> Dict[str, Any]:
return get_model_state_dict(self.mod)

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
set_model_state_dict(self.mod, state_dict)


class DistributedCheckpointSaver(BaseCheckpointer):
"""
Expand Down Expand Up @@ -148,6 +191,11 @@ def _checkpoint_impl(
curr_snapshot_wait = hook == "on_train_end"

app_state = _prepare_app_state_for_checkpoint(state, unit, intra_epoch)

for key, obj in app_state.items():
if isinstance(obj, DistributedDataParallel):
app_state[key] = DSDModelWrapper(obj)

# TODO: evaluate whether we need to implement the equivalent of torchsnapshot.RNGState()
if self._async_checkpoint:
with get_timing_context(state, f"{self.__class__.__name__}.async_save"):
Expand Down Expand Up @@ -315,14 +363,17 @@ def restore(
)

# necessary for loading optimizers since states are initialized lazy
for obj in app_state.values():
for key, obj in app_state.items():
# sometimes optimizers are actually held in a wrapper which handles calling
# state_dict and load_state_dict, sa is the case for
# `torchtnt.utils.prepare_module.FSDPOptimizerWrapper`, this handles that case.
optimizer = getattr(obj, "optimizer", obj)
if isinstance(optimizer, torch.optim.Optimizer):
init_optim_state(optimizer)

if isinstance(obj, DistributedDataParallel):
app_state[key] = DSDModelWrapper(obj)

try:
dcp.load(
{"app_state": MultiStateful(app_state)},
Expand Down
Loading