-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
frameworks for FSDP and model/pipeline parallelism #85
Comments
Hello Zheng and thanks for your reply.
That sounds very good and indeed I would appreciate if you could share more
details please.
Are you using the HuggingFace model and trainer?
This is what I am using at the moment. LoRA runs well with the
implementation from their PEFT library. But I couldn't get FSDP and CPU
offloading to work, as described in the Accelerate documentation.
How did you setup Zero2 and CPU offloading?
…On Mon, Sep 30, 2024, 04:14 Zheng YuLong ***@***.***> wrote:
By using bf16, LoRA, Zero2, and offloading optimizations and parameters to
the CPU, I was able to fine-tune EVO with a 30k+ sequence on 8 A800 80GB
devices. With a per-device training batch size of 4 and gradient
accumulation steps of 4, each device consumed 48GB of memory, achieving a
speed of around 8-9 seconds per iteration.
I hope this helps! If you need any further adjustments, feel free to let
me know.
—
Reply to this email directly, view it on GitHub
<#85 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AIO3CYLOE7MSADUFTBQAHTDZZCQZBAVCNFSM6AAAAABODYUNK2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGOBRHAZDINJYHE>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
@adrienchaton Hi, I just wonder if you have more details on deepspeed finetuning? I have tried to train full parameters with deepspeed with seq len as 8K. But it always cause GPU OOM no matter using stage2 or stage3 (with offloading to CPU). |
Hi, I think we need to combine FSDP + CPU offloading with LoRA to fit the
model within e.g. 80GB GPU. LoRA was easy to setup for me with the PEFT
library from HuggingFace. I am on vacations but if that helps I can share a
code snippet how to implement it when I am back. However FSDP and CPU
offloading caused me lots of errors so I am blocked with this part... If
you can share details on how you made it work, that would be fantastic.
…On Wed, Oct 9, 2024, 23:52 Xi Yang ***@***.***> wrote:
@adrienchaton <https://github.com/adrienchaton> Hi, I just wonder if you
have more details on deepspeed finetuning? I have tried to train full
parameters with deepspeed with seq len as 8K. But it always cause GPU OOM
no matter using stage2 or stage3 (with offloading to CPU).
—
Reply to this email directly, view it on GitHub
<#85 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AIO3CYO5NVDW45URF6WYQLDZ2WJSJAVCNFSM6AAAAABODYUNK2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBTGQYDOOBRGQ>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
@adrienchaton Just let you know, the internal problem is they did not implement gradient checkpointing. I added gradient checkpoint implementation and now can full fine tune the model with 8K window using deepspeed stage 2 without any OOM problem. |
@xiyang-aads-lilly Could you share more details please? Are you using the HugginFace AutoModelForCausalLM wrapper and the HF trainer? For me, setting That would be great to see how you implemented gradient checkpointing and also how you run it with deepspeed. In my case, using this parameter Did you proceed similarly to this documentation? https://huggingface.co/docs/accelerate/usage_guides/deepspeed |
so if we load evo-8k model, it uses this: https://huggingface.co/togethercomputer/evo-1-131k-base/blob/main/model.py as backbone. If you check function def stateless_forward(self, x, padding_mask=None):
if type(padding_mask) == torch.Tensor:
x = x * padding_mask[..., None]
for _, block in enumerate(self.blocks):
x, _ = block(x, inference_params=None, padding_mask=padding_mask)
return x, None clearly it does not implement checkpointing strategy. Even you set checkpointing args, it actually did nothing. I modified the function as blow along with some other small changes. def stateless_forward(self, x, padding_mask=None):
if type(padding_mask) == torch.Tensor:
x = x * padding_mask[..., None]
for _, block in enumerate(self.blocks):
if self.gradient_checkpointing and self.training:
x, _ = self._gradient_checkpointing_func(block.__call__, x, None, padding_mask)
else:
x, _ = block(x, inference_params=None, padding_mask=padding_mask)
return x, None Now, the checkpointing is working as expected. I raised a issue in original stripedHyena repo: togethercomputer/stripedhyena#22 to discuss this. The small changes I made: # in modeling_heyna.py
# under StripedHyenaModelForCausalLM
# add/modify two following function:
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": True}
# TODO support deepspeed checkpoint
gradient_checkpointing_func = functools.partial(
torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
)
self._set_gradient_checkpointing(
enable=True, gradient_checkpointing_func=gradient_checkpointing_func
)
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
self.backbone.gradient_checkpointing = enable
self.backbone._gradient_checkpointing_func = gradient_checkpointing_func
# in model.py
# under StripedHyena
# in __init__
# add
self.gradient_checkpointing = False
self._gradient_checkpointing_func = None |
@adrienchaton from torch.utils.checkpoint import checkpoint
...
class StripedHyena(nn.Module):
...
def stateless_forward(self,x,padding_mask=None):
...
for _, block in enumerate(self.blocks):
x, _ =checkpoint(block, x, None, padding_mask, use_reentrant=False)
return x, None
... if memory usage allows, you could set a checkpoint after few blocks calculated to having higher speed while training.
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
is_gradient_checkpointing_set = False
# Apply it on the top-level module in case the top-level modules supports it
# for example, LongT5Stack inherits from PreTrainedModel.
if hasattr(self, "gradient_checkpointing"):
self._gradient_checkpointing_func = gradient_checkpointing_func
self.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
for module in self.modules():
if hasattr(module, "gradient_checkpointing"):
module._gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = enable
is_gradient_checkpointing_set = True
if not is_gradient_checkpointing_set:
raise ValueError(
f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
" gradient_checkpointing to modules of the model that uses checkpointing."
)
def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
self.backbone.gradient_checkpointing = enable
super()._set_gradient_checkpointing(enable, gradient_checkpointing_func) and set attribute |
Thank you both @xiyang-aads-lilly @kawabata-tomoko for sharing implementation details on gradient checkpointing, I didn't look into details enough to notice that activating gradient checkpointing was doing nothing at all since it was just not implemented. I will try to run finetuning with LoRA + gradient checkpointing and then check if I can also make Zero3 work to further reduce GPU memory consumption. |
I am also surprised that they did not implement gradient checkpointing because in their LLM version (https://huggingface.co/togethercomputer/StripedHyena-Hessian-7B/tree/main), they have checkpointing implemented. |
actually, zero2 may enough for most of sequence tasks. It can full fine-tune on 26000bp sequence with A800 80GB*7 or 8(base model memory usage is close to 39Gb; additional memory usage is close to ~15Gb/10kbp, 26000bp sequence used 79.8+Gb memory.) |
Thanks, if I can properly setup Zero2/3 I will try both. How did you set it up? Did you use the |
I used hugginface trainer and set the deepspeed argument pointing to my deepspeed config file. I think if you use accelerator, it will work as well. (under the hood, they are the same) |
in accelerate,its also easy to use deepspeed: step 1: DeepSpeed config: ...
from accelerate import Accelerator, DeepSpeedPlugin
deepspeed_plugin = DeepSpeedPlugin(
hf_ds_config={
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": True,
"bf16": {
"enabled": "auto"
},
"zero_optimization": {
"stage": 2,
"contiguous_gradients": True,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_prefetch_bucket_size": 1e7,
"stage3_param_persistence_threshold": 1e5,
"reduce_bucket_size": 1e7,
"sub_group_size": 1e9,
"offload_optimizer": {
"device": "cpu",
"pin_memory": True
},
"offload_param": {
"device": "cpu",
"pin_memory": True
}
},
"activation_checkpointing":{
"partition_activations":True,
"contiguous_memory_optimization":True,
"cpu_checkpointing":True
}
}
)
accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
device = accelerator.device
...
model.to(device)
... and add step2 Prepare to use: ...
trainer = accelerator.prepare(trainer)
trainer.train() P.S. when I'm trying using ZeRO3, it cause RuntimeError while Embdedding (which same to this issue:issue:24643), and solution in github is not work(caused |
Thanks a lot, I really appreciate you sharing all these details! Somehow just setting I will try your recommendations as soon as possible! |
thanks everyone for sharing your insights, this was really helpful! I managed to put everything in place and getting nice training results so I will share the setup which is currently best for me I am running on a DGX with 8*A100-80GB
Considering a batch size of 8, i.e. 1 per-device and no gradient accumulation, for sequence chunks of 16,000 tokens the training speed is around 8.30s/it and ~60GB per GPU memory. Then gradient accumulation can be tuned to the desired global batch size... Hope that helps if anyone is struggling with setup... |
Hello,
Along the issue here #11 which discusses finetuning codes for Evo, I am specifically looking for information on which frameworks could be used to optimize Evo finetuning. And what was used for the original training please?
Currently I am using HuggingFace and its ecosystem (e.g. PEFT), if I apply LoRA I can bring down the trainable parameters ("adapters") to 0.40% of the full model size so that I can fit around 3000 tokens per GPU (A100-80GB). In this setting, MSL=3000 tokens, batch size 1 and gradient accumulation, the training loss behaves well.
I need to increase the MSL at which I finetune Evo, 8k would be the minimum I would ideally aim at ... Note that the training speed is not my main concern here but gradient checkpointing did not bring a sufficient decrease in memory use.
FSDP, i.e. sharding and offloading to CPU/RAM
model parallelism / quantization, both techniques are more advanced and experimental, since Evo is not a native huggingface class, most of their utilities for MP / quantization do not work since particular methods are not implemented
Besides extensive research and trial and errors, I couldnt get to finetune Evo at MSL>3000 ... thanks in advance if any tips can be shared or if more details can be given on the original model training
The text was updated successfully, but these errors were encountered: