-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A faster and more memory-efficient implementation of zero_to_fp32
#6658
Conversation
@xylian86, FYI |
@xu-song, just to clarify, we greatly appreciate this PR. The memory and speed benefits are very useful. My only concern are the HF_Hub related changes, so hopefully those can be clarified. Can you please add the observed speed and memory benefits of this optimizations? Such details are generally useful for readers to better appreciate the value. Thanks! |
@tjruwase Is there any alternative approach to sharding torch state_dict? If any, the compatible feature to |
Sorry, but I am a bit confused about the objective of this PR. The goal of zero_to_fp32 is to create a consolidated checkpoint state from the sharded checkpoints of ZeRO-* training, so I don't understand why state_dict sharding is a consideration here. It seems that there are two parts of this PR.
Am I correct? |
1. Yes, HF_hub compatibility involves state_dict sharding.
2. Besides, our implementation exactly follows the the goal of zero_to_fp32. As the document says
By default, 3. new impl v1 DeepSpeed/deepspeed/utils/zero_to_fp32.py Lines 565 to 567 in 54903e0
v2 - state_dict_split = split_torch_state_dict_into_shards(state_dict,
+ mock_state_dict = {name: torch.empty(tensor.shape, dtype=tensor.dtype) for name, tensor in state_dict.items()}
+ state_dict_split = split_torch_state_dict_into_shards(mock_state_dict,
filename_pattern=filename_pattern,
max_shard_size=max_shard_size) Convert pseudo tensor to |
Apologies for not being clear. The reason that I referred to that doc was to show that |
deepspeed/utils/zero_to_fp32.py
Outdated
@@ -483,6 +530,7 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_f | |||
- ``checkpoint_dir``: path to the desired checkpoint folder | |||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` | |||
- ``exclude_frozen_parameters``: exclude frozen parameters | |||
- ``lazy_merge``: a more memory-efficient feature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Provide brief description of why more memory-efficient, and perhaps mention important usage concepts like pseudo tensors and contiguous()
.
Also, please add a unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be a good place for a unit test:
state_dict = get_fp32_state_dict_from_zero_checkpoint(filename, tag="checkpoint") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, i am working on it.
@tjruwase Sorry, you may confuse the objective of
If your point is |
No, my point is not |
A unit test and more comments have been added. Thanks |
formatting and unit test issue have been resolved. |
|
It is a faster and more memory-efficient implementation of
zero_to_fp32
.The previous version double the memory usage, which cause cpu OOM for very large models (e.g. llama 405B).
DeepSpeed/deepspeed/utils/zero_to_fp32.py
Lines 438 to 441 in b647fb2
How does it work?
mmap=True
, thus the weights are mmaped rather than loading all the storages into memory.GatheredTensor
contains the mmaped weights and tensor offset. It is a memory-efficient pseudo tensor. Only whentensor.contiguous()
is called, it starts to load related weights to memory and merge into a single tensor.Throughout the process, only one shard of tensors are keeped in memory.
How much benefit in speed and memory ?
Experiments were conducted on a linux host with 1TB of memory. Here is a detailed comparision
You can reproduce with the following scripts
2M
to(1/n)M
, whereM
is the memory cost of the full weights,n
is num_shards.Impl history
It has been discarded due to the controversial implementation of
data_ptr().
torch.empty