Skip to content

Latest commit

 

History

History
68 lines (53 loc) · 2.42 KB

checkpoint.md

File metadata and controls

68 lines (53 loc) · 2.42 KB

How to convert a Llama3 checkpoint for use in torchtitan

If you want to continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager. An example script for converting the original Llama3 checkpoints into the expected DCP format can be found in scripts/convert_llama_to_dcp.py.

The script expects a path to the original checkpoint files, and a path to an output directory:

python3 scripts/convert_llama_to_dcp.py <input_dir> <output_dir>

How to Convert a torchtitan Checkpoint for Use in torchtune

This guide will walk you through the steps required to convert a checkpoint from torchtitan so that it can be loaded into torchtune.

Steps

  1. ENABLE CHECKPOINTING In your torchtitan training config, ensure that enable_checkpoint is set to True.
[checkpoint]
enable_checkpoint = true
folder = "checkpoint"
interval_type = "steps"
interval = 500
  1. SAVE ONLY MODEL WEIGHTS By setting model_weights_only to True, the checkpoint will only contain the model weights and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size.
[checkpoint]
enable_checkpoint = true
model_weights_only = true
  1. CHOOSE DESIRED EXPORT PRECISION The default model states are in float32. You can choose to export the checkpoint in a lower precision format such as bfloat16.
[checkpoint]
enable_checkpoint = true
model_weights_only = true
export_dtype = "bfloat16"
  1. EXAMPLE CHECKPOINT CONFIGURATION
[checkpoint]
enable_checkpoint = true
folder = "checkpoint"
interval_type = "steps"
interval = 10
load_step = 5
model_weights_only = true
export_dtype = "bfloat16"
  1. SAVE THE FINAL CHECKPOINT
    Once the above have been set, the final checkpoint at the end of the training step will consist of model weights only with the desired export dtype. However, if the final step has not been reached yet, full checkpoints will still be saved so that training can be resumed.

  2. CONVERT SHARDED CHECKPOINTS TO A SINGLE FILE
    Finally, once you have obtained the last checkpoint, you can use the following command to convert the sharded checkpoints to a single .pt file that can be loaded into torchtune:

python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outputs/checkpoint/step-1000 checkpoint.pt

That's it. You have now successfully converted a sharded torchtitan checkpoint for use in torchtune.