You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a tracking issue for adding LocalSGD support into torchft. There's been interest in LocalSGD support and it's something we'd like to be able to support.
This should be fairly straightforward as we can use the Manager + quorum in an outer loop and then use an allreduce only periodically copy of the weights.
Something like:
manager=Manager(...)
model= ...
whileTrue:
forstepinrange(local_steps):
inputs, labels=next(dataloader_iter)
optimizer.zero_grad()
criterion(model(inputs), labels).backwards()
optimizer.step()
# update quorum and PGs (could overlap with the optimizer steps above)manager.step()
# free gradient memory to make room for averaged weightsoptimizer.zero_grad(set_to_none=True)
# copy the model weights and start the allreduce mean# we need a temporary copy to gracefully handle failuresparams= {}
forname, paraminmodel.named_parameters():
copy=param.detach().clone()
manager.allreduce_grad(copy)
params[name] =copy# this will wait for all transfers to complete succesfullyifmanager.should_commit():
forname, paraminmodel.named_parameters():
param.copy_(params[name])
delparams[name]
DiLoCo should be a small modification of this algorithm to use a separate optimizer instead of just averaging the weights
For efficiency we should probably use the DDP reducer on the parameters directly and copy underlying Storage to make a backup copy
One of the additional points here is on when we allow rejoining/recovering. Our current implementation is quite rigid but with LocalSGD we may want more control for when we detect failing workers as well as when we allow them to recover to avoid blocking.
I think for flexibility we should change when we increment the step count
If we rename start_step to get_quorum and add a recover_allowed field to it that would let us call this multiple times per step, with the first time allowing recovery and the second time not.
We would then move the step incrementing to either should_commit or to a new explicit step/commit method
detect early failures and restart via heartbeats somehow
add a proper join/shrink-only support in quorum (adds complexity to quorum)
add support for multiple quorums per lighthouse -- that way we can do tick/tock behavior so workers will always join the primary quorum and the members in the secondary quorum are only healthy workers. This avoids complicating the quorum algorithm, though it will require some fallback logic if we get stuck in the secondary quorum
This is a tracking issue for adding LocalSGD support into torchft. There's been interest in LocalSGD support and it's something we'd like to be able to support.
This should be fairly straightforward as we can use the Manager + quorum in an outer loop and then use an allreduce only periodically copy of the weights.
Something like:
DiLoCo should be a small modification of this algorithm to use a separate optimizer instead of just averaging the weights
For efficiency we should probably use the DDP reducer on the parameters directly and copy underlying Storage to make a backup copy
References:
The text was updated successfully, but these errors were encountered: