Skip to content

Commit

Permalink
[FSDP][optim_state_dict] Consolidate the arguments and logic of optim…
Browse files Browse the repository at this point in the history
…_state_dict and optim_state_dict_to_load (pytorch#96534)

Summary:
The current `optim_state_dict()` does not require users to call `optim.state_dict()` first while `optim_state_dict_to_load()` requires users to call `optim.load_state_dict()`. This PR make both APIs provide the option for users not having to call the extra API.

This PR also changes the arguments order of `optim_state_dict_to_load` which is a breaking change. So we should do this asap before the API is adopted in production cases.

Test Plan: CI

Differential Revision: D43925068

Pull Request resolved: pytorch#96534
Approved by: https://github.com/rohan-varma
  • Loading branch information
fegin authored and pytorchmergebot committed Mar 23, 2023
1 parent 1fb1c6e commit 580b470
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 35 deletions.
2 changes: 1 addition & 1 deletion test/distributed/checkpoint/test_2d_fsdp_dt_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _test_fsdp_dt_checkpoint(self, fsdp_pg=None) -> None:
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
)
flattened_osd = FSDP.optim_state_dict_to_load(
optim_state["optim"], model_2, optim_2
model_2, optim_2, optim_state["optim"]
)
optim_2.load_state_dict(flattened_osd)

Expand Down
2 changes: 1 addition & 1 deletion test/distributed/checkpoint/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_distributed_tensor_planner(self) -> None:
)

flattened_osd = FSDP.optim_state_dict_to_load(
optim_state["optim"], model_2, optim_2
model_2, optim_2, optim_state["optim"]
)
optim_2.load_state_dict(flattened_osd)

Expand Down
67 changes: 50 additions & 17 deletions test/distributed/fsdp/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,24 @@ def param_group1(self) -> List[torch.nn.Parameter]:
return list(self.block2.parameters()) + list(self.block0.parameters())


# Simple and boring model to test interface and some corner cases that do not
# require complicated wrapping strategy.
class TestDummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
self.net3 = nn.Linear(32, 64)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))

def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))

def get_input(self):
return torch.rand(8, 8, device="cuda")


class TestFSDPOptimState(FSDPTest):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -1079,8 +1097,8 @@ def _test_load_optim_state(
optim=optim2,
)
elif osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT:
sharded_osd1 = FSDP.optim_state_dict_to_load(fsdp_osd1, model2, optim2)
sharded_osd2 = FSDP.optim_state_dict_to_load(fsdp_osd2, model2, optim2)
sharded_osd1 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd1)
sharded_osd2 = FSDP.optim_state_dict_to_load(model2, optim2, fsdp_osd2)

# As a sanity check, check that sharding the second model's full/sharded
# optimizer state dict according to itself is equivalent to its local
Expand Down Expand Up @@ -1638,7 +1656,7 @@ def forward(self, x):

# Load the state back to see if load_optim_state_dict works.
state_dict_to_load = FSDP.optim_state_dict_to_load(
state_dicts[1], models[1], optims[1], is_named_optimizer=True
models[1], optims[1], state_dicts[1], is_named_optimizer=True
)
optims[1].load_state_dict(state_dict_to_load)
state_dicts[1] = FSDP.optim_state_dict(models[1], optims[1])
Expand All @@ -1652,18 +1670,6 @@ def forward(self, x):

@skip_if_lt_x_gpu(2)
def test_with_empty_optimizer_state(self):
class TestDummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
self.net3 = nn.Linear(32, 64)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))

def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))

model = FSDP(TestDummyModel().cuda())
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
state_dict = optim.state_dict()
Expand Down Expand Up @@ -1737,10 +1743,10 @@ def _test_load_optim_state_with_optim_state_dict(
# according to the second model and (2) for the second model according
# to the second model
sharded_osd1 = FSDP.optim_state_dict_to_load(
fsdp_osd1, model2, optim2, group=new_group
model2, optim2, fsdp_osd1, group=new_group
)
sharded_osd2 = FSDP.optim_state_dict_to_load(
fsdp_osd2, model2, optim2, group=new_group
model2, optim2, fsdp_osd2, group=new_group
)

# As a sanity check, check that sharding the second model's full/sharded
Expand Down Expand Up @@ -1774,6 +1780,33 @@ def _test_load_optim_state_with_optim_state_dict(
optim2.load_state_dict(sharded_osd2)
self._step_model(model2, optim2, num_iters=num_iters)

@skip_if_lt_x_gpu(2)
def test_interface_arguments(self):
model = FSDP(TestDummyModel().cuda())
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

def step():
loss = model(model.get_input())
loss.backward(loss)
optim.step()

step()
original_osd = deepcopy(optim.state_dict())
osd = FSDP.optim_state_dict(model, optim, optim_state_dict=original_osd)
self._check_same_state(
FSDP.optim_state_dict(model, optim), osd, check_same_param_keys=True
)
step()
osd_to_load = FSDP.optim_state_dict_to_load(
model, optim, osd, load_directly=True
)
self._check_same_state(
optim.state_dict(), original_osd, check_same_param_keys=True
)

# TODO: add local/sharded/full state_dict and CPU offloading and rank0
# interface test here, https://github.com/pytorch/pytorch/issues/97163


instantiate_parametrized_tests(TestFSDPOptimState)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def run_fsdp_checkpoint_example(rank, world_size):
)

flattened_osd = FSDP.optim_state_dict_to_load(
optim_state["optim"], model_2, optim_2
model_2, optim_2, optim_state["optim"]
)
optim_2.load_state_dict(flattened_osd)

Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/checkpoint/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def load_sharded_optimizer_state_dict(
>>> )
>>>
>>> flattened_osd = FSDP.optim_state_dict_to_load(
>>> optim_state["optimizer"], model, optim
>>> model, optim, optim_state["optimizer"]
>>> )
>>>
>>> optim.load_state_dict(flattened_osd)
Expand Down
55 changes: 43 additions & 12 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,13 +1692,25 @@ def rekey_optim_state_dict(
def optim_state_dict(
model: torch.nn.Module,
optim: torch.optim.Optimizer,
optim_state_dict: Optional[Dict[str, Any]] = None,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""
Returns the state dict of ``optim`` for the ``model`` that is (partially)
sharded by FSDP. The state may be sharded, consolidated, or consolidated
on rank 0 only depending on the ``state_dict_type`` set by
:meth:`set_state_dict_type` or :meth:`state_dict_type`.
Transforms the state_dict of ``optim`` for the ``model`` that is sharded
by FSDP to one of the three types: 1) full optimizer state_dict, 2)
sharded optimizer state_dict, 3) local optimizer state_dict.
For full optimizer state_dict, all states are unflattened and not sharded.
Rank0 only and CPU only can be specified via :meth:`state_dict_type` to
avoid OOM.
For sharded optimizer state_dict, all states are unflattend but sharded.
CPU only can be specified via :meth:`state_dict_type` to further save
memory.
For local state_dict, no transformation will be performed. But a state
will be converted from nn.Tensor to ShardedTensor to represent its sharding
nature (this is not supported yet).
Example::
Expand Down Expand Up @@ -1739,6 +1751,9 @@ def optim_state_dict(
were passed into the optimizer ``optim``.
optim (torch.optim.Optimizer): Optimizer for ``model`` 's
parameters.
optim_state_dict (Dict[str, Any]): the target optimizer state_dict to
transform. If the value is None, optim.state_dict() will be used. (
Default: ``None``)
group (dist.ProcessGroup): Model's process group across which parameters
are sharded or ``None`` if using the default process group. (
Default: ``None``)
Expand All @@ -1749,10 +1764,12 @@ def optim_state_dict(
``state_dict_type``.
"""
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model)
if optim_state_dict is None:
optim_state_dict = optim.state_dict()
return FullyShardedDataParallel._optim_state_dict_impl(
model=model,
optim=optim,
optim_state_dict=optim.state_dict(),
optim_state_dict=optim_state_dict,
optim_input=None,
rank0_only=getattr(state_dict_settings, "rank0_only", False),
full_state_dict=state_dict_settings.state_dict_type
Expand Down Expand Up @@ -1803,16 +1820,18 @@ def optim_state_dict_post_hook(

@staticmethod
def optim_state_dict_to_load(
optim_state_dict: Dict[str, Any],
model: torch.nn.Module,
optim: torch.optim.Optimizer,
optim_state_dict: Dict[str, Any],
is_named_optimizer: bool = False,
load_directly: bool = False,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""
Given a saved ``optim_state_dict``, converts it to the optimizer state_dict
that can be loaded to ``optim`` which is the optimizer for ``model``.
``model`` is (partially) sharded by FullyShardedDataParallel.
Given a ``optim_state_dict`` that is transformed through
:meth:`optim_state_dict`, converts it to the flattened optimizer
state_dict that can be loaded to ``optim`` which is the optimizer for
``model``. ``model`` must be sharded by FullyShardedDataParallel.
>>> # xdoctest: +SKIP("undefined variables")
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand All @@ -1828,7 +1847,12 @@ def optim_state_dict_to_load(
>>> FullOptimStateDictConfig(rank0_only=False),
>>> )
>>> state_dict = model.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(model, optim)
>>> original_osd = optim.state_dict()
>>> optim_state_dict = FSDP.optim_state_dict(
>>> model,
>>> optim,
>>> optim_state_dict=original_osd
>>> )
>>> save_a_checkpoint(state_dict, optim_state_dict)
>>> # Load a checkpoint
>>> model, optim = ...
Expand All @@ -1846,21 +1870,25 @@ def optim_state_dict_to_load(
>>> optim.load_state_dict(optim_state_dict)
Args:
optim_state_dict (Dict[str, Any]): The optimizer states to be loaded.
model (torch.nn.Module): Root module (which may or may not be a
:class:`FullyShardedDataParallel` instance) whose parameters
were passed into the optimizer ``optim``.
optim (torch.optim.Optimizer): Optimizer for ``model`` 's
parameters.
optim_state_dict (Dict[str, Any]): The optimizer states to be loaded.
is_named_optimizer (bool): Is this optimizer a NamedOptimizer or
KeyedOptimizer. Only set to True if ``optim`` is TorchRec's
KeyedOptimizer or torch.distributed's NamedOptimizer.
load_directly (bool): If this is set to True, this API will also
call optim.load_state_dict(result) before returning the result.
Otherwise, users are responsible to call ``optim.load_state_dict()``
(Default: ``False``)
group (dist.ProcessGroup): Model's process group across which parameters
are sharded or ``None`` if using the default process group. (
Default: ``None``)
"""
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model)
return FullyShardedDataParallel._optim_state_dict_to_load_impl(
result = FullyShardedDataParallel._optim_state_dict_to_load_impl(
optim_state_dict=optim_state_dict,
model=model,
optim_input=None,
Expand All @@ -1872,6 +1900,9 @@ def optim_state_dict_to_load(
is_named_optimizer=is_named_optimizer,
group=group,
)
if load_directly:
optim.load_state_dict(result)
return result

@staticmethod
def load_optim_state_dict_pre_hook(
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/optim/named_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ def init_state(self) -> None:

def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]:
if isinstance(self.module, FSDP):
return FSDP.load_optim_state_dict_pre_hook(
self.module, self._optimizer, state_dict
return FSDP.optim_state_dict_to_load(
self.module, self._optimizer, state_dict, is_named_optimizer=True
)
return state_dict

Expand Down

0 comments on commit 580b470

Please sign in to comment.