Skip to content

Python package for rematerialization-aware gradient checkpointing

License

Notifications You must be signed in to change notification settings

automorphic-ai/FastCkpt

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FastCkpt: accelerate your LLM training in one line!

Fast gradient checkpoint is designed for accelerate the training with memory-efficient attention like FlashAttention and LightSeq. FastCkpt has monkey patch for both rematerialization-aware checkpointing and FlashAttention, so you can patch both in only one line!

Paper: https://arxiv.org/pdf/2310.03294.pdf

News

  • [2023/10] FastCkpt now supports LlamaModel in Huggingface!

Install

pip install fastckpt

Usage

FastCkpt now supports HF training pipeline.

Use FaskCkpt and FlashAttention

To use fasckpt with flash_attn, import and run replace_hf_ckpt_with_fast_ckpt before importing transformers

# add monkey patch for fastckpt
from fastckpt.llama_flash_attn_ckpt_monkey_patch import replace_hf_ckpt_with_fast_ckpt
replace_hf_ckpt_with_fast_ckpt()

# import transformers and other packages
import transformers
...

Use FlashAttention only

To only replace the LlamaAttention with flash_attn without chaning the checkpointing strategy, import and run replace_llama_attn_with_flash_attn

# add monkey patch for fastckpt
from fastckpt.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()

# import transformers and other packages
import transformers
...

About

Python package for rematerialization-aware gradient checkpointing

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%