Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

frameworks for FSDP and model/pipeline parallelism #85

Open
adrienchaton opened this issue Sep 12, 2024 · 15 comments
Open

frameworks for FSDP and model/pipeline parallelism #85

adrienchaton opened this issue Sep 12, 2024 · 15 comments

Comments

@adrienchaton
Copy link

Hello,

Along the issue here #11 which discusses finetuning codes for Evo, I am specifically looking for information on which frameworks could be used to optimize Evo finetuning. And what was used for the original training please?

Currently I am using HuggingFace and its ecosystem (e.g. PEFT), if I apply LoRA I can bring down the trainable parameters ("adapters") to 0.40% of the full model size so that I can fit around 3000 tokens per GPU (A100-80GB). In this setting, MSL=3000 tokens, batch size 1 and gradient accumulation, the training loss behaves well.

I need to increase the MSL at which I finetune Evo, 8k would be the minimum I would ideally aim at ... Note that the training speed is not my main concern here but gradient checkpointing did not bring a sufficient decrease in memory use.

FSDP, i.e. sharding and offloading to CPU/RAM

  • it seems that one must cast all parameters to the same dtype, so I tried full BF16 (although it isnt recommended)
  • nonetheless, besides dtype, the shape of the parameters triggered errors too and it would not handle e.g. the embeddings which are 2D

model parallelism / quantization, both techniques are more advanced and experimental, since Evo is not a native huggingface class, most of their utilities for MP / quantization do not work since particular methods are not implemented

Besides extensive research and trial and errors, I couldnt get to finetune Evo at MSL>3000 ... thanks in advance if any tips can be shared or if more details can be given on the original model training

@adrienchaton
Copy link
Author

adrienchaton commented Oct 5, 2024 via email

@xiyang-aads-lilly
Copy link

@adrienchaton Hi, I just wonder if you have more details on deepspeed finetuning? I have tried to train full parameters with deepspeed with seq len as 8K. But it always cause GPU OOM no matter using stage2 or stage3 (with offloading to CPU).

@adrienchaton
Copy link
Author

adrienchaton commented Oct 10, 2024 via email

@xiyang-aads-lilly
Copy link

@adrienchaton Just let you know, the internal problem is they did not implement gradient checkpointing. I added gradient checkpoint implementation and now can full fine tune the model with 8K window using deepspeed stage 2 without any OOM problem.

@adrienchaton
Copy link
Author

@xiyang-aads-lilly Could you share more details please? Are you using the HugginFace AutoModelForCausalLM wrapper and the HF trainer? For me, setting gradient_checkpointing=True in the TrainingArguments‎ passed to the HF trainer didnt cause any bug but did not show significant reduction in memory consumption ...

That would be great to see how you implemented gradient checkpointing and also how you run it with deepspeed. In my case, using this parameter
https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.fsdp
did not work e.g. fsdp=["full_shard", "offload"] would trigger errors due to e.g. mixed dtypes and shapes in the model parameters ...

Did you proceed similarly to this documentation? https://huggingface.co/docs/accelerate/usage_guides/deepspeed
i.e. creating a specific deepspeed config to configure accelerate launcher?

@xiyang-aads-lilly
Copy link

xiyang-aads-lilly commented Oct 15, 2024

@xiyang-aads-lilly Could you share more details please? Are you using the HugginFace AutoModelForCausalLM wrapper and the HF trainer? For me, setting gradient_checkpointing=True in the TrainingArguments‎ passed to the HF trainer didnt cause any bug but did not show significant reduction in memory consumption ...

That would be great to see how you implemented gradient checkpointing and also how you run it with deepspeed. In my case, using this parameter https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.fsdp did not work e.g. fsdp=["full_shard", "offload"] would trigger errors due to e.g. mixed dtypes and shapes in the model parameters ...

Did you proceed similarly to this documentation? https://huggingface.co/docs/accelerate/usage_guides/deepspeed i.e. creating a specific deepspeed config to configure accelerate launcher?

so if we load evo-8k model, it uses this: https://huggingface.co/togethercomputer/evo-1-131k-base/blob/main/model.py as backbone. If you check function stateless_forward (L377-L387):

def stateless_forward(self, x, padding_mask=None):
  if type(padding_mask) == torch.Tensor:
  x = x * padding_mask[..., None]
   
  for _, block in enumerate(self.blocks):
  x, _ = block(x, inference_params=None, padding_mask=padding_mask)
  return x, None

clearly it does not implement checkpointing strategy. Even you set checkpointing args, it actually did nothing.

I modified the function as blow along with some other small changes.

def stateless_forward(self, x, padding_mask=None):
        if type(padding_mask) == torch.Tensor:
            x = x * padding_mask[..., None]

        for _, block in enumerate(self.blocks):
            if self.gradient_checkpointing and self.training:
                x, _ = self._gradient_checkpointing_func(block.__call__, x, None, padding_mask)
            else:
                x, _ = block(x, inference_params=None, padding_mask=padding_mask)

        return x, None

Now, the checkpointing is working as expected.

I raised a issue in original stripedHyena repo: togethercomputer/stripedhyena#22 to discuss this.

The small changes I made:

# in modeling_heyna.py
# under StripedHyenaModelForCausalLM
# add/modify two following function:
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
    if not self.supports_gradient_checkpointing:
        raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

    if gradient_checkpointing_kwargs is None:
        gradient_checkpointing_kwargs = {"use_reentrant": True}

    # TODO support deepspeed checkpoint
    gradient_checkpointing_func = functools.partial(
        torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
    )

    self._set_gradient_checkpointing(
        enable=True, gradient_checkpointing_func=gradient_checkpointing_func
    )

    if getattr(self, "_hf_peft_config_loaded", False):
        # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
        # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
        # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
        # the gradients to make sure the gradient flows.
        self.enable_input_require_grads()

def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
    self.backbone.gradient_checkpointing = enable
    self.backbone._gradient_checkpointing_func = gradient_checkpointing_func


# in model.py
# under StripedHyena
# in __init__
# add
        self.gradient_checkpointing = False
        self._gradient_checkpointing_func = None

@kawabata-tomoko
Copy link

kawabata-tomoko commented Oct 15, 2024

@adrienchaton Just let you know, the internal problem is they did not implement gradient checkpointing. I added gradient checkpoint implementation and now can full fine tune the model with 8K window using deepspeed stage 2 without any OOM problem.

@adrienchaton
same solution with @xiyang-aads-lilly . Actually, you could fine tune sequence over 30k+length with lora and gradient_checkpointing(A800 80Gb * 7 or 8) but without DeepSpeed zeros. However, this method had bad performance on my demo dataset (225 sequence binary classification task).
here is a sample:
step 1, change stateless_forward method in the class StripedHyena:

from torch.utils.checkpoint import checkpoint
...
class StripedHyena(nn.Module):
    ...
    def stateless_forward(self,x,padding_mask=None):
        ...
        for _, block in enumerate(self.blocks):
            x, _ =checkpoint(block, x, None, padding_mask, use_reentrant=False)
        return x, None
    ...

if memory usage allows, you could set a checkpoint after few blocks calculated to having higher speed while training.
step 2 (may not necessary)

  1. adding this method into class StripedHyenaPreTrainerModel
    def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
        is_gradient_checkpointing_set = False

        # Apply it on the top-level module in case the top-level modules supports it
        # for example, LongT5Stack inherits from PreTrainedModel.
        if hasattr(self, "gradient_checkpointing"):
            self._gradient_checkpointing_func = gradient_checkpointing_func
            self.gradient_checkpointing = enable
            is_gradient_checkpointing_set = True

        for module in self.modules():
            if hasattr(module, "gradient_checkpointing"):
                module._gradient_checkpointing_func = gradient_checkpointing_func
                module.gradient_checkpointing = enable
                is_gradient_checkpointing_set = True

        if not is_gradient_checkpointing_set:
            raise ValueError(
                f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
                " gradient_checkpointing to modules of the model that uses checkpointing."
            )
  1. inherit _set_gradient_checkpointing method in downstream task model like this :
    def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
        self.backbone.gradient_checkpointing = enable
        super()._set_gradient_checkpointing(enable, gradient_checkpointing_func)

and set attribute supports_gradient_checkpointing = True.
Now, you could use gradient_checkpointing=True in the TrainingArguments‎

@adrienchaton
Copy link
Author

Thank you both @xiyang-aads-lilly @kawabata-tomoko for sharing implementation details on gradient checkpointing, I didn't look into details enough to notice that activating gradient checkpointing was doing nothing at all since it was just not implemented. I will try to run finetuning with LoRA + gradient checkpointing and then check if I can also make Zero3 work to further reduce GPU memory consumption.

@xiyang-aads-lilly
Copy link

xiyang-aads-lilly commented Oct 15, 2024

Thank you both @xiyang-aads-lilly @kawabata-tomoko for sharing implementation details on gradient checkpointing, I didn't look into details enough to notice that activating gradient checkpointing was doing nothing at all since it was just not implemented. I will try to run finetuning with LoRA + gradient checkpointing and then check if I can also make Zero3 work to further reduce GPU memory consumption.

I am also surprised that they did not implement gradient checkpointing because in their LLM version (https://huggingface.co/togethercomputer/StripedHyena-Hessian-7B/tree/main), they have checkpointing implemented.

@kawabata-tomoko
Copy link

kawabata-tomoko commented Oct 15, 2024

感谢你们分享梯度检查点的实现细节,我没有深入研究足够多的细节,以至于注意到激活梯度检查点根本没有做任何事情,因为它只是没有实现。我将尝试使用 LoRA + 梯度检查点运行微调,然后检查我是否也可以让 Zero3 工作以进一步减少 GPU 内存消耗。

actually, zero2 may enough for most of sequence tasks. It can full fine-tune on 26000bp sequence with A800 80GB*7 or 8(base model memory usage is close to 39Gb; additional memory usage is close to ~15Gb/10kbp, 26000bp sequence used 79.8+Gb memory.)

@adrienchaton
Copy link
Author

adrienchaton commented Oct 15, 2024

Thanks, if I can properly setup Zero2/3 I will try both.

How did you set it up? Did you use the accelerate config and accelerate launch as described in https://huggingface.co/docs/accelerate/usage_guides/deepspeed#accelerate-deepspeed-plugin?

@xiyang-aads-lilly
Copy link

xiyang-aads-lilly commented Oct 15, 2024

Thanks, if I can properly setup Zero2/3 I will try both.

How did you set it up? Did you use the accelerate config and accelerate launch as described in https://huggingface.co/docs/accelerate/usage_guides/deepspeed#accelerate-deepspeed-plugin?

I used hugginface trainer and set the deepspeed argument pointing to my deepspeed config file.

I think if you use accelerator, it will work as well. (under the hood, they are the same)

@kawabata-tomoko
Copy link

kawabata-tomoko commented Oct 15, 2024

Thanks, if I can properly setup Zero2/3 I will try both.

How did you set it up? Did you use the accelerate config and accelerate launch as described in https://huggingface.co/docs/accelerate/usage_guides/deepspeed#accelerate-deepspeed-plugin?

in accelerate,its also easy to use deepspeed:

step 1: DeepSpeed config:

...
from accelerate import Accelerator, DeepSpeedPlugin

deepspeed_plugin = DeepSpeedPlugin(
    hf_ds_config={
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "zero_allow_untested_optimizer": True,
    "bf16": {
      "enabled": "auto"
    },
    "zero_optimization": {
        "stage": 2,
        "contiguous_gradients": True,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_prefetch_bucket_size": 1e7,
        "stage3_param_persistence_threshold": 1e5,
        "reduce_bucket_size": 1e7,
        "sub_group_size": 1e9,
        "offload_optimizer": {
          "device": "cpu",
          "pin_memory": True
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": True
        }
   },
    "activation_checkpointing":{
        "partition_activations":True,
        "contiguous_memory_optimization":True,
        "cpu_checkpointing":True
    }
  }

)
accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
device = accelerator.device
...
model.to(device)
...

and add optim="adamw_apex_fused" into TrainingArgs

step2 Prepare to use:

...
trainer = accelerator.prepare(trainer)
trainer.train()

P.S. when I'm trying using ZeRO3, it cause RuntimeError while Embdedding (which same to this issue:issue:24643), and solution in github is not work(caused AssertionError: zero stage 2 requires an optimizer)

@adrienchaton
Copy link
Author

Thanks a lot, I really appreciate you sharing all these details!

Somehow just setting fsdp in the TrainingArguments‎ did not work so I imagine there must be some default params which were not suited to the Evo finetuning and must be explicitly configured in the deepspeed config.

I will try your recommendations as soon as possible!

@adrienchaton
Copy link
Author

adrienchaton commented Dec 19, 2024

thanks everyone for sharing your insights, this was really helpful!

I managed to put everything in place and getting nice training results so I will share the setup which is currently best for me

I am running on a DGX with 8*A100-80GB

Considering a batch size of 8, i.e. 1 per-device and no gradient accumulation, for sequence chunks of 16,000 tokens the training speed is around 8.30s/it and ~60GB per GPU memory. Then gradient accumulation can be tuned to the desired global batch size... Hope that helps if anyone is struggling with setup...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants