-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the hybrid sharding strategy in FSDP (HSDP) #884
Conversation
e2062b6
to
8011720
Compare
6122a65
to
babd5b3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks quite good. I just have one comment regarding the visibility of setup_2d_mesh_gangs
.
src/fairseq2/gang.py
Outdated
@@ -707,6 +707,113 @@ def fake_gangs(device: Device) -> Gangs: | |||
return Gangs(gang, gang, gang) | |||
|
|||
|
|||
def setup_2D_mesh_gangs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My preference would be to make this an internal function (i.e. _setup_2d_mesh_gangs()
). We already expose the actual functionality in setup_parallel_gangs()
and to_fsdp()
that are higher-level and abstract away the details of the mesh creation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I implemented this but added an extra "setup_" function in gang.py
to keep it coherent with setup_parallel_gangs()
.
src/fairseq2/gang.py
Outdated
A ``dict`` of two gangs; 0 maps to the gang of 2D mesh row, | ||
1 maps to the gang of the 2D mesh column. | ||
""" | ||
if row_length <= 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we make this function internal (as I mentioned above), we can ideally move argument checks to the public functions (setup_parallel_gangs
and to_fsdp
with more descriptive error messages. For instance, if the user calls setup_parallel_gangs
with a wrong tp_size
, the error message says "row_length
must be less than or equal to root_gang.size" which is totally clear now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, less defensive programming in the internal function now.
src/fairseq2/nn/fsdp.py
Outdated
if local_world_size == 1: | ||
raise ValueError( | ||
f"`local_world_size` must be greater than 1, but is {local_world_size} instead. " | ||
"This hybrid configuration would force FSDP to switch to use `NO_SHARD`, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit. As a convention, we do not line-break individual sentences of an error and log message within the code. This makes it possible for users to "grep" any part of the sentence (e.g. if they see it in the log file and want to figure out where it originates in the code base). This was mostly inspired by Linux kernel coding conventions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I actually was wondering about that, thank you for the explanation, implemented this.
fb5f97e
to
02f7b24
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect! LGTM!
What does this PR do? Please describe:
Add the hybrid sharding strategy in FSDP (HSDP).
Integration in recipes is only done and tested on instruction finetune, for now.
Fixes #{issue number}
Does your PR introduce any breaking changes? If yes, please list them:
No
Check list: