Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LocalSGD / DiLoCo support #39

Open
d4l3k opened this issue Dec 13, 2024 · 2 comments
Open

LocalSGD / DiLoCo support #39

d4l3k opened this issue Dec 13, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@d4l3k
Copy link
Member

d4l3k commented Dec 13, 2024

This is a tracking issue for adding LocalSGD support into torchft. There's been interest in LocalSGD support and it's something we'd like to be able to support.

This should be fairly straightforward as we can use the Manager + quorum in an outer loop and then use an allreduce only periodically copy of the weights.

Something like:

manager = Manager(...)
model = ...

while True:
    for step in range(local_steps):
        inputs, labels = next(dataloader_iter)
        optimizer.zero_grad()
        criterion(model(inputs), labels).backwards()
        optimizer.step()

    # update quorum and PGs (could overlap with the optimizer steps above)
    manager.step()

    # free gradient memory to make room for averaged weights
    optimizer.zero_grad(set_to_none=True)

    # copy the model weights and start the allreduce mean
    # we need a temporary copy to gracefully handle failures
    params = {}
    for name, param in model.named_parameters():
        copy = param.detach().clone()
        manager.allreduce_grad(copy)
        params[name] = copy

    # this will wait for all transfers to complete succesfully
    if manager.should_commit():
        for name, param in model.named_parameters():
            param.copy_(params[name])
            del params[name]

DiLoCo should be a small modification of this algorithm to use a separate optimizer instead of just averaging the weights

For efficiency we should probably use the DDP reducer on the parameters directly and copy underlying Storage to make a backup copy

References:

@d4l3k d4l3k added the enhancement New feature or request label Dec 13, 2024
@d4l3k d4l3k changed the title LocalSGD support LocalSGD / DiLoCo support Dec 13, 2024
@d4l3k
Copy link
Member Author

d4l3k commented Dec 17, 2024

One of the additional points here is on when we allow rejoining/recovering. Our current implementation is quite rigid but with LocalSGD we may want more control for when we detect failing workers as well as when we allow them to recover to avoid blocking.

https://pytorch.slack.com/archives/C083HHTCU06/p1734466793734549?thread_ts=1734049373.379299&cid=C083HHTCU06

I think for flexibility we should change when we increment the step count
If we rename start_step to get_quorum and add a recover_allowed field to it that would let us call this multiple times per step, with the first time allowing recovery and the second time not.
We would then move the step incrementing to either should_commit or to a new explicit step/commit method

image

@d4l3k
Copy link
Member Author

d4l3k commented Dec 18, 2024

For the quorum we have a few options:

  1. detect early failures and restart via heartbeats somehow
  2. add a proper join/shrink-only support in quorum (adds complexity to quorum)
  3. add support for multiple quorums per lighthouse -- that way we can do tick/tock behavior so workers will always join the primary quorum and the members in the secondary quorum are only healthy workers. This avoids complicating the quorum algorithm, though it will require some fallback logic if we get stuck in the secondary quorum

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant