Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the hybrid sharding strategy in FSDP (HSDP) #884

Merged
merged 5 commits into from
Dec 20, 2024

Conversation

MartinGleize
Copy link
Contributor

@MartinGleize MartinGleize commented Dec 11, 2024

What does this PR do? Please describe:
Add the hybrid sharding strategy in FSDP (HSDP).
Integration in recipes is only done and tested on instruction finetune, for now.

Fixes #{issue number}

Does your PR introduce any breaking changes? If yes, please list them:
No

Check list:

  • Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements)
  • Did you read the contributor guideline?
  • Did you make sure that your PR does only one thing instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes)

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 11, 2024
@MartinGleize MartinGleize changed the title Implement the hybrid sharding strategy in FSDP (HSDP) Add the hybrid sharding strategy in FSDP (HSDP) Dec 12, 2024
@MartinGleize MartinGleize self-assigned this Dec 13, 2024
@MartinGleize MartinGleize marked this pull request as ready for review December 13, 2024 21:30
@MartinGleize MartinGleize force-pushed the mgleize/hsdp branch 2 times, most recently from 6122a65 to babd5b3 Compare December 13, 2024 23:09
Copy link
Contributor

@cbalioglu cbalioglu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks quite good. I just have one comment regarding the visibility of setup_2d_mesh_gangs.

@@ -707,6 +707,113 @@ def fake_gangs(device: Device) -> Gangs:
return Gangs(gang, gang, gang)


def setup_2D_mesh_gangs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preference would be to make this an internal function (i.e. _setup_2d_mesh_gangs()). We already expose the actual functionality in setup_parallel_gangs() and to_fsdp() that are higher-level and abstract away the details of the mesh creation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented this but added an extra "setup_" function in gang.py to keep it coherent with setup_parallel_gangs().

A ``dict`` of two gangs; 0 maps to the gang of 2D mesh row,
1 maps to the gang of the 2D mesh column.
"""
if row_length <= 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make this function internal (as I mentioned above), we can ideally move argument checks to the public functions (setup_parallel_gangs and to_fsdp with more descriptive error messages. For instance, if the user calls setup_parallel_gangs with a wrong tp_size, the error message says "row_length must be less than or equal to root_gang.size" which is totally clear now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, less defensive programming in the internal function now.

if local_world_size == 1:
raise ValueError(
f"`local_world_size` must be greater than 1, but is {local_world_size} instead. "
"This hybrid configuration would force FSDP to switch to use `NO_SHARD`, "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit. As a convention, we do not line-break individual sentences of an error and log message within the code. This makes it possible for users to "grep" any part of the sentence (e.g. if they see it in the log file and want to figure out where it originates in the code base). This was mostly inspired by Linux kernel coding conventions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I actually was wondering about that, thank you for the explanation, implemented this.

tests/unit/test_gang.py Show resolved Hide resolved
Copy link
Contributor

@cbalioglu cbalioglu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! LGTM!

@cbalioglu cbalioglu merged commit 9a89641 into facebookresearch:main Dec 20, 2024
15 checks passed
@MartinGleize MartinGleize deleted the mgleize/hsdp branch December 20, 2024 13:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants