From 6d6e9a4df51b986fe17e5e846839f4683f6054e4 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 18 Dec 2024 14:52:56 -0800 Subject: [PATCH] local_sgd: initial version of fault tolerant LocalSGD (#47) --- docs/source/index.rst | 1 + docs/source/local_sgd.rst | 4 + torchft/ddp.py | 4 +- torchft/ddp_test.py | 6 +- torchft/local_sgd.py | 184 ++++++++++++++++++++++++++++++++++++++ torchft/local_sgd_test.py | 96 ++++++++++++++++++++ torchft/manager.py | 41 +++++---- torchft/manager_test.py | 26 +++--- 8 files changed, 323 insertions(+), 39 deletions(-) create mode 100644 docs/source/local_sgd.rst create mode 100644 torchft/local_sgd.py create mode 100644 torchft/local_sgd_test.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 0037543..4d2a5af 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,6 +17,7 @@ the entire training job. manager optim ddp + local_sgd data checkpointing parameter_server diff --git a/docs/source/local_sgd.rst b/docs/source/local_sgd.rst new file mode 100644 index 0000000..6839eec --- /dev/null +++ b/docs/source/local_sgd.rst @@ -0,0 +1,4 @@ +.. automodule:: torchft.local_sgd + :members: + :undoc-members: + :show-inheritance: diff --git a/torchft/ddp.py b/torchft/ddp.py index e1d00a1..6fbea8f 100644 --- a/torchft/ddp.py +++ b/torchft/ddp.py @@ -68,7 +68,7 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N def _comm_hook( state: "Manager", bucket: dist.GradBucket ) -> torch.futures.Future[torch.Tensor]: - return state.allreduce_grad(bucket.buffer()) + return state.allreduce(bucket.buffer()) class PureDistributedDataParallel(nn.Module): @@ -88,7 +88,7 @@ def __init__(self, manager: "Manager", module: nn.Module) -> None: def post_grad_hook(p: torch.Tensor) -> None: if p.grad is not None: - manager.allreduce_grad(p.grad) + manager.allreduce(p.grad) for p in module.parameters(): p.register_post_accumulate_grad_hook(post_grad_hook) diff --git a/torchft/ddp_test.py b/torchft/ddp_test.py index da3ece4..1a56dce 100644 --- a/torchft/ddp_test.py +++ b/torchft/ddp_test.py @@ -32,14 +32,14 @@ def test_pure_ddp(self) -> None: for p in m.parameters(): self.assertIsNotNone(p.grad) - self.assertEqual(manager.allreduce_grad.call_count, len(list(m.parameters()))) + self.assertEqual(manager.allreduce.call_count, len(list(m.parameters()))) def test_ddp(self) -> None: manager = create_autospec(Manager) call_count = 0 - def allreduce_grad(tensor: torch.Tensor) -> Future[torch.Tensor]: + def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]: nonlocal call_count call_count += 1 @@ -48,7 +48,7 @@ def allreduce_grad(tensor: torch.Tensor) -> Future[torch.Tensor]: fut.set_result(tensor) return fut - manager.allreduce_grad = allreduce_grad + manager.allreduce = allreduce m = nn.Linear(3, 4) m = DistributedDataParallel(manager, m) diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py new file mode 100644 index 0000000..eef2b53 --- /dev/null +++ b/torchft/local_sgd.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +LocalSGD +========= + +This module implements a fault tolerant version of LocalSGD and related methods. +""" + +from typing import Any, Dict, List, Mapping, Optional + +import torch +from torch import nn, optim + +from torchft.manager import Manager + + +class LocalSGD(nn.Module): + """ + LocalSGD is a model wrapper similar to DistributedDataParallel that + implements the algorithm described in https://arxiv.org/pdf/1805.09767 + + This will synchronize the model parameters periodically in a fault tolerant + way using a torchft Manager. The allreduce on the parameters will happen + every sync_every steps after the optimizer.step call. + + To implement safe and fault tolerant, this requires a backup copy of the + weights. By default these are stored in CPU memory. If any error occurs + during the LocalSGD step, the step will be discarded and the model + parameters will reset back to the last time LocalSGD synchronized. + + The backup weights could be eliminated by relaxing the guarantee of exactly + `sync_every` steps but that would diverge from the LocalSGD algorithm. + DiLoCo also needs this backup copy to compute the delta. + + The torchft quorum is computed at the beginning of ``sync_every`` steps. If + any error occurs, or a worker fails between syncs, ``sync_every`` steps will be + discarded and a new quorum will be computed on the next step. + + If running in async mode, on a joining worker the first ``sync_every`` steps + will discarded as the model will be recovering during that period. When + using sync mode, the checkpoint will be restored prior to the first step. + + TODO: add a way via Manager to detect workers failing early for shrink only + TODO: add DiLoCo support + """ + + def __init__( + self, + manager: Manager, + model: nn.Module, + optimizer: optim.Optimizer, + sync_every: int, + backup_device: Optional[torch.device] = None, + pin_memory: bool = True, + ) -> None: + """ + Args: + manager: The manager to use. + model: The model to wrap. + optimizer: The optimizer used by the model. + sync_every: How often to sync the model weights. + backup_device: The device to store the backup of the model parameters on. (default cpu) + pin_memory: Whether to pin the memory used for the backup of the model parameters. + """ + super().__init__() + + self._manager = manager + self._model = model + self._local_step = 0 + self._started_step = False + self._sync_every = sync_every + assert sync_every >= 1, "sync_every must be greater than or equal to 1" + + device = backup_device or torch.device("cpu") + + self._backup_parameters: Dict[str, torch.Tensor] = {} + + for name, p in self._model.named_parameters(): + t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=device) + if ( + pin_memory + and t.device == torch.device("cpu") + and torch.cuda.is_available() + ): + t = t.pin_memory() + self._backup_parameters[name] = t + + # Need to copy the parameters to the host to be safe if we are on the first step. + self._save_parameters() + + optimizer.register_step_post_hook(self._step_post_hook) + + def _save_parameters(self) -> None: + # TODO: consider running copy on a separate stream + for name, p in self._model.named_parameters(): + self._backup_parameters[name].copy_(p.data, non_blocking=True) + + def _restore_parameters(self) -> None: + # TODO: consider running copy on a separate stream + for name, p in self._model.named_parameters(): + p.data.copy_(self._backup_parameters[name], non_blocking=True) + + # pyre-fixme[14]: support state_dict args + def state_dict(self) -> Dict[str, object]: + """ + state_dict returns the state_dict from the last time LocalSGD + synchronized and not the current weights. + """ + state_dict = self._model.state_dict() + for name, p in self._backup_parameters.items(): + assert name in state_dict + state_dict[name] = p + return state_dict + + def load_state_dict( + self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False + ) -> None: + """ + Loads the state dict to the model and the backup parameters. + + This must be called while the model weights aren't being modified to + avoid corrupting the backup weights. + """ + self._model.load_state_dict(state_dict, strict=strict, assign=assign) + self._save_parameters() + + def forward(self, *args: object, **kwargs: object) -> object: + """ + Run the model parameters. + + This should be called before the optimizer step. + + This will start the quorum and save the parameters if this is the first step. + """ + if self._local_step == 0: + self._manager.start_quorum() + + self._started_step = True + + return self._model.forward(*args, **kwargs) + + def _step_post_hook( + self, _optim: optim.Optimizer, _args: List[object], _kwargs: Dict[str, object] + ) -> None: + """ + This hook is registered on the optimizer and is called after the optimizer step. + + This will call the allreduce on the model weights every sync_every steps. + If any errors occur it will restore to the weights from the previous sync. + + ``forward`` must be called before this function. + """ + assert self._started_step, "forward must be called before step" + self._started_step = False + + self._local_step += 1 + + if self._local_step >= self._sync_every: + self._local_step = 0 + self._average() + + if self._manager.should_commit(): + # save the parameters so we can restore from them later if necessary. + self._save_parameters() + else: + # commit failed, restore from the backup parameters + self._restore_parameters() + + def _average(self) -> None: + # TODO: do we need to broadcast buffers like DDP does? + + works = [] + + for p in self._model.parameters(): + # TODO: bucketize parameters + works.append(self._manager.allreduce(p)) + + for work in works: + work.wait() diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py new file mode 100644 index 0000000..d2b73cd --- /dev/null +++ b/torchft/local_sgd_test.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict +from unittest import TestCase +from unittest.mock import create_autospec + +import torch +from torch import nn, optim + +from torchft.local_sgd import LocalSGD +from torchft.manager import Manager + + +class SimpleModel(nn.Module): + def __init__(self) -> None: + super().__init__() + + self.model = nn.Sequential( + nn.Linear(3, 4), + nn.ReLU(), + nn.Linear(4, 5), + nn.Sigmoid(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.model(x) + + +def _params_dict(m: torch.nn.Module) -> Dict[str, torch.Tensor]: + return {name: p.data for name, p in m.named_parameters()} + + +def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {name: value.clone().detach() for name, value in state_dict.items()} + + +class LocalSGDTest(TestCase): + def test_local_sgd_healthy(self) -> None: + base_m = SimpleModel() + optimizer = optim.SGD(base_m.parameters()) + manager = create_autospec(Manager) + + m = LocalSGD(manager, base_m, optimizer, sync_every=2) + self.assertEqual(m._local_step, 0) + + torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) + + inp = torch.rand(2, 3) + + loss = m(inp).mean() + loss.backward() + optimizer.step() + + self.assertEqual(m._local_step, 1) + self.assertEqual(manager.start_quorum.call_count, 1) + + loss = m(inp).mean() + loss.backward() + optimizer.step() + + manager.should_commit.return_value = True + self.assertEqual(m._local_step, 0) + + torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) + self.assertEqual(manager.should_commit.call_count, 1) + self.assertEqual(manager.allreduce.call_count, 4) + + def test_local_sgd_recovery(self) -> None: + base_m = SimpleModel() + optimizer = optim.SGD(base_m.parameters()) + manager = create_autospec(Manager) + + m = LocalSGD(manager, base_m, optimizer, sync_every=2) + + torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) + og_state_dict = _copy_state_dict(base_m.state_dict()) + + inp = torch.rand(2, 3) + + loss = m(inp).mean() + loss.backward() + optimizer.step() + + self.assertEqual(m._local_step, 1) + + state_dict = m.state_dict() + torch.testing.assert_close(state_dict, m._backup_parameters) + torch.testing.assert_close(state_dict, og_state_dict) + + m.load_state_dict(state_dict) + torch.testing.assert_close(_params_dict(base_m), state_dict) + torch.testing.assert_close(m._backup_parameters, _params_dict(base_m)) diff --git a/torchft/manager.py b/torchft/manager.py index a1e5167..1f76729 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -194,39 +194,39 @@ def shutdown(self) -> None: self._manager.shutdown() self._executor.shutdown() - def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tensor]: + def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]: """ - Allreduce the gradient and return a Future that will be completed when - the gradient is ready. + Fault tolerant allreduce the tensor and return a Future that will be completed when + the tensor is ready. - This will automatically scale the gradient by 1 / world_size. + This will automatically scale the tensor by 1 / world_size. If an error occurs during the allreduce: * The Future will be completed with no error and instead tracked asynchronously. - * After the first error, all subsequent allreduce_grad calls will be noops and immediately return. - * The grad tensor must be zeroed before being used as it may be corrupted. + * After the first error, all subsequent calls will be noops and immediately return. + * The tensor must be zeroed before being used as it may be corrupted. Args: - grad: the gradient to allreduce + tensor: the tensor to allreduce Returns: - a Future that will be completed with the allreduced gradient + a Future that will be completed with the allreduced tensor """ if self.errored(): fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut.set_result(grad) + fut.set_result(tensor) return fut self.wait_quorum() if not self.is_participating(): - grad.zero_() + tensor.zero_() # TODO: increase timeout when waiting when healing try: # Run the allreduce async and save the work object so we can wait on # it later. - work = self._pg.allreduce([grad], ReduceOp.SUM) + work = self._pg.allreduce([tensor], ReduceOp.SUM) fut = work.get_future() # schedule grad normalization as a continuation @@ -234,17 +234,17 @@ def allreduce_grad(self, grad: torch.Tensor) -> torch.futures.Future[torch.Tenso def callback( fut: torch.futures.Future[List[torch.Tensor]], ) -> torch.Tensor: - nonlocal grad + nonlocal tensor # check for exceptions fut.value() - grad /= self.num_participants() + tensor /= self.num_participants() - return grad + return tensor fut = fut.then(callback) - fut = self.wrap_future(fut, grad) + fut = self.wrap_future(fut, tensor) return fut except Exception as e: @@ -254,7 +254,7 @@ def callback( self.report_error(e) fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut.set_result(grad) + fut.set_result(tensor) return fut def report_error(self, e: Exception) -> None: @@ -324,12 +324,11 @@ def start_quorum(self, allow_heal: bool = True) -> None: It's best practice to call this before the forwards pass of each step for performance as computing quorum may take some time. - If allow_heal is set, the manager will attempt to heal either - synchronously before returning or asynchronously prior to any network - calls. - Args: - allow_heal: whether to allow healing at the beginning of the step + allow_heal: (experimental) whether to allow healing at the beginning of the step + If allow_heal is set, the manager will attempt to heal either + synchronously before returning or asynchronously prior to any network + calls. All replicas must pass the same value to allow_heal. """ # wait for previous quorum to complete diff --git a/torchft/manager_test.py b/torchft/manager_test.py index ad6d5a5..e119cc1 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -103,7 +103,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None: self.assertEqual(manager.batches_committed(), 0) manager.start_quorum() - manager.allreduce_grad(torch.tensor([1.0])).wait() + manager.allreduce(torch.tensor([1.0])).wait() self.assertEqual(len(manager._pending_work), 1) self.assertTrue(manager.should_commit()) self.assertEqual(len(manager._pending_work), 0) @@ -141,7 +141,7 @@ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None: self.assertEqual(manager.current_step(), 0) manager.start_quorum() - manager.allreduce_grad(torch.tensor([1.0])).wait() + manager.allreduce(torch.tensor([1.0])).wait() self.assertFalse(manager._healing) self.assertTrue(manager.is_participating()) self.assertEqual(manager.num_participants(), 2) @@ -190,7 +190,7 @@ def test_quorum_heal_async_not_enough_participants( self.assertEqual(manager.num_participants(), 1) grad = torch.tensor([1.0]) - manager.allreduce_grad(grad).wait() + manager.allreduce(grad).wait() torch.testing.assert_close(grad, torch.zeros_like(grad)) # don't commit since num_max < min_replica_size self.assertFalse(manager.should_commit()) @@ -240,7 +240,7 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None: self.assertTrue(manager._healing) grad = torch.tensor([1.0]) - manager.allreduce_grad(grad).wait() + manager.allreduce(grad).wait() torch.testing.assert_close(grad, torch.zeros_like(grad)) # don't commit since num_max < min_replica_size self.assertTrue(manager.should_commit()) @@ -280,17 +280,17 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: self.assertEqual(manager.current_step(), 0) manager.start_quorum() - manager.allreduce_grad(torch.tensor([1.0])).wait() + manager.allreduce(torch.tensor([1.0])).wait() # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 1) # inject failure when work queued # pyre-ignore[16]: _pg is mocked manager._pg.allreduce.side_effect = RuntimeError("injected failure") - manager.allreduce_grad(torch.tensor([1.0])).wait() + manager.allreduce(torch.tensor([1.0])).wait() self.assertTrue(manager._errored) # this should be skipped due to error - manager.allreduce_grad(torch.tensor([1.0])).wait() + manager.allreduce(torch.tensor([1.0])).wait() self.assertEqual(manager._pg.allreduce.call_count, 2) # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1) @@ -320,7 +320,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: bad_fut = torch.futures.Future() # pyre-fixme[29]: not a function bad_fut.set_exception(RuntimeError("injected failure")) manager._pg.allreduce.return_value.get_future.return_value = bad_fut - manager.allreduce_grad(torch.tensor([1.0])).wait() + manager.allreduce(torch.tensor([1.0])).wait() self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 2) self.assertTrue(manager._errored) self.assertFalse(manager.should_commit()) @@ -343,7 +343,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: ) manager.start_quorum() - manager.allreduce_grad(torch.tensor([1.0])).wait() + manager.allreduce(torch.tensor([1.0])).wait() self.assertTrue(manager.should_commit()) @patch("torchft.manager.ManagerClient", autospec=True) @@ -375,7 +375,7 @@ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None: self.assertEqual(manager.batches_committed(), 0) manager.start_quorum() - manager.allreduce_grad(torch.tensor([1.0])).wait() + manager.allreduce(torch.tensor([1.0])).wait() self.assertEqual(manager.is_participating(), rank != 2) self.assertEqual(manager.num_participants(), 2) @@ -408,7 +408,7 @@ def test_quorum_no_healing(self, client_mock: MagicMock) -> None: self.assertEqual(manager.batches_committed(), 0) manager.start_quorum(allow_heal=False) - manager.allreduce_grad(torch.tensor([1.0])).wait() + manager.allreduce(torch.tensor([1.0])).wait() self.assertFalse(manager.is_participating()) self.assertEqual(manager.num_participants(), 2) @@ -472,7 +472,7 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None: self.assertTrue(manager.is_participating()) fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut = manager.allreduce_grad(torch.tensor([1.0])) + fut = manager.allreduce(torch.tensor([1.0])) result = fut.value() torch.testing.assert_close(result, torch.tensor([1.0 / 5])) @@ -480,6 +480,6 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None: manager._healing = True self.assertFalse(manager.is_participating()) fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut = manager.allreduce_grad(torch.tensor([1.0])) + fut = manager.allreduce(torch.tensor([1.0])) result = fut.value() torch.testing.assert_close(result, torch.tensor([0.0]))