-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integration with DCP #978
base: unflatten
Are you sure you want to change the base?
Integration with DCP #978
Conversation
@@ -66,6 +68,49 @@ def get_layers(module): | |||
return layers | |||
|
|||
|
|||
def pipe_to_sd(pipe): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wz337 , might be interesting in dist state dict
Thanks for making it work! |
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
#Simulate saving the pipe | ||
# Option 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Option 1 would be more likely used than Option 2 in realistic setting. Could you please uncomment this block of code?
# print(f"Saving pipeline stage {stage_idx}") | ||
# stage_mod = pipe.get_stage_module(stage_idx) | ||
# dcp.save( | ||
# {f"stage_{stage_idx}": stage_mod}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, is the dict required by API of DCP? Can a user directly save stage_mod
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this matter? i think the DCP api had reasons for interfacing with dict instead of model, adding a new variant that takes model and gets its dict should be possible, but i think it's clearer this way that the only part of the model that gets saved is the dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be clear: I like saving the state dict too (instead of the module). That's more composable to me.
My question above is: is {f"stage_{stage_idx}": stage_mod}
necessary?
@@ -66,6 +68,49 @@ def get_layers(module): | |||
return layers | |||
|
|||
|
|||
def pipe_to_sd(pipe): | |||
sd = {} | |||
for stage_idx in range(pipe.num_stages): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something a little fishy about this proposal (equally so for both option 1 and 2) is that it's not likely you'd want to iterate all the stages in the pipe and load/save them.
Example 1: simple pipeline with 4 gpus
rank0: save/load pipe.submod_0 only
...
Example 2: complex pipeline with 4 gpus, 2 stages per gpu
rank0: save/load pipe.submod_0 and pipe.submod_4
rank1: save/load pipe.submod_1 and pipe.submod_5
...
sd = {} | ||
for stage_idx in range(pipe.num_stages): | ||
stage_mod = pipe.get_stage_module(stage_idx) | ||
sd[f"stage_{stage_idx}"] = stage_mod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not really clear to me why we need to add a prefix at all.
orig model
-----------
Transformer
embedding
layers
0
1
split model
-----------
submod0
embedding
layers
0
submod 1
layers
1
There should be no duplication of fqns between submods/stages.
what are we doing about the 'submod_0' part in the fqn? when we do stage_mod = pipe.get_stage_module(stage_idx)
does that return us a module that has top level keys like embedding
and layers
or a module that has a top level key of submod_n
?
If the former, can't we just save/load the keys as usual?
If the latter, we can still save/load without a prefix of stage_{idx}
i think, but we'll sadly be uncompatible to load into a non-PP model later on if we want to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Former. @wconstab
What's our plan for this PR? @LucasLLC I think we are pretty close to the destination.
|
For code quality checks, please run:
|
Description
Please read our CONTRIBUTING.md prior to creating your first pull request.
Please include a summary of the feature or issue being fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Testing out some Checkpointing code .
PR description is WIP
Fixes #(issue)
Type of change
Please delete options that are not relevant.
Feature/Issue validation/testing
Please describe the Unit or Integration tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.
Test A
Logs for Test A
Test B
Logs for Test B
Checklist: