diff --git a/.github/workflows/build_linux_wheels.yaml b/.github/workflows/build_linux_wheels.yaml
index 7b30a64ad6..bcfe639531 100644
--- a/.github/workflows/build_linux_wheels.yaml
+++ b/.github/workflows/build_linux_wheels.yaml
@@ -36,6 +36,8 @@ jobs:
with:
repository: pytorch/torchtune
ref: ""
+ test-infra-repository: pytorch/test-infra
+ test-infra-ref: main
package-name: torchtune
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
pre-script: .github/scripts/pre_build_script.sh
diff --git a/.github/workflows/export.yaml b/.github/workflows/export.yaml
new file mode 100644
index 0000000000..e641568673
--- /dev/null
+++ b/.github/workflows/export.yaml
@@ -0,0 +1,51 @@
+name: Export
+
+on:
+ push:
+ paths:
+ - 'torchtune/modules/_export/**'
+ - 'tests/torchtune/modules/_export/**'
+ pull_request:
+ paths:
+ - 'torchtune/modules/_export/**'
+ - 'tests/torchtune/modules/_export/**'
+ schedule:
+ # Runs at midnight evvery day
+ - cron: '0 0 * * *'
+
+concurrency:
+ group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
+ cancel-in-progress: true
+
+defaults:
+ run:
+ shell: bash -l -eo pipefail {0}
+
+jobs:
+ export_unit_tests:
+ if: github.repository_owner == 'pytorch'
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ['3.9', '3.10', '3.11']
+ steps:
+ - name: Check out repo
+ uses: actions/checkout@v3
+ - name: Setup conda env
+ uses: conda-incubator/setup-miniconda@v2
+ with:
+ auto-update-conda: true
+ miniconda-version: "latest"
+ activate-environment: test
+ python-version: ${{ matrix.python-version }}
+ - name: Update pip
+ run: python -m pip install --upgrade pip
+ - name: Install dependencies
+ run: |
+ bash torchtune/modules/_export/install_requirements.sh
+ python -m pip install torchao
+ python -m pip install -e ".[dev]"
+ - name: Run unit tests with coverage
+ run: pytest tests/torchtune/modules/_export --cov=. --cov-report=xml --durations=20 -vv
+ - name: Upload Coverage to Codecov
+ uses: codecov/codecov-action@v3
diff --git a/.github/workflows/gpu_test.yaml b/.github/workflows/gpu_test.yaml
index 67b4a0705a..42dfd4d16b 100644
--- a/.github/workflows/gpu_test.yaml
+++ b/.github/workflows/gpu_test.yaml
@@ -55,6 +55,6 @@ jobs:
python -m pip install -e ".[dev]"
python -m pip install lm-eval==0.4.5
- name: Run recipe and unit tests with coverage
- run: pytest tests --with-integration --cov=. --cov-report=xml --durations=20 -vv
+ run: pytest tests --ignore tests/torchtune/modules/_export --with-integration --cov=. --cov-report=xml --durations=20 -vv
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v3
diff --git a/.github/workflows/unit_test.yaml b/.github/workflows/unit_test.yaml
index 4a988b5455..188f42e084 100644
--- a/.github/workflows/unit_test.yaml
+++ b/.github/workflows/unit_test.yaml
@@ -37,6 +37,6 @@ jobs:
python -m pip install torch torchvision torchao
python -m pip install -e ".[dev]"
- name: Run unit tests with coverage
- run: pytest tests --cov=. --cov-report=xml --durations=20 -vv
+ run: pytest tests --ignore tests/torchtune/modules/_export --cov=. --cov-report=xml --durations=20 -vv
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v3
diff --git a/README.md b/README.md
index 8caa38c890..c9980a6c0f 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,7 @@
[**Introduction**](#introduction) | [**Installation**](#installation) | [**Get Started**](#get-started) | [**Documentation**](https://pytorch.org/torchtune/main/index.html) | [**Community**](#community) | [**License**](#license) | [**Citing torchtune**](#citing-torchtune)
### 📣 Recent updates 📣
+* *December 2024*: torchtune now supports **Llama 3.3 70B**! Try it out by following our installation instructions [here](#Installation), then run any of the configs [here](recipes/configs/llama3_3).
* *November 2024*: torchtune has released [v0.4.0](https://github.com/pytorch/torchtune/releases/tag/v0.4.0) which includes stable support for exciting features like activation offloading and multimodal QLoRA
* *November 2024*: torchtune has added [Gemma2](recipes/configs/gemma2) to its models!
* *October 2024*: torchtune added support for Qwen2.5 models - find the recipes [here](recipes/configs/qwen2_5/)
@@ -39,6 +40,7 @@ torchtune currently supports the following models.
| Model | Sizes |
|-----------------------------------------------|-----------|
+| [Llama3.3](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_3) | 70B [[models](torchtune/models/llama3_3/_model_builders.py), [configs](recipes/configs/llama3_3/)] |
| [Llama3.2-Vision](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2#-llama-3.2-vision-models-(11b/90b)-) | 11B, 90B [[models](torchtune/models/llama3_2_vision/_model_builders.py), [configs](recipes/configs/llama3_2_vision/)] |
| [Llama3.2](https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_2) | 1B, 3B [[models](torchtune/models/llama3_2/_model_builders.py), [configs](recipes/configs/llama3_2/)] |
| [Llama3.1](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1) | 8B, 70B, 405B [[models](torchtune/models/llama3_1/_model_builders.py), [configs](recipes/configs/llama3_1/)] |
@@ -67,7 +69,8 @@ torchtune provides the following finetuning recipes for training on one or more
| LoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py)
[lora_finetune_distributed](recipes/lora_finetune_distributed.py) | [Qwen2 0.5B single-device](recipes/configs/qwen2/0.5B_lora_single_device.yaml)
[Gemma 7B distributed](recipes/configs/gemma/7B_lora.yaml)
| QLoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py)
[lora_finetune_distributed](recipes/lora_finetune_distributed.py)| [Phi3 Mini single-device](recipes/configs/phi3/mini_qlora_single_device.yaml)
[Llama 3.1 405B distributed](recipes/configs/llama3_1/405B_qlora.yaml)
| DoRA/QDoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py)
[lora_finetune_distributed](recipes/lora_finetune_distributed.py)| [Llama3 8B QDoRA single-device](recipes/configs/llama3/8B_qdora_single_device.yaml)
[Llama3 8B DoRA distributed](recipes/configs/llama3/8B_dora.yaml)
-| Quantization-Aware Training | 4-8 | [qat_distributed](recipes/qat_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_full.yaml)
+| Quantization-Aware Training | 2-8 | [qat_distributed](recipes/qat_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_full.yaml)
+| Quantization-Aware Training and LoRA Finetuning | 2-8 | [qat_lora_finetune_distributed](recipes/qat_lora_finetune_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_lora.yaml)
| Direct Preference Optimization |1-8 | [lora_dpo_single_device](recipes/lora_dpo_single_device.py)
[lora_dpo_distributed](recipes/lora_dpo_distributed.py) | [Llama2 7B single-device](recipes/configs/llama2/7B_lora_dpo_single_device.yaml)
[Llama2 7B distributed](recipes/configs/llama2/7B_lora_dpo.yaml)
| Proximal Policy Optimization | 1 | [ppo_full_finetune_single_device](recipes/ppo_full_finetune_single_device.py) | [Mistral 7B](recipes/configs/mistral/7B_full_ppo_low_memory.yaml)
| Knowledge Distillation | 1 | [knowledge_distillation_single_device](recipes/knowledge_distillation_single_device.py) | [Qwen2 1.5B -> 0.5B](recipes/configs/qwen2/knowledge_distillation_single_device.yaml)
diff --git a/docs/source/api_ref_models.rst b/docs/source/api_ref_models.rst
index e658de7294..645e7f2c19 100644
--- a/docs/source/api_ref_models.rst
+++ b/docs/source/api_ref_models.rst
@@ -6,6 +6,31 @@ torchtune.models
.. currentmodule:: torchtune.models
+llama3.3
+--------
+
+Text-only models from the 3.3 version of `Llama3 family `_.
+
+Important: You need to request access on `Hugging Face `__ before downloading it.
+
+To download the Llama-3.3-70B-Instruct model:
+
+.. code-block:: bash
+
+ tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf-token
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+
+ llama3_3.llama3_3_70b
+ llama3_3.lora_llama3_3_70b
+ llama3_3.qlora_llama3_3_70b
+
+.. note::
+
+ The Llama3.3 tokenizer reuses the :class:`~torchtune.models.llama3.llama3_tokenizer` class.
+
llama3.2
--------
diff --git a/docs/source/basics/multimodal_datasets.rst b/docs/source/basics/multimodal_datasets.rst
index 9ecadd92a7..7a45f06f72 100644
--- a/docs/source/basics/multimodal_datasets.rst
+++ b/docs/source/basics/multimodal_datasets.rst
@@ -71,12 +71,12 @@ in the text, ``""`` for where to place the image tokens. This will get re
.. code-block:: yaml
- # In config - model_transforms takes the place of the tokenizer
- model_transform:
+ tokenizer:
_component_: torchtune.models.llama3_2_vision_transform
path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
prompt_template: torchtune.data.QuestionAnswerTemplate
max_seq_len: 8192
+ image_size: 560
dataset:
_component_: torchtune.datasets.multimodal.multimodal_chat_dataset
@@ -137,7 +137,7 @@ For most datasets, you will also need to specify the ``split`` and/or the subset
.. code-block:: yaml
# In config
- model_transform:
+ tokenizer:
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
max_seq_len: 8192
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 318c82b3e2..d62ad77b63 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -113,6 +113,7 @@ torchtune tutorials.
recipes/recipes_overview
recipes/lora_finetune_single_device
recipes/qat_distributed
+ recipes/dpo
.. toctree::
:glob:
diff --git a/docs/source/recipes/dpo.rst b/docs/source/recipes/dpo.rst
new file mode 100644
index 0000000000..5fdb455a35
--- /dev/null
+++ b/docs/source/recipes/dpo.rst
@@ -0,0 +1,75 @@
+.. _dpo_recipe_label:
+
+====================================
+Direct Preference Optimization
+====================================
+
+This recipe supports several `Direct Preference Optimization `_ (DPO)-style fine-tuning techniques.
+These techniques aim to steer (or `align `_) a model towards some desirable behaviours.
+For example, a common goal is to train language models to produce safe and honest outputs,
+or to be `helpful and harmless `_.
+
+To see the best results when using this recipe, it may be helpful to first fine-tune your model with using supervised fine-tuning to ensure your model is
+on-distribution for the domain you're interested in. To do this, check out our other fine-tuning recipes in the :ref:`recipe overview ` which
+support a variety of SFT paradigms.
+
+After supervised fine-tuning, here is an example of DPO with Llama 3.1 8B:
+
+.. note::
+
+ You may need to be granted access to the Llama model you're interested in. See
+ :ref:`here ` for details on accessing gated repositories.
+
+
+.. code-block:: bash
+
+ tune download meta-llama/Meta-Llama-3.1-8B-Instruct \
+ --ignore-patterns "original/consolidated.00.pth"
+ --HF_TOKEN
+
+ # run on a single device
+ tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device
+
+ # run on two gpus
+ tune run --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo
+
+It's easy to get started with this recipe with your dataset of choice, including custom local datasets,
+and datasets from Hugging Face. Check out our primer on :ref:`preference datasets ` to
+see how to do this.
+
+For this recipe we include different DPO-style losses:
+
+* :class:`Direct Preference Optimization ` (DPO) loss [#]_. The DPO loss function
+ increases the relative log-probabilities of preferred to un-preferred responses, whilst using log probabilities
+ from a reference model to prevent policy degradation during training. Alongside RLHF, this is the most commonly used
+ alignment technique and is used to train a growing number of state-of-the-art LLMs e.g. Llama3.1, Gemma 2, Qwen2, etc.
+ This is a good starting point for alignment fine-tuning.
+* :class:`Statistical Rejection Sampling Optimization ` (RSO) or "hinge" loss [#]_.
+ RSO builds on concepts from support vector machines and DPO, applying a margin-based approach that penalizes
+ low-quality responses while ensuring a significant gap between chosen and un-chosen log probabilities.
+
+To use any of these, simply use the ``loss`` config entry or flag through the :ref:`cli_label`:
+
+.. code-block:: bash
+
+ tune run lora_dpo_single_device --config llama2/7B_lora_dpo_single_device \
+ loss=torchtune.modules.loss.RSOLoss \
+ gamma=0.5
+
+.. todo (@SalmanMohammadi) point to an example repo for SimPO
+
+For a deeper understanding of the different levers you can pull when using this recipe,
+see our documentation for the different PEFT training paradigms we support:
+
+* :ref:`glossary_lora`
+* :ref:`glossary_qlora`
+* :ref:`glossary_dora`
+
+Many of our other memory optimization features can be used in this recipe. You can learn more about all of our memory optimization features in our :ref:`memory optimization overview`.
+
+.. rubric:: References:
+
+.. [#] Rafailov, R., Sharma, A., Mitchell, E., Manning, C.D., Ermon, S. and Finn, C., 2024.
+ Direct preference optimization: Your language model is secretly a reward model. Advances in Neural Information Processing Systems, 36.
+.. [#] Liu, T., Zhao, Y., Joshi, R., Khalman, M., Saleh, M., Liu, P.J. and Liu, J., 2023.
+ Statistical rejection sampling improves preference optimization. arXiv preprint arXiv:2309.06657.
diff --git a/docs/source/recipes/lora_finetune_single_device.rst b/docs/source/recipes/lora_finetune_single_device.rst
index 4b4d476058..ffcca11d53 100644
--- a/docs/source/recipes/lora_finetune_single_device.rst
+++ b/docs/source/recipes/lora_finetune_single_device.rst
@@ -8,7 +8,7 @@ This recipe supports finetuning on next-token prediction tasks using parameter e
such as :ref:`glossary_lora` and :ref:`glossary_qlora`. These techniques
significantly reduce memory consumption during training whilst still maintaining competitive performance.
-We provide configs which you can get up and running quickly. Here is an example with llama 3.1 8B:
+We provide configs which you can get up and running quickly. Here is an example with Llama 3.1 8B:
.. note::
diff --git a/docs/source/recipes/recipes_overview.rst b/docs/source/recipes/recipes_overview.rst
index a1c4f39ef3..e6e8c9cd63 100644
--- a/docs/source/recipes/recipes_overview.rst
+++ b/docs/source/recipes/recipes_overview.rst
@@ -28,7 +28,7 @@ Our recipes include:
* Single-device full fine-tuning
* Distributed full fine-tuning
* Distributed LoRA fine-tuning
-* Direct Preference Optimization (DPO)
+* :ref:`Direct Preference Optimization (DPO) `
* Proximal Policy Optimization (PPO)
* :ref:`Distributed Quantization-Aware Training (QAT)`.
diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml
index b6586e8b5a..ad941803bb 100644
--- a/recipes/configs/code_llama2/7B_full_low_memory.yaml
+++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml
@@ -19,6 +19,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/code_llama2_7B/full_low_memory # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.code_llama2.code_llama2_7b
@@ -39,7 +41,7 @@ checkpointer:
pytorch_model-00003-of-00003.bin
]
recipe_checkpoint: null
- output_dir: /tmp/CodeLlama-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
@@ -55,14 +57,14 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 2e-5
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -73,13 +75,13 @@ enable_activation_offloading: True # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/codellama_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/CodeLlama-7b-hf/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml
index 11f2ffc6c6..416c11fc27 100644
--- a/recipes/configs/code_llama2/7B_lora_single_device.yaml
+++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/code_llama2_7B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.code_llama2.lora_code_llama2_7b
@@ -42,7 +44,7 @@ checkpointer:
]
adapter_checkpoint: null
recipe_checkpoint: null
- output_dir: /tmp/CodeLlama-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -59,7 +61,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -70,7 +72,7 @@ lr_scheduler:
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -81,10 +83,9 @@ enable_activation_offloading: False # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/codellama_lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/CodeLlama-7b-hf/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml
index ad21d8074d..9f3f1dbe4e 100644
--- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml
+++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/code_llama2_7B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.code_llama2.qlora_code_llama2_7b
@@ -42,7 +44,7 @@ checkpointer:
]
adapter_checkpoint: null
recipe_checkpoint: null
- output_dir: /tmp/CodeLlama-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -58,7 +60,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -69,7 +71,7 @@ lr_scheduler:
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -80,10 +82,9 @@ enable_activation_offloading: False # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/codellama_qlora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/CodeLlama-7b-hf/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/dev/8B_full_experimental.yaml b/recipes/configs/dev/8B_full_experimental.yaml
index c3fb212093..d8f5e8956f 100644
--- a/recipes/configs/dev/8B_full_experimental.yaml
+++ b/recipes/configs/dev/8B_full_experimental.yaml
@@ -18,6 +18,8 @@
# best to use 8B_full_single_device.yaml for those cases
+output_dir: /tmp/torchtune/dev_8B/full_experimental # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -42,7 +44,7 @@ checkpointer:
consolidated.00.pth
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-8B/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
@@ -57,8 +59,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -77,11 +79,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-llama3-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: null
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/gemma/2B_full.yaml b/recipes/configs/gemma/2B_full.yaml
index 5bd0f05a02..fa692e0f0d 100644
--- a/recipes/configs/gemma/2B_full.yaml
+++ b/recipes/configs/gemma/2B_full.yaml
@@ -16,6 +16,8 @@
# This config works only when the model is being fine-tuned on 2+ GPUs.
+output_dir: /tmp/torchtune/gemma_2B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -40,7 +42,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-2b
+ output_dir: ${output_dir}
model_type: GEMMA
resume_from_checkpoint: False
@@ -54,8 +56,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -71,11 +73,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/gemma/2B_lora.yaml b/recipes/configs/gemma/2B_lora.yaml
index d947b358b0..4a79abcf99 100644
--- a/recipes/configs/gemma/2B_lora.yaml
+++ b/recipes/configs/gemma/2B_lora.yaml
@@ -15,6 +15,8 @@
#
# This config works only when the model is being fine-tuned on 2+ GPUs.
+output_dir: /tmp/torchtune/gemma_2B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -44,7 +46,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-2b
+ output_dir: ${output_dir}
model_type: GEMMA
resume_from_checkpoint: False
@@ -66,8 +68,8 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -82,11 +84,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/gemma/2B_lora_single_device.yaml b/recipes/configs/gemma/2B_lora_single_device.yaml
index 0559dc218c..e0b473e4ec 100644
--- a/recipes/configs/gemma/2B_lora_single_device.yaml
+++ b/recipes/configs/gemma/2B_lora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/gemma_2B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -44,7 +46,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-2b
+ output_dir: ${output_dir}
model_type: GEMMA
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -65,8 +67,8 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -81,8 +83,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/gemma/2B_qlora_single_device.yaml b/recipes/configs/gemma/2B_qlora_single_device.yaml
index a3c7f3a5f9..842dc2580f 100644
--- a/recipes/configs/gemma/2B_qlora_single_device.yaml
+++ b/recipes/configs/gemma/2B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/gemma_2B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -44,7 +46,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-2b
+ output_dir: ${output_dir}
model_type: GEMMA
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -65,8 +67,8 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -81,8 +83,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/gemma/7B_full.yaml b/recipes/configs/gemma/7B_full.yaml
index 4555235385..47206ed291 100644
--- a/recipes/configs/gemma/7B_full.yaml
+++ b/recipes/configs/gemma/7B_full.yaml
@@ -16,6 +16,8 @@
# This config works only when the model is being fine-tuned on 2+ GPUs.
+output_dir: /tmp/torchtune/gemma_7B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -42,7 +44,7 @@ checkpointer:
model-00004-of-00004.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-7b
+ output_dir: ${output_dir}
model_type: GEMMA
resume_from_checkpoint: False
@@ -56,8 +58,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -73,11 +75,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/gemma/7B_lora.yaml b/recipes/configs/gemma/7B_lora.yaml
index a67e9ea3e7..3383bae31c 100644
--- a/recipes/configs/gemma/7B_lora.yaml
+++ b/recipes/configs/gemma/7B_lora.yaml
@@ -16,6 +16,8 @@
# This config works only when the model is being fine-tuned on 2+ GPUs.
+output_dir: /tmp/torchtune/gemma_7B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -47,7 +49,7 @@ checkpointer:
model-00004-of-00004.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-7b/
+ output_dir: ${output_dir}
model_type: GEMMA
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -68,8 +70,8 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -84,11 +86,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/gemma/7B_lora_single_device.yaml b/recipes/configs/gemma/7B_lora_single_device.yaml
index 82d1399b20..e055b09bd5 100644
--- a/recipes/configs/gemma/7B_lora_single_device.yaml
+++ b/recipes/configs/gemma/7B_lora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/gemma_7B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -46,7 +48,7 @@ checkpointer:
model-00004-of-00004.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-7b/
+ output_dir: ${output_dir}
model_type: GEMMA
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -67,8 +69,8 @@ loss:
batch_size: 8
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -83,8 +85,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/gemma/7B_qlora_single_device.yaml b/recipes/configs/gemma/7B_qlora_single_device.yaml
index 471de7572a..01fb823b4a 100644
--- a/recipes/configs/gemma/7B_qlora_single_device.yaml
+++ b/recipes/configs/gemma/7B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/gemma_7B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -46,7 +48,7 @@ checkpointer:
model-00004-of-00004.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-7b/
+ output_dir: ${output_dir}
model_type: GEMMA
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -67,8 +69,8 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -83,8 +85,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml
index ddc89b38b2..46a31b6821 100644
--- a/recipes/configs/gemma2/27B_full.yaml
+++ b/recipes/configs/gemma2/27B_full.yaml
@@ -16,6 +16,8 @@
# This config works only when the model is being fine-tuned on 2+ GPUs.
+output_dir: /tmp/torchtune/gemma2_27B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -23,8 +25,8 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
@@ -39,7 +41,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00024"
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-27b
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
@@ -53,14 +55,16 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -68,7 +72,31 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-27b-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
diff --git a/recipes/configs/gemma2/27B_lora.yaml b/recipes/configs/gemma2/27B_lora.yaml
index d00455f01a..c8b96ed55e 100644
--- a/recipes/configs/gemma2/27B_lora.yaml
+++ b/recipes/configs/gemma2/27B_lora.yaml
@@ -16,6 +16,8 @@
# This config works only when the model is being fine-tuned on 2+ GPUs.
+output_dir: /tmp/torchtune/gemma2_27B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -23,18 +25,18 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.gemma2.lora_gemma2_27b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
+ lora_rank: 64 # higher increases accuracy and memory
+ lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0
checkpointer:
@@ -44,7 +46,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00024"
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-27b/
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -65,14 +67,15 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -80,7 +83,31 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-27b-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
diff --git a/recipes/configs/gemma2/27B_lora_single_device.yaml b/recipes/configs/gemma2/27B_lora_single_device.yaml
index 3cbdac3cf4..74af4c22b5 100644
--- a/recipes/configs/gemma2/27B_lora_single_device.yaml
+++ b/recipes/configs/gemma2/27B_lora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/gemma2_27B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -22,18 +24,18 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.gemma2.lora_gemma2_27b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
- lora_rank: 8
- lora_alpha: 16
+ lora_rank: 8 # higher increases accuracy and memory
+ lora_alpha: 16 # usually alpha=2*rank
lora_dropout: 0.0
checkpointer:
@@ -43,7 +45,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00024"
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-27b/
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -64,15 +66,15 @@ loss:
batch_size: 2
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -80,8 +82,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-27b-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/27B_qlora_single_device.yaml b/recipes/configs/gemma2/27B_qlora_single_device.yaml
index 51e481a621..2f11ef13ab 100644
--- a/recipes/configs/gemma2/27B_qlora_single_device.yaml
+++ b/recipes/configs/gemma2/27B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/gemma2_27B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -22,18 +24,18 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.gemma2.qlora_gemma2_27b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
+ lora_rank: 64 # higher increases accuracy and memory
+ lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0
checkpointer:
@@ -43,7 +45,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00024"
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-27b/
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -64,15 +66,15 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 4
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -80,8 +82,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-27b-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/2B_full.yaml b/recipes/configs/gemma2/2B_full.yaml
index b87cf1ccf9..42b034fa2c 100644
--- a/recipes/configs/gemma2/2B_full.yaml
+++ b/recipes/configs/gemma2/2B_full.yaml
@@ -16,6 +16,8 @@
# This config works only when the model is being fine-tuned on 2+ GPUs.
+output_dir: /tmp/torchtune/gemma2_2B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -23,8 +25,8 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
@@ -41,7 +43,7 @@ checkpointer:
model-00003-of-00003.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-2b
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
@@ -55,14 +57,16 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -70,7 +74,31 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
diff --git a/recipes/configs/gemma2/2B_lora.yaml b/recipes/configs/gemma2/2B_lora.yaml
index 381739dbd5..3a38fc5d0c 100644
--- a/recipes/configs/gemma2/2B_lora.yaml
+++ b/recipes/configs/gemma2/2B_lora.yaml
@@ -15,6 +15,8 @@
#
# This config works only when the model is being fine-tuned on 2+ GPUs.
+output_dir: /tmp/torchtune/gemma2_2B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -22,18 +24,18 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.gemma2.lora_gemma2_2b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
+ lora_rank: 64 # higher increases accuracy and memory
+ lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0
checkpointer:
@@ -45,7 +47,7 @@ checkpointer:
model-00003-of-00003.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-2b
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
@@ -67,14 +69,15 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -82,7 +85,31 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
diff --git a/recipes/configs/gemma2/2B_lora_single_device.yaml b/recipes/configs/gemma2/2B_lora_single_device.yaml
index d36a0d1f4c..228e1b3b33 100644
--- a/recipes/configs/gemma2/2B_lora_single_device.yaml
+++ b/recipes/configs/gemma2/2B_lora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/gemma2_2B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -22,18 +24,18 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.gemma2.lora_gemma2_2b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
+ lora_rank: 64 # higher increases accuracy and memory
+ lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0
checkpointer:
@@ -45,7 +47,7 @@ checkpointer:
model-00003-of-00003.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-2b
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -66,15 +68,15 @@ loss:
batch_size: 8
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 2
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -82,8 +84,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/2B_qlora_single_device.yaml b/recipes/configs/gemma2/2B_qlora_single_device.yaml
index c56a51953c..16dd23cc51 100644
--- a/recipes/configs/gemma2/2B_qlora_single_device.yaml
+++ b/recipes/configs/gemma2/2B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/gemma2_2B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -22,18 +24,18 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.gemma2.qlora_gemma2_2b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
+ lora_rank: 64 # higher increases accuracy and memory
+ lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0
checkpointer:
@@ -45,7 +47,7 @@ checkpointer:
model-00003-of-00003.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-2b
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -66,15 +68,15 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 4
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -82,8 +84,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml
index 0fc7e6e4e4..bbb31fb268 100644
--- a/recipes/configs/gemma2/9B_full.yaml
+++ b/recipes/configs/gemma2/9B_full.yaml
@@ -16,6 +16,8 @@
# This config works only when the model is being fine-tuned on 2+ GPUs.
+output_dir: /tmp/torchtune/gemma2_9B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -23,8 +25,8 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
@@ -39,7 +41,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00008"
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-9b
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
@@ -53,14 +55,16 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -68,7 +72,31 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-9b-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
diff --git a/recipes/configs/gemma2/9B_lora.yaml b/recipes/configs/gemma2/9B_lora.yaml
index 5c12391cea..3c402433a0 100644
--- a/recipes/configs/gemma2/9B_lora.yaml
+++ b/recipes/configs/gemma2/9B_lora.yaml
@@ -16,6 +16,8 @@
# This config works only when the model is being fine-tuned on 2+ GPUs.
+output_dir: /tmp/torchtune/gemma2_9B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -23,18 +25,18 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.gemma2.lora_gemma2_9b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
+ lora_rank: 64 # higher increases accuracy and memory
+ lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0
checkpointer:
@@ -44,7 +46,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00008"
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-9b/
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -65,14 +67,15 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 1
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -80,7 +83,31 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-9b-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
diff --git a/recipes/configs/gemma2/9B_lora_single_device.yaml b/recipes/configs/gemma2/9B_lora_single_device.yaml
index a16cb1130e..7a665c1f12 100644
--- a/recipes/configs/gemma2/9B_lora_single_device.yaml
+++ b/recipes/configs/gemma2/9B_lora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/gemma2_9B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -22,18 +24,18 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.gemma2.lora_gemma2_9b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
- lora_rank: 8
- lora_alpha: 16
+ lora_rank: 8 # higher increases accuracy and memory
+ lora_alpha: 16 # usually alpha=2*rank
lora_dropout: 0.0
checkpointer:
@@ -43,7 +45,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00008"
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-9b/
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -64,15 +66,15 @@ loss:
batch_size: 8
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 2
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -80,8 +82,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-9b-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/gemma2/9B_qlora_single_device.yaml b/recipes/configs/gemma2/9B_qlora_single_device.yaml
index dd9bdf9a9d..eff7057b63 100644
--- a/recipes/configs/gemma2/9B_qlora_single_device.yaml
+++ b/recipes/configs/gemma2/9B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/gemma2_9B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.gemma.gemma_tokenizer
@@ -22,18 +24,18 @@ tokenizer:
# Dataset
dataset:
- packed: False # Set to true for great speed ups
_component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
seed: null
shuffle: True
# Model Arguments
model:
_component_: torchtune.models.gemma2.qlora_gemma2_9b
- lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
- lora_rank: 64
- lora_alpha: 128
+ lora_rank: 64 # higher increases accuracy and memory
+ lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0
checkpointer:
@@ -43,7 +45,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00008"
recipe_checkpoint: null
- output_dir: /tmp/gemma-2-9b/
+ output_dir: ${output_dir}
model_type: GEMMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -64,15 +66,15 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 4
-compile: False # pytorch compile, set to true for perf/memory improvement
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
# Memory management
-enable_activation_checkpointing: True
-enable_activation_offloading: False
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Reduced precision
dtype: bf16
@@ -80,8 +82,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-gemma2-9b-lora
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/llama2/13B_full.yaml b/recipes/configs/llama2/13B_full.yaml
index fd7b7421c1..67932bbb1b 100644
--- a/recipes/configs/llama2/13B_full.yaml
+++ b/recipes/configs/llama2/13B_full.yaml
@@ -18,6 +18,8 @@
# 7B_full_single_device.yaml. Please update the model and checkpoints to 13B
# in that config.
+output_dir: /tmp/torchtune/llama2_13B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama2.llama2_13b
@@ -31,7 +33,7 @@ checkpointer:
pytorch_model-00003-of-00003.bin
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-13b-hf/
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
@@ -58,8 +60,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -75,11 +77,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-llama2-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama2/13B_lora.yaml b/recipes/configs/llama2/13B_lora.yaml
index 2bae98471d..fbd8a2141d 100644
--- a/recipes/configs/llama2/13B_lora.yaml
+++ b/recipes/configs/llama2/13B_lora.yaml
@@ -19,6 +19,8 @@
# the 13B model.
+output_dir: /tmp/torchtune/llama2_13B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama2.lora_llama2_13b
@@ -39,7 +41,7 @@ checkpointer:
]
adapter_checkpoint: null
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-13b-hf/
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -74,14 +76,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -91,6 +92,7 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama2/13B_qlora_single_device.yaml b/recipes/configs/llama2/13B_qlora_single_device.yaml
index 62e74c4e62..69558858bd 100644
--- a/recipes/configs/llama2/13B_qlora_single_device.yaml
+++ b/recipes/configs/llama2/13B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama2_13B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama2.qlora_llama2_13b
@@ -40,7 +42,7 @@ checkpointer:
]
adapter_checkpoint: null
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-13b-hf/
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -69,14 +71,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/qlora_finetune_output/
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/llama2/70B_lora.yaml b/recipes/configs/llama2/70B_lora.yaml
index bf2a0817d0..a9be1f6cb6 100644
--- a/recipes/configs/llama2/70B_lora.yaml
+++ b/recipes/configs/llama2/70B_lora.yaml
@@ -9,6 +9,8 @@
# # tune run --nproc_per_node 8 lora_finetune_distributed --config llama2/70B_lora
#
+output_dir: /tmp/torchtune/llama2_70B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama2.lora_llama2_70b
@@ -45,7 +47,7 @@ checkpointer:
pytorch_model-00015-of-00015.bin,
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-70b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -74,14 +76,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-compile: False # pytorch compile, set to true for better perf/memory
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -91,6 +92,7 @@ dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama2/70B_qlora.yaml b/recipes/configs/llama2/70B_qlora.yaml
index 38444bf0c7..3e48bbcaa0 100644
--- a/recipes/configs/llama2/70B_qlora.yaml
+++ b/recipes/configs/llama2/70B_qlora.yaml
@@ -14,6 +14,8 @@
# tune run --nproc_per_node 8 lora_finetune_distributed --config llama2/70B_qlora checkpointer.checkpoint_dir=
#
+output_dir: /tmp/torchtune/llama2_70B/qlora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama2.qlora_llama2_70b
@@ -50,7 +52,7 @@ checkpointer:
pytorch_model-00015-of-00015.bin,
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-70b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -83,14 +85,13 @@ fsdp:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/qlora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -100,6 +101,7 @@ dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama2/7B_full.yaml b/recipes/configs/llama2/7B_full.yaml
index 7e69c8f5a6..40fb804035 100644
--- a/recipes/configs/llama2/7B_full.yaml
+++ b/recipes/configs/llama2/7B_full.yaml
@@ -18,6 +18,8 @@
# best to use 7B_full_single_device.yaml for those cases
+output_dir: /tmp/torchtune/llama2_7B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
@@ -43,7 +45,7 @@ checkpointer:
pytorch_model-00002-of-00002.bin
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
@@ -57,8 +59,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -74,11 +76,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-llama2-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama2/7B_full_low_memory.yaml b/recipes/configs/llama2/7B_full_low_memory.yaml
index d7ee50898e..29d157dbf6 100644
--- a/recipes/configs/llama2/7B_full_low_memory.yaml
+++ b/recipes/configs/llama2/7B_full_low_memory.yaml
@@ -20,6 +20,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama2_7B/full_low_memory # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
@@ -45,7 +47,7 @@ checkpointer:
pytorch_model-00002-of-00002.bin
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
@@ -62,8 +64,8 @@ optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_step
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
@@ -78,11 +80,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-llama2-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama2/7B_lora.yaml b/recipes/configs/llama2/7B_lora.yaml
index 5bf21ccb2c..fc8ec7e346 100644
--- a/recipes/configs/llama2/7B_lora.yaml
+++ b/recipes/configs/llama2/7B_lora.yaml
@@ -18,6 +18,8 @@
# or 7B_qlora_single_device.yaml
+output_dir: /tmp/torchtune/llama2_7B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama2.lora_llama2_7b
@@ -42,7 +44,7 @@ checkpointer:
]
adapter_checkpoint: null
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -71,14 +73,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-compile: False # pytorch compile, set to true for better perf/memory
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/llama2/7B_lora_dpo.yaml b/recipes/configs/llama2/7B_lora_dpo.yaml
index abf1b43138..250d62db44 100644
--- a/recipes/configs/llama2/7B_lora_dpo.yaml
+++ b/recipes/configs/llama2/7B_lora_dpo.yaml
@@ -16,6 +16,8 @@
# This config works best when the model is being fine-tuned on 2+ GPUs.
# For single device LoRA DPO alignment please use 7B_lora_dpo_single_device.yaml
+output_dir: /tmp/torchtune/llama2_7B/lora_dpo # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama2.lora_llama2_7b
@@ -39,7 +41,7 @@ checkpointer:
[pytorch_model-00001-of-00002.bin, pytorch_model-00002-of-00002.bin]
adapter_checkpoint: null
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -69,18 +71,20 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: 1000
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_dpo_output/
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
# Environment
device: cuda
dtype: bf16
+
+# Memory management
enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
diff --git a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml
index 7543cb5d6f..4d154c38ce 100644
--- a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml
+++ b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama2_7B/lora_dpo_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama2.lora_llama2_7b
@@ -38,7 +40,7 @@ checkpointer:
[pytorch_model-00001-of-00002.bin, pytorch_model-00002-of-00002.bin]
adapter_checkpoint: null
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -66,18 +68,20 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: 1000
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_dpo_output/
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
# Environment
device: cuda
dtype: bf16
+
+# Memory management
enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml
index 4196cc5a59..ac87053ae1 100644
--- a/recipes/configs/llama2/7B_lora_single_device.yaml
+++ b/recipes/configs/llama2/7B_lora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama2_7B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama2.lora_llama2_7b
@@ -40,7 +42,7 @@ checkpointer:
]
adapter_checkpoint: null
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -69,14 +71,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/llama2/7B_qat_full.yaml b/recipes/configs/llama2/7B_qat_full.yaml
index 15a3f000e4..a82e3f580c 100644
--- a/recipes/configs/llama2/7B_qat_full.yaml
+++ b/recipes/configs/llama2/7B_qat_full.yaml
@@ -14,6 +14,8 @@
# tune run --nnodes 1 --nproc_per_node 4 qat_distributed --config llama2/7B_qat_full checkpointer.checkpoint_dir=
+output_dir: /tmp/torchtune/llama2_7B/qat_full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
@@ -39,7 +41,7 @@ checkpointer:
pytorch_model-00002-of-00002.bin
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
@@ -53,8 +55,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# QAT arguments
@@ -75,11 +77,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/alpaca-llama2-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama2/7B_qlora.yaml b/recipes/configs/llama2/7B_qlora.yaml
index 667b94c376..49c44c27f3 100644
--- a/recipes/configs/llama2/7B_qlora.yaml
+++ b/recipes/configs/llama2/7B_qlora.yaml
@@ -17,6 +17,8 @@
# For single device LoRA finetuning please use 7B_lora_single_device.yaml
# or 7B_qlora_single_device.yaml
+output_dir: /tmp/torchtune/llama2_7B/qlora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama2.qlora_llama2_7b
@@ -41,7 +43,7 @@ checkpointer:
]
adapter_checkpoint: null
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -74,14 +76,13 @@ fsdp:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/qlora_finetune_output/
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -91,6 +92,7 @@ dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama2/7B_qlora_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml
index 028265007e..a925ac782b 100644
--- a/recipes/configs/llama2/7B_qlora_single_device.yaml
+++ b/recipes/configs/llama2/7B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama2_7B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama2.qlora_llama2_7b
@@ -39,7 +41,7 @@ checkpointer:
]
adapter_checkpoint: null
recipe_checkpoint: null
- output_dir: /tmp/Llama-2-7b-hf
+ output_dir: ${output_dir}
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -68,14 +70,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/qlora_finetune_output/
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml
index 5878b2fd95..df07de0165 100644
--- a/recipes/configs/llama3/70B_full.yaml
+++ b/recipes/configs/llama3/70B_full.yaml
@@ -17,6 +17,8 @@
#
+output_dir: /tmp/torchtune/llama3_70B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -70,7 +72,7 @@ checkpointer:
model-00030-of-00030.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-70b
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
@@ -86,7 +88,7 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
# Training env
device: cuda
@@ -94,9 +96,9 @@ device: cuda
# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
-custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed
+custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
fsdp_cpu_offload: True
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Reduced precision
@@ -105,11 +107,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/full-llama3-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3/70B_lora.yaml b/recipes/configs/llama3/70B_lora.yaml
index 2e4e718f62..2d0931cc07 100644
--- a/recipes/configs/llama3/70B_lora.yaml
+++ b/recipes/configs/llama3/70B_lora.yaml
@@ -9,6 +9,8 @@
# # tune run --nproc_per_node 8 lora_finetune_distributed --config llama3/70B_lora
#
+output_dir: /tmp/torchtune/llama3_70B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3.lora_llama3_70b
@@ -60,7 +62,7 @@ checkpointer:
model-00030-of-00030.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-70B-Instruct
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -89,14 +91,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -105,6 +106,8 @@ device: cuda
dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+# custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
+
# Profiler (disabled)
profiler:
diff --git a/recipes/configs/llama3/8B_dora.yaml b/recipes/configs/llama3/8B_dora.yaml
index 276f303807..98bd75b08e 100644
--- a/recipes/configs/llama3/8B_dora.yaml
+++ b/recipes/configs/llama3/8B_dora.yaml
@@ -14,6 +14,8 @@
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_dora checkpointer.checkpoint_dir=
+output_dir: /tmp/torchtune/llama3_8B/dora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3.lora_llama3_8b
@@ -36,7 +38,7 @@ checkpointer:
consolidated.00.pth
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
@@ -64,14 +66,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/dora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -81,6 +82,7 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3/8B_dora_single_device.yaml b/recipes/configs/llama3/8B_dora_single_device.yaml
index 82c7c765b5..4258cc08a4 100644
--- a/recipes/configs/llama3/8B_dora_single_device.yaml
+++ b/recipes/configs/llama3/8B_dora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_8B/dora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3.lora_llama3_8b
@@ -38,7 +40,7 @@ checkpointer:
consolidated.00.pth
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
@@ -66,14 +68,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/dora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml
index a065fa9ece..2723d08c90 100644
--- a/recipes/configs/llama3/8B_full.yaml
+++ b/recipes/configs/llama3/8B_full.yaml
@@ -18,6 +18,8 @@
# best to use 8B_full_single_device.yaml for those cases
+output_dir: /tmp/torchtune/llama3_8B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -42,7 +44,7 @@ checkpointer:
consolidated.00.pth
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
@@ -57,8 +59,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -75,11 +77,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/full-llama3-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3/8B_full_single_device.yaml b/recipes/configs/llama3/8B_full_single_device.yaml
index a63845fc30..ad534c62b9 100644
--- a/recipes/configs/llama3/8B_full_single_device.yaml
+++ b/recipes/configs/llama3/8B_full_single_device.yaml
@@ -20,6 +20,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_8B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -44,7 +46,7 @@ checkpointer:
consolidated.00.pth
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
@@ -60,9 +62,9 @@ lr_scheduler:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
@@ -77,11 +79,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/full-llama3-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3/8B_lora.yaml b/recipes/configs/llama3/8B_lora.yaml
index f9cb8f9d95..7ac7bd1942 100644
--- a/recipes/configs/llama3/8B_lora.yaml
+++ b/recipes/configs/llama3/8B_lora.yaml
@@ -17,6 +17,8 @@
# For single device LoRA finetuning please use 8B_lora_single_device.yaml
# or 8B_qlora_single_device.yaml
+output_dir: /tmp/torchtune/llama3_8B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -40,7 +42,7 @@ checkpointer:
consolidated.00.pth
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -69,14 +71,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -86,6 +87,7 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml
index 5ae3a0088a..8b1db9d06d 100644
--- a/recipes/configs/llama3/8B_lora_single_device.yaml
+++ b/recipes/configs/llama3/8B_lora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_8B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3.lora_llama3_8b
@@ -39,7 +41,7 @@ checkpointer:
consolidated.00.pth
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -68,14 +70,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -87,6 +88,7 @@ dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3/8B_qat_full.yaml b/recipes/configs/llama3/8B_qat_full.yaml
index 49c5c7ee74..b1d9bfad5b 100644
--- a/recipes/configs/llama3/8B_qat_full.yaml
+++ b/recipes/configs/llama3/8B_qat_full.yaml
@@ -13,6 +13,8 @@
# you can run:
# tune run --nproc_per_node 4 qat_distributed --config llama3/8B_qat_full checkpointer.checkpoint_dir=
+output_dir: /tmp/torchtune/llama3_8B/qat_full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -37,7 +39,7 @@ checkpointer:
consolidated.00.pth
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
@@ -57,8 +59,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -75,11 +77,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/full-llama3-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3/8B_qat_lora.yaml b/recipes/configs/llama3/8B_qat_lora.yaml
new file mode 100644
index 0000000000..5a889a3d63
--- /dev/null
+++ b/recipes/configs/llama3/8B_qat_lora.yaml
@@ -0,0 +1,115 @@
+# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
+# using a Llama3 8B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token
+#
+# To launch on 2 devices, run the following command from root:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3/8B_qat_lora
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3/8B_qat_lora checkpointer.checkpoint_dir=
+
+output_dir: /tmp/torchtune/llama3_8B/qat_lora # /tmp may be deleted by your system. Change it to your preference.
+
+# Tokenizer
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3.lora_llama3_8b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ apply_lora_to_output: False
+ lora_rank: 8 # higher increases accuracy and memory
+ lora_alpha: 16 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+checkpointer:
+ _component_: torchtune.training.FullModelMetaCheckpointer
+ checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
+ checkpoint_files: [
+ consolidated.00.pth
+ ]
+ recipe_checkpoint: null
+ output_dir: ${output_dir}
+ model_type: LLAMA3
+resume_from_checkpoint: False
+save_adapter_weights_only: False
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.alpaca_cleaned_dataset
+ packed: False # True increases speed
+seed: null
+shuffle: True
+batch_size: 2
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.01
+ lr: 3e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+
+# Training
+epochs: 1
+max_steps_per_epoch: null
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+
+# Logging
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}/logs
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
+
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
+
+# QAT arguments
+quantizer:
+ _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
+ groupsize: 256
diff --git a/recipes/configs/llama3/8B_qdora_single_device.yaml b/recipes/configs/llama3/8B_qdora_single_device.yaml
index 823e0f75fe..8a5a39b58b 100644
--- a/recipes/configs/llama3/8B_qdora_single_device.yaml
+++ b/recipes/configs/llama3/8B_qdora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_8B/qdora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3.lora_llama3_8b
@@ -39,7 +41,7 @@ checkpointer:
consolidated.00.pth
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
@@ -67,14 +69,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/qdora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/llama3/8B_qlora_single_device.yaml b/recipes/configs/llama3/8B_qlora_single_device.yaml
index 76af71e432..4922ada0f0 100644
--- a/recipes/configs/llama3/8B_qlora_single_device.yaml
+++ b/recipes/configs/llama3/8B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_8B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3.qlora_llama3_8b
@@ -38,7 +40,7 @@ checkpointer:
consolidated.00.pth
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -67,14 +69,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/qlora_finetune_output/
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -86,6 +87,7 @@ dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_1/405B_qlora.yaml b/recipes/configs/llama3_1/405B_qlora.yaml
index cc4eead534..749b16717f 100644
--- a/recipes/configs/llama3_1/405B_qlora.yaml
+++ b/recipes/configs/llama3_1/405B_qlora.yaml
@@ -14,6 +14,8 @@
#
+output_dir: /tmp/torchtune/llama3_1_405B/qlora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3_1.qlora_llama3_1_405b
@@ -34,7 +36,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: 00191
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3.1-405B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: True # Set to false to save the whole model + adapter merged
@@ -67,14 +69,13 @@ fsdp:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/qlora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -84,6 +85,7 @@ dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml
index d92fcef1f6..d3d546cfdb 100644
--- a/recipes/configs/llama3_1/70B_full.yaml
+++ b/recipes/configs/llama3_1/70B_full.yaml
@@ -16,6 +16,8 @@
# This config is only tested on an 8xA100 machine.
#
+output_dir: /tmp/torchtune/llama3_1_70B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -69,7 +71,7 @@ checkpointer:
model-00030-of-00030.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
@@ -87,7 +89,7 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
# Training env
@@ -98,7 +100,7 @@ enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
fsdp_cpu_offload: True
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Reduced precision
@@ -107,11 +109,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/full-llama3_1-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_1/70B_lora.yaml b/recipes/configs/llama3_1/70B_lora.yaml
index a89c01b4c1..c27636d2fb 100644
--- a/recipes/configs/llama3_1/70B_lora.yaml
+++ b/recipes/configs/llama3_1/70B_lora.yaml
@@ -8,6 +8,8 @@
# This config needs 8 GPUs to run
# tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_1/70B_lora
+output_dir: /tmp/torchtune/llama3_1_70B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3_1.lora_llama3_1_70b
@@ -59,7 +61,7 @@ checkpointer:
model-00030-of-00030.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: True # Set to false to save the whole model + adapter merged
@@ -88,14 +90,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora-llama3_1-finetune-output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -104,6 +105,8 @@ device: cuda
dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+# custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
+
# Profiler (disabled)
profiler:
diff --git a/recipes/configs/llama3_1/8B_full.yaml b/recipes/configs/llama3_1/8B_full.yaml
index 32aff922cf..357d20356d 100644
--- a/recipes/configs/llama3_1/8B_full.yaml
+++ b/recipes/configs/llama3_1/8B_full.yaml
@@ -18,6 +18,8 @@
# best to use 8B_full_single_device.yaml for those cases
+output_dir: /tmp/torchtune/llama3_1_8B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -45,7 +47,7 @@ checkpointer:
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
@@ -60,9 +62,9 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
# Training env
device: cuda
@@ -78,11 +80,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/full-llama3.1-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml
index 66f397e1df..1429b9cc2b 100644
--- a/recipes/configs/llama3_1/8B_full_single_device.yaml
+++ b/recipes/configs/llama3_1/8B_full_single_device.yaml
@@ -20,6 +20,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_1_8B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -47,7 +49,7 @@ checkpointer:
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
@@ -60,9 +62,9 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
@@ -77,11 +79,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/full-llama3.1-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_1/8B_lora.yaml b/recipes/configs/llama3_1/8B_lora.yaml
index b889f20fe2..7303194173 100644
--- a/recipes/configs/llama3_1/8B_lora.yaml
+++ b/recipes/configs/llama3_1/8B_lora.yaml
@@ -17,6 +17,8 @@
# For single device LoRA finetuning please use 8B_lora_single_device.yaml
# or 8B_qlora_single_device.yaml
+output_dir: /tmp/torchtune/llama3_1_8B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -43,7 +45,7 @@ checkpointer:
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -72,14 +74,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -89,6 +90,7 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_1/8B_lora_dpo.yaml b/recipes/configs/llama3_1/8B_lora_dpo.yaml
new file mode 100644
index 0000000000..7160362b2a
--- /dev/null
+++ b/recipes/configs/llama3_1/8B_lora_dpo.yaml
@@ -0,0 +1,93 @@
+# Config for multi-device LoRA DPO alignment in lora_dpo_distributed.py
+# using a Llama2 7B model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on 2 devices, run the following command from root:
+# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+# tune run --nnodes 1 --nproc_per_node 2 lora_dpo_distributed --config llama3_1/8B_lora_dpo checkpointer.checkpoint_dir=
+#
+# This config works best when the model is being fine-tuned on 2+ GPUs.
+# For single device LoRA DPO alignment please use llama3_1/8B_lora_dpo_single_device
+
+output_dir: /tmp/torchtune/llama3_1_8B/lora_dpo # /tmp may be deleted by your system. Change it to your preference.
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3_1.lora_llama3_1_8b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ apply_lora_to_output: False
+ lora_rank: 8 # higher increases accuracy and memory
+ lora_alpha: 16 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+# Tokenizer
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+checkpointer:
+ _component_: torchtune.training.FullModelHFCheckpointer
+ checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
+ checkpoint_files: [
+ model-00001-of-00004.safetensors,
+ model-00002-of-00004.safetensors,
+ model-00003-of-00004.safetensors,
+ model-00004-of-00004.safetensors
+ ]
+ recipe_checkpoint: null
+ output_dir: ${output_dir}
+ model_type: LLAMA3
+resume_from_checkpoint: False
+save_adapter_weights_only: False
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.stack_exchange_paired_dataset
+seed: null
+shuffle: True
+batch_size: 4
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.05
+ lr: 5e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.rlhf.loss.DPOLoss
+ beta: 0.1
+ label_smoothing: 0
+
+# Training
+epochs: 1
+max_steps_per_epoch: 1000
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+
+# Logging
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}/logs
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+
+# Memory management
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
diff --git a/recipes/configs/llama3_1/8B_lora_dpo_single_device.yaml b/recipes/configs/llama3_1/8B_lora_dpo_single_device.yaml
new file mode 100644
index 0000000000..81d6158b28
--- /dev/null
+++ b/recipes/configs/llama3_1/8B_lora_dpo_single_device.yaml
@@ -0,0 +1,90 @@
+# Config for single device LoRA DPO alignment in lora_dpo_single_device.py
+# using a Llama2 7B model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on a single device, run the following command from root:
+# tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+# tune run lora_dpo_single_device --config llama3_1/8B_lora_dpo_single_device checkpointer.checkpoint_dir=
+#
+# This config works only for training on single device.
+
+output_dir: /tmp/torchtune/llama3_1_8B/lora_dpo_single_device # /tmp may be deleted by your system. Change it to your preference.
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3_1.lora_llama3_1_8b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ apply_lora_to_output: False
+ lora_rank: 8 # higher increases accuracy and memory
+ lora_alpha: 16 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+# Tokenizer
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+checkpointer:
+ _component_: torchtune.training.FullModelHFCheckpointer
+ checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
+ checkpoint_files: [
+ model-00001-of-00004.safetensors,
+ model-00002-of-00004.safetensors,
+ model-00003-of-00004.safetensors,
+ model-00004-of-00004.safetensors
+ ]
+ recipe_checkpoint: null
+ output_dir: ${output_dir}
+ model_type: LLAMA3
+resume_from_checkpoint: False
+save_adapter_weights_only: False
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.stack_exchange_paired_dataset
+seed: null
+shuffle: True
+batch_size: 4
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.05
+ lr: 5e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.rlhf.loss.DPOLoss
+
+# Training
+epochs: 1
+max_steps_per_epoch: 1000
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+
+# Logging
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}/logs
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+
+# Memory management
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml
index f631dcfd7e..0f19750219 100644
--- a/recipes/configs/llama3_1/8B_lora_single_device.yaml
+++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_1_8B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3_1.lora_llama3_1_8b
@@ -42,7 +44,7 @@ checkpointer:
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -71,14 +73,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -90,6 +91,7 @@ dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_1/8B_qat_lora.yaml b/recipes/configs/llama3_1/8B_qat_lora.yaml
new file mode 100644
index 0000000000..d25351a0e4
--- /dev/null
+++ b/recipes/configs/llama3_1/8B_qat_lora.yaml
@@ -0,0 +1,118 @@
+# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
+# using a Llama3.1 8B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on 2 devices, run the following command from root:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_1/8B_qat_lora
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_1/8B_qat_lora checkpointer.checkpoint_dir=
+
+output_dir: /tmp/torchtune/llama3_1_8B/qat_lora # /tmp may be deleted by your system. Change it to your preference.
+
+# Tokenizer
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3_1.lora_llama3_1_8b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ apply_lora_to_output: False
+ lora_rank: 8 # higher increases accuracy and memory
+ lora_alpha: 16 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+checkpointer:
+ _component_: torchtune.training.FullModelHFCheckpointer
+ checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
+ checkpoint_files: [
+ model-00001-of-00004.safetensors,
+ model-00002-of-00004.safetensors,
+ model-00003-of-00004.safetensors,
+ model-00004-of-00004.safetensors
+ ]
+ recipe_checkpoint: null
+ output_dir: ${output_dir}
+ model_type: LLAMA3
+resume_from_checkpoint: False
+save_adapter_weights_only: False
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.alpaca_cleaned_dataset
+ packed: False # True increases speed
+seed: null
+shuffle: True
+batch_size: 2
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.01
+ lr: 3e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+
+# Training
+epochs: 1
+max_steps_per_epoch: null
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+
+# Logging
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}/logs
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
+
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
+
+# QAT arguments
+quantizer:
+ _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
+ groupsize: 256
diff --git a/recipes/configs/llama3_1/8B_qlora_single_device.yaml b/recipes/configs/llama3_1/8B_qlora_single_device.yaml
index 57c8cdb513..3386601917 100644
--- a/recipes/configs/llama3_1/8B_qlora_single_device.yaml
+++ b/recipes/configs/llama3_1/8B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_1_8B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3_1.qlora_llama3_1_8b
@@ -41,7 +43,7 @@ checkpointer:
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -70,14 +72,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/qlora_finetune_output/
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -89,6 +90,7 @@ dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2/1B_full.yaml b/recipes/configs/llama3_2/1B_full.yaml
index 56fc968b0d..25c7de45c1 100644
--- a/recipes/configs/llama3_2/1B_full.yaml
+++ b/recipes/configs/llama3_2/1B_full.yaml
@@ -18,6 +18,8 @@
# best to use 1B_full_single_device.yaml for those cases
+output_dir: /tmp/torchtune/llama3_2_1B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -42,7 +44,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-1B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
@@ -57,7 +59,7 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
# Training env
@@ -66,7 +68,7 @@ device: cuda
# Memory management
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Reduced precision
@@ -75,11 +77,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/full-llama3.2-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2/1B_full_single_device.yaml b/recipes/configs/llama3_2/1B_full_single_device.yaml
index 6945def166..77b1b98410 100644
--- a/recipes/configs/llama3_2/1B_full_single_device.yaml
+++ b/recipes/configs/llama3_2/1B_full_single_device.yaml
@@ -19,11 +19,8 @@
#
# This config works only for training on single device.
-output_dir: /tmp/llama_3_2_1b/full_single_device
-# Model Arguments
-model:
- _component_: torchtune.models.llama3_2.llama3_2_1b
+output_dir: /tmp/torchtune/llama3_2_1B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
# Tokenizer
tokenizer:
@@ -59,9 +56,9 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
@@ -77,10 +74,10 @@ dtype: bf16
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
-
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2/1B_lora.yaml b/recipes/configs/llama3_2/1B_lora.yaml
index 4903e482ba..15e14be3b1 100644
--- a/recipes/configs/llama3_2/1B_lora.yaml
+++ b/recipes/configs/llama3_2/1B_lora.yaml
@@ -17,6 +17,8 @@
# For single device LoRA finetuning please use 1B_lora_single_device.yaml
# or 1B_qlora_single_device.yaml
+output_dir: /tmp/torchtune/llama3_2_1B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -39,7 +41,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-1B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -68,14 +70,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -85,6 +86,7 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2/1B_lora_single_device.yaml b/recipes/configs/llama3_2/1B_lora_single_device.yaml
index 951af6439c..adabdbdddb 100644
--- a/recipes/configs/llama3_2/1B_lora_single_device.yaml
+++ b/recipes/configs/llama3_2/1B_lora_single_device.yaml
@@ -17,6 +17,8 @@
output_dir: /tmp/llama_3_2_1b/lora_single_device
+output_dir: /tmp/torchtune/llama3_2_1B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
@@ -68,8 +70,8 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
metric_logger:
@@ -86,6 +88,7 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2/1B_qat_lora.yaml b/recipes/configs/llama3_2/1B_qat_lora.yaml
new file mode 100644
index 0000000000..79f628367f
--- /dev/null
+++ b/recipes/configs/llama3_2/1B_qat_lora.yaml
@@ -0,0 +1,114 @@
+# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
+# using a Llama3.2 1B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on 2 devices, run the following command from root:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/1B_qat_lora
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/1B_qat_lora checkpointer.checkpoint_dir=
+
+output_dir: /tmp/torchtune/llama3_2_1B/qat_lora # /tmp may be deleted by your system. Change it to your preference.
+
+# Tokenizer
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3_2.lora_llama3_2_1b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ lora_rank: 64 # higher increases accuracy and memory
+ lora_alpha: 128 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+checkpointer:
+ _component_: torchtune.training.FullModelHFCheckpointer
+ checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
+ checkpoint_files: [
+ model.safetensors
+ ]
+ recipe_checkpoint: null
+ output_dir: ${output_dir}
+ model_type: LLAMA3_2
+resume_from_checkpoint: False
+save_adapter_weights_only: False
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.alpaca_cleaned_dataset
+ packed: False # True increases speed
+seed: null
+shuffle: True
+batch_size: 4
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.01
+ lr: 3e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+
+# Training
+epochs: 1
+max_steps_per_epoch: null
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+
+# Logging
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}/logs
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
+
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
+
+# QAT arguments
+quantizer:
+ _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
+ groupsize: 256
diff --git a/recipes/configs/llama3_2/1B_qlora_single_device.yaml b/recipes/configs/llama3_2/1B_qlora_single_device.yaml
index 3573ae38fc..99165c806f 100644
--- a/recipes/configs/llama3_2/1B_qlora_single_device.yaml
+++ b/recipes/configs/llama3_2/1B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_2_1B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3_2.qlora_llama3_2_1b
@@ -37,7 +39,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-1B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -66,14 +68,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -85,6 +86,7 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml
index 4128bb58e7..0703437596 100644
--- a/recipes/configs/llama3_2/3B_full.yaml
+++ b/recipes/configs/llama3_2/3B_full.yaml
@@ -18,6 +18,8 @@
# best to use 3B_full_single_device.yaml for those cases
+output_dir: /tmp/torchtune/llama3_2_3B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -43,7 +45,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-3B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
@@ -58,7 +60,7 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
# Training env
device: cuda
@@ -66,7 +68,7 @@ device: cuda
# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Reduced precision
@@ -75,11 +77,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/full-llama3.2-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2/3B_full_single_device.yaml b/recipes/configs/llama3_2/3B_full_single_device.yaml
index ebc49ae1fb..052c524019 100644
--- a/recipes/configs/llama3_2/3B_full_single_device.yaml
+++ b/recipes/configs/llama3_2/3B_full_single_device.yaml
@@ -20,6 +20,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_2_3B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -45,7 +47,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-3B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
@@ -58,9 +60,9 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
@@ -75,11 +77,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/full-llama3.2-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2/3B_lora.yaml b/recipes/configs/llama3_2/3B_lora.yaml
index 0e790b20cb..9575df0f78 100644
--- a/recipes/configs/llama3_2/3B_lora.yaml
+++ b/recipes/configs/llama3_2/3B_lora.yaml
@@ -17,6 +17,8 @@
# For single device LoRA finetuning please use 3B_lora_single_device.yaml
# or 3B_qlora_single_device.yaml
+output_dir: /tmp/torchtune/llama3_2_3B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
@@ -40,7 +42,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-3B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -69,14 +71,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -86,6 +87,7 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2/3B_lora_single_device.yaml b/recipes/configs/llama3_2/3B_lora_single_device.yaml
index 29e021d150..451455253a 100644
--- a/recipes/configs/llama3_2/3B_lora_single_device.yaml
+++ b/recipes/configs/llama3_2/3B_lora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_2_3B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_3b
@@ -39,7 +41,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-3B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -68,14 +70,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -87,6 +88,7 @@ dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2/3B_qat_lora.yaml b/recipes/configs/llama3_2/3B_qat_lora.yaml
new file mode 100644
index 0000000000..6b69aebac2
--- /dev/null
+++ b/recipes/configs/llama3_2/3B_qat_lora.yaml
@@ -0,0 +1,115 @@
+# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
+# using a Llama3.2 3B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Llama-3.2-3B-Instruct --output-dir /tmp/Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on 2 devices, run the following command from root:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora checkpointer.checkpoint_dir=
+
+output_dir: /tmp/torchtune/llama3_2_3B/qat_lora # /tmp may be deleted by your system. Change it to your preference.
+
+# Tokenizer
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3_2.lora_llama3_2_3b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ lora_rank: 64 # higher increases accuracy and memory
+ lora_alpha: 128 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+checkpointer:
+ _component_: torchtune.training.FullModelHFCheckpointer
+ checkpoint_dir: /tmp/Llama-3.2-3B-Instruct/
+ checkpoint_files: [
+ model-00001-of-00002.safetensors,
+ model-00002-of-00002.safetensors,
+ ]
+ recipe_checkpoint: null
+ output_dir: ${output_dir}
+ model_type: LLAMA3_2
+resume_from_checkpoint: False
+save_adapter_weights_only: False
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.alpaca_cleaned_dataset
+ packed: False # True increases speed
+seed: null
+shuffle: True
+batch_size: 4
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.01
+ lr: 3e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+
+# Training
+epochs: 1
+max_steps_per_epoch: null
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+
+# Logging
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}/logs
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
+
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
+
+# QAT arguments
+quantizer:
+ _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
+ groupsize: 256
diff --git a/recipes/configs/llama3_2/3B_qlora_single_device.yaml b/recipes/configs/llama3_2/3B_qlora_single_device.yaml
index 7ffa146e51..3cc504f1d0 100644
--- a/recipes/configs/llama3_2/3B_qlora_single_device.yaml
+++ b/recipes/configs/llama3_2/3B_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_2_3B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3_2.qlora_llama3_2_3b
@@ -38,7 +40,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-3B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -67,14 +69,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -86,6 +87,7 @@ dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml b/recipes/configs/llama3_2/8B_to_1B_KD_lora_distributed.yaml
similarity index 88%
rename from recipes/configs/llama3_2/knowledge_distillation_distributed.yaml
rename to recipes/configs/llama3_2/8B_to_1B_KD_lora_distributed.yaml
index 8ef1bcbea3..877039bab0 100644
--- a/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml
+++ b/recipes/configs/llama3_2/8B_to_1B_KD_lora_distributed.yaml
@@ -10,11 +10,13 @@
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora
#
# To launch on 2 devices, run the following command from root:
-# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/knowledge_distillation_distributed
+# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/8B_to_1B_KD_lora_distributed
#
# This config works best for distilling on 2+ devices.
+output_dir: /tmp/torchtune/llama3_2_8B_to_1B/KD_lora_distributed # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
@@ -41,7 +43,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-1B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -88,14 +90,13 @@ kd_ratio: 0.5
# Training
epochs: 1
max_steps_per_epoch: null
-compile: False # pytorch compile, set to true for better perf/memory
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
# Logging
-output_dir: /tmp/kd_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
@@ -103,6 +104,7 @@ log_peak_memory_stats: False
device: cuda
dtype: bf16
enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
diff --git a/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml b/recipes/configs/llama3_2/8B_to_1B_KD_lora_single_device.yaml
similarity index 88%
rename from recipes/configs/llama3_2/knowledge_distillation_single_device.yaml
rename to recipes/configs/llama3_2/8B_to_1B_KD_lora_single_device.yaml
index e08fb8ad7a..103a649f84 100644
--- a/recipes/configs/llama3_2/knowledge_distillation_single_device.yaml
+++ b/recipes/configs/llama3_2/8B_to_1B_KD_lora_single_device.yaml
@@ -10,11 +10,13 @@
# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device
#
# To launch on a single device, run the following command from root:
-# tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device
+# tune run knowledge_distillation_single_device --config llama3_2/8B_to_1B_KD_lora_single_device
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_2_8B_to_1B/KD_lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
@@ -41,7 +43,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-1B-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -88,14 +90,13 @@ kd_ratio: 0.5
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/kd_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -105,6 +106,9 @@ dtype: bf16
# Activations Memory
enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
+
+
# Profiler (disabled)
profiler:
diff --git a/recipes/configs/llama3_2_vision/11B_full.yaml b/recipes/configs/llama3_2_vision/11B_full.yaml
index 51173f162a..5f0e970a66 100644
--- a/recipes/configs/llama3_2_vision/11B_full.yaml
+++ b/recipes/configs/llama3_2_vision/11B_full.yaml
@@ -15,6 +15,8 @@
# Single device full finetuning requires more memory optimizations. It's
# best to use 11B_full_single_device.yaml for those cases.
+output_dir: /tmp/torchtune/llama3_2_vision_11B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_11b
@@ -38,7 +40,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00005"
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_VISION
resume_from_checkpoint: False
@@ -55,7 +57,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
@@ -65,7 +67,7 @@ optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_ste
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -76,13 +78,13 @@ custom_sharded_layers: ['decoder.tok_embeddings'] # Layers to shard separately
dtype: bf16
# Logging
-output_dir: /tmp/full-llama3.2-vision--finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml
index d10afdcbfe..daa678d0e5 100644
--- a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml
+++ b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml
@@ -17,6 +17,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_2_vision_11B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_11b
@@ -40,7 +42,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00005"
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_VISION
resume_from_checkpoint: False
@@ -57,7 +59,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: bitsandbytes.optim.PagedAdamW8bit
lr: 2e-5
@@ -66,7 +68,7 @@ optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_ste
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -76,10 +78,9 @@ enable_activation_checkpointing: True # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/full-llama3.2-vision--finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/llama3_2_vision/11B_lora.yaml b/recipes/configs/llama3_2_vision/11B_lora.yaml
index b394b9ffbf..c54e8c571a 100644
--- a/recipes/configs/llama3_2_vision/11B_lora.yaml
+++ b/recipes/configs/llama3_2_vision/11B_lora.yaml
@@ -15,6 +15,8 @@
# For single device LoRA finetuning please use 11B_lora_single_device.yaml
# or 11B_qlora_single_device.yaml
+output_dir: /tmp/torchtune/llama3_2_vision_11B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.lora_llama3_2_vision_11b
@@ -44,7 +46,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00005"
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_VISION
resume_from_checkpoint: False
save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
@@ -62,7 +64,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -75,7 +77,7 @@ lr_scheduler:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -85,13 +87,13 @@ enable_activation_checkpointing: True # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/lora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml
index c248ccaee8..25682f0d2f 100644
--- a/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml
+++ b/recipes/configs/llama3_2_vision/11B_lora_single_device.yaml
@@ -13,6 +13,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_2_vision_11B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.lora_llama3_2_vision_11b
@@ -42,7 +44,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00005"
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_VISION
resume_from_checkpoint: False
save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
@@ -60,7 +62,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -73,7 +75,7 @@ lr_scheduler:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -83,13 +85,13 @@ enable_activation_checkpointing: True # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/lora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2_vision/11B_qlora.yaml b/recipes/configs/llama3_2_vision/11B_qlora.yaml
index c934e78008..6d93181726 100644
--- a/recipes/configs/llama3_2_vision/11B_qlora.yaml
+++ b/recipes/configs/llama3_2_vision/11B_qlora.yaml
@@ -14,6 +14,8 @@
# This config works best when the model is being fine-tuned on 2+ GPUs.
# For single device QLoRA finetuning please use 11B_qlora_single_device.yaml
+output_dir: /tmp/torchtune/llama3_2_vision_11B/qlora # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b
@@ -43,7 +45,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00005"
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_VISION
resume_from_checkpoint: False
save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
@@ -60,7 +62,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -72,7 +74,7 @@ lr_scheduler:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -82,13 +84,13 @@ enable_activation_checkpointing: True # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/qlora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml
index 531f27a52f..7d94c41709 100644
--- a/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml
+++ b/recipes/configs/llama3_2_vision/11B_qlora_single_device.yaml
@@ -13,6 +13,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/llama3_2_vision_11B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_11b
@@ -42,7 +44,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00005"
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-11B-Vision-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_VISION
resume_from_checkpoint: False
save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
@@ -59,7 +61,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -72,7 +74,7 @@ lr_scheduler:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -82,13 +84,13 @@ enable_activation_checkpointing: True # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/qlora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2_vision/90B_full.yaml b/recipes/configs/llama3_2_vision/90B_full.yaml
index 2ef3c271eb..9d96b966cd 100644
--- a/recipes/configs/llama3_2_vision/90B_full.yaml
+++ b/recipes/configs/llama3_2_vision/90B_full.yaml
@@ -13,6 +13,8 @@
#
# This config needs 8 GPUs to run.
+output_dir: /tmp/torchtune/llama3_2_vision_90B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.llama3_2_vision_90b
@@ -36,7 +38,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00037"
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-90B-Vision-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_VISION
resume_from_checkpoint: False
@@ -52,7 +54,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
@@ -62,7 +64,7 @@ optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_ste
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -73,13 +75,13 @@ custom_sharded_layers: ['decoder.tok_embeddings'] # Layers to shard separately
dtype: bf16
# Logging
-output_dir: /tmp/full-llama3.2-vision--finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Llama-3.2-90B-Vision-Instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2_vision/90B_lora.yaml b/recipes/configs/llama3_2_vision/90B_lora.yaml
index 970c7dab81..10b2c0d841 100644
--- a/recipes/configs/llama3_2_vision/90B_lora.yaml
+++ b/recipes/configs/llama3_2_vision/90B_lora.yaml
@@ -13,6 +13,8 @@
#
# This config works best when the model is being fine-tuned on 4+ GPUs.
+output_dir: /tmp/torchtune/llama3_2_vision_90B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.lora_llama3_2_vision_90b
@@ -42,7 +44,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00037"
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-90B-Vision-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_VISION
resume_from_checkpoint: False
save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
@@ -59,7 +61,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -72,7 +74,7 @@ lr_scheduler:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -82,13 +84,13 @@ enable_activation_checkpointing: True # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/full-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Llama-3.2-90B-Vision-Instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_2_vision/90B_qlora.yaml b/recipes/configs/llama3_2_vision/90B_qlora.yaml
index 888093d574..fe8df1ef47 100644
--- a/recipes/configs/llama3_2_vision/90B_qlora.yaml
+++ b/recipes/configs/llama3_2_vision/90B_qlora.yaml
@@ -13,6 +13,8 @@
#
# This config works best when the model is being fine-tuned on 4+ GPUs.
+output_dir: /tmp/torchtune/llama3_2_vision_90B/qlora # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.llama3_2_vision.qlora_llama3_2_vision_90b
@@ -42,7 +44,7 @@ checkpointer:
filename_format: model-{}-of-{}.safetensors
max_filename: "00037"
recipe_checkpoint: null
- output_dir: /tmp/Llama-3.2-90B-Vision-Instruct/
+ output_dir: ${output_dir}
model_type: LLAMA3_VISION
resume_from_checkpoint: False
save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
@@ -59,7 +61,7 @@ collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -71,7 +73,7 @@ lr_scheduler:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
clip_grad_norm: 1.0
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -81,13 +83,13 @@ enable_activation_checkpointing: True # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/qlora-llama3.2-vision-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Llama-3.2-90B-Vision-Instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/llama3_3/70B_full.yaml b/recipes/configs/llama3_3/70B_full.yaml
new file mode 100644
index 0000000000..8f96a5fbd7
--- /dev/null
+++ b/recipes/configs/llama3_3/70B_full.yaml
@@ -0,0 +1,138 @@
+# Config for multi-device full finetuning in full_finetune_distributed.py
+# using a Llama3.3 70B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*"
+#
+# To launch on 8 devices, run the following command from root:
+# tune run --nproc_per_node 8 full_finetune_distributed --config llama3_3/70B_full
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+# tune run --nproc_per_node 8 full_finetune_distributed --config llama3_3/70B_full checkpointer.checkpoint_dir=
+#
+# This config is only tested on an 8xA100 machine.
+#
+
+# Tokenizer
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Llama-3.3-70B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+# Dataset
+dataset:
+ _component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
+seed: null
+shuffle: True
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3_3.llama3_3_70b
+
+checkpointer:
+ _component_: torchtune.training.FullModelHFCheckpointer
+ checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
+ checkpoint_files: [
+ model-00001-of-00030.safetensors,
+ model-00002-of-00030.safetensors,
+ model-00003-of-00030.safetensors,
+ model-00004-of-00030.safetensors,
+ model-00005-of-00030.safetensors,
+ model-00006-of-00030.safetensors,
+ model-00007-of-00030.safetensors,
+ model-00008-of-00030.safetensors,
+ model-00009-of-00030.safetensors,
+ model-00010-of-00030.safetensors,
+ model-00011-of-00030.safetensors,
+ model-00012-of-00030.safetensors,
+ model-00013-of-00030.safetensors,
+ model-00014-of-00030.safetensors,
+ model-00015-of-00030.safetensors,
+ model-00016-of-00030.safetensors,
+ model-00017-of-00030.safetensors,
+ model-00018-of-00030.safetensors,
+ model-00019-of-00030.safetensors,
+ model-00020-of-00030.safetensors,
+ model-00021-of-00030.safetensors,
+ model-00022-of-00030.safetensors,
+ model-00023-of-00030.safetensors,
+ model-00024-of-00030.safetensors,
+ model-00025-of-00030.safetensors,
+ model-00026-of-00030.safetensors,
+ model-00027-of-00030.safetensors,
+ model-00028-of-00030.safetensors,
+ model-00029-of-00030.safetensors,
+ model-00030-of-00030.safetensors,
+ ]
+ recipe_checkpoint: null
+ output_dir: /tmp/Llama-3.3-70B-Instruct/
+ model_type: LLAMA3
+resume_from_checkpoint: False
+
+# Fine-tuning arguments
+batch_size: 2
+epochs: 1
+
+optimizer:
+ _component_: torch.optim.AdamW
+ lr: 2e-5
+ # Note: highly recommended to use fused=True optimizer flag
+ # with CPU offload for faster optimizer step.
+ fused: True
+
+loss:
+ _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+max_steps_per_epoch: null
+gradient_accumulation_steps: 1 # Use to increase virtual batch size
+
+
+# Training env
+device: cuda
+
+# Memory management
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
+custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
+fsdp_cpu_offload: True
+compile: False # pytorch compile, set to true for better perf/memory
+optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
+
+# Reduced precision
+dtype: bf16
+
+# Logging
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}
+output_dir: /tmp/full-llama3_3-finetune
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
diff --git a/recipes/configs/llama3_3/70B_lora.yaml b/recipes/configs/llama3_3/70B_lora.yaml
new file mode 100644
index 0000000000..901c700c22
--- /dev/null
+++ b/recipes/configs/llama3_3/70B_lora.yaml
@@ -0,0 +1,132 @@
+# Config for multi-device LoRA in lora_finetune_distributed.py
+# using a Llama3.3 70B model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*"
+#
+# This config needs 8 GPUs to run
+# tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_3/70B_lora
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3_3.lora_llama3_3_70b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ apply_lora_to_output: False
+ lora_rank: 16 # higher increases accuracy and memory
+ lora_alpha: 32 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Llama-3.3-70B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+checkpointer:
+ _component_: torchtune.training.FullModelHFCheckpointer
+ checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
+ checkpoint_files: [
+ model-00001-of-00030.safetensors,
+ model-00002-of-00030.safetensors,
+ model-00003-of-00030.safetensors,
+ model-00004-of-00030.safetensors,
+ model-00005-of-00030.safetensors,
+ model-00006-of-00030.safetensors,
+ model-00007-of-00030.safetensors,
+ model-00008-of-00030.safetensors,
+ model-00009-of-00030.safetensors,
+ model-00010-of-00030.safetensors,
+ model-00011-of-00030.safetensors,
+ model-00012-of-00030.safetensors,
+ model-00013-of-00030.safetensors,
+ model-00014-of-00030.safetensors,
+ model-00015-of-00030.safetensors,
+ model-00016-of-00030.safetensors,
+ model-00017-of-00030.safetensors,
+ model-00018-of-00030.safetensors,
+ model-00019-of-00030.safetensors,
+ model-00020-of-00030.safetensors,
+ model-00021-of-00030.safetensors,
+ model-00022-of-00030.safetensors,
+ model-00023-of-00030.safetensors,
+ model-00024-of-00030.safetensors,
+ model-00025-of-00030.safetensors,
+ model-00026-of-00030.safetensors,
+ model-00027-of-00030.safetensors,
+ model-00028-of-00030.safetensors,
+ model-00029-of-00030.safetensors,
+ model-00030-of-00030.safetensors,
+ ]
+ recipe_checkpoint: null
+ output_dir: /tmp/Llama-3.3-70B-Instruct/
+ model_type: LLAMA3
+resume_from_checkpoint: False
+save_adapter_weights_only: True # Set to false to save the whole model + adapter merged
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
+seed: null
+shuffle: True
+batch_size: 2
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.01
+ lr: 3e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+
+# Training
+epochs: 1
+max_steps_per_epoch: null
+gradient_accumulation_steps: 1 # Use to increase virtual batch size
+compile: False # pytorch compile, set to true for better perf/memory
+
+# Logging
+output_dir: /tmp/lora-llama3_3-finetune-output
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
+# custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
diff --git a/recipes/configs/llama3_3/70B_qlora.yaml b/recipes/configs/llama3_3/70B_qlora.yaml
new file mode 100644
index 0000000000..e25b196927
--- /dev/null
+++ b/recipes/configs/llama3_3/70B_qlora.yaml
@@ -0,0 +1,132 @@
+# Config for multi-device LoRA in lora_finetune_distributed.py
+# using a Llama3.3 70B model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*"
+#
+# This config needs 8 GPUs to run
+# tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_3/70B_lora
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3_3.qlora_llama3_3_70b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ apply_lora_to_output: False
+ lora_rank: 16 # higher increases accuracy and memory
+ lora_alpha: 32 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Llama-3.3-70B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+checkpointer:
+ _component_: torchtune.training.FullModelHFCheckpointer
+ checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
+ checkpoint_files: [
+ model-00001-of-00030.safetensors,
+ model-00002-of-00030.safetensors,
+ model-00003-of-00030.safetensors,
+ model-00004-of-00030.safetensors,
+ model-00005-of-00030.safetensors,
+ model-00006-of-00030.safetensors,
+ model-00007-of-00030.safetensors,
+ model-00008-of-00030.safetensors,
+ model-00009-of-00030.safetensors,
+ model-00010-of-00030.safetensors,
+ model-00011-of-00030.safetensors,
+ model-00012-of-00030.safetensors,
+ model-00013-of-00030.safetensors,
+ model-00014-of-00030.safetensors,
+ model-00015-of-00030.safetensors,
+ model-00016-of-00030.safetensors,
+ model-00017-of-00030.safetensors,
+ model-00018-of-00030.safetensors,
+ model-00019-of-00030.safetensors,
+ model-00020-of-00030.safetensors,
+ model-00021-of-00030.safetensors,
+ model-00022-of-00030.safetensors,
+ model-00023-of-00030.safetensors,
+ model-00024-of-00030.safetensors,
+ model-00025-of-00030.safetensors,
+ model-00026-of-00030.safetensors,
+ model-00027-of-00030.safetensors,
+ model-00028-of-00030.safetensors,
+ model-00029-of-00030.safetensors,
+ model-00030-of-00030.safetensors,
+ ]
+ recipe_checkpoint: null
+ output_dir: /tmp/Llama-3.3-70B-Instruct/
+ model_type: LLAMA3
+resume_from_checkpoint: False
+save_adapter_weights_only: True # Set to false to save the whole model + adapter merged
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.alpaca_dataset
+ packed: False # True increases speed
+seed: null
+shuffle: True
+batch_size: 2
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.01
+ lr: 3e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+
+# Training
+epochs: 1
+max_steps_per_epoch: null
+gradient_accumulation_steps: 1 # Use to increase virtual batch size
+compile: False # pytorch compile, set to true for better perf/memory
+
+# Logging
+output_dir: /tmp/lora-llama3_3-finetune-output
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+enable_activation_checkpointing: True # True reduces memory
+enable_activation_offloading: False # True reduces memory
+# custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
diff --git a/recipes/configs/mistral/7B_full.yaml b/recipes/configs/mistral/7B_full.yaml
index 23c82e1d71..15a6ec7b89 100644
--- a/recipes/configs/mistral/7B_full.yaml
+++ b/recipes/configs/mistral/7B_full.yaml
@@ -21,6 +21,8 @@
# Single device full finetuning requires more memory optimizations. It's
# best to use 7B_full_single_device.yaml for those cases
+output_dir: /tmp/torchtune/mistral_7B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
@@ -46,7 +48,7 @@ checkpointer:
pytorch_model-00002-of-00002.bin
]
recipe_checkpoint: null
- output_dir: /tmp/Mistral-7B-v0.1/
+ output_dir: ${output_dir}
model_type: MISTRAL
resume_from_checkpoint: False
@@ -60,8 +62,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -77,11 +79,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Mistral-7B-v0.1/
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/mistral/7B_full_low_memory.yaml b/recipes/configs/mistral/7B_full_low_memory.yaml
index 01de2f11ea..287a66dbd0 100644
--- a/recipes/configs/mistral/7B_full_low_memory.yaml
+++ b/recipes/configs/mistral/7B_full_low_memory.yaml
@@ -23,6 +23,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/mistral_7B/full_low_memory # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
@@ -48,7 +50,7 @@ checkpointer:
pytorch_model-00002-of-00002.bin
]
recipe_checkpoint: null
- output_dir: /tmp/Mistral-7B-v0.1/
+ output_dir: ${output_dir}
model_type: MISTRAL
resume_from_checkpoint: False
@@ -61,7 +63,7 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -75,16 +77,16 @@ enable_activation_offloading: True # True reduces memory
dtype: bf16
# Model compilation
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Mistral-7B-v0.1/
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/mistral/7B_full_ppo_low_memory.yaml b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml
index 310c9e5bcf..166fbeac1d 100644
--- a/recipes/configs/mistral/7B_full_ppo_low_memory.yaml
+++ b/recipes/configs/mistral/7B_full_ppo_low_memory.yaml
@@ -24,6 +24,8 @@
# tune run ppo_full_finetune_single_device --config mistral/7B_full_ppo_low_memory checkpointer.checkpoint_dir=
#
+output_dir: /tmp/torchtune/mistral_7B/full_ppo_low_memory # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
@@ -68,7 +70,7 @@ checkpointer:
]
# this is the only place where you should update `recipe_checkpoint` if resuming training
recipe_checkpoint: null
- output_dir: ${output_dir}/policy
+ output_dir: ${output_dir}
model_type: MISTRAL
# this should be setup identically to the policy model checkpointer at the start of training
@@ -115,7 +117,6 @@ reward_checkpointer:
model_type: REWARD
resume_from_checkpoint: False
-output_dir: /tmp/mistral7b-ppo-finetune
seed: null
shuffle: True
@@ -127,10 +128,10 @@ batch_size: 64
num_steps: 10000
ppo_epochs: 2
ppo_batch_size: 32
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
# Memory management and performance
-compile: True # pytorch compile, set to true for better perf/memory
+compile: True # torch.compile the model + loss, True increases speed + decreases memory
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 3e-6
@@ -176,6 +177,6 @@ kl_coeff: 0.01
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
diff --git a/recipes/configs/mistral/7B_lora.yaml b/recipes/configs/mistral/7B_lora.yaml
index f637240b34..ef3c9b0e1b 100644
--- a/recipes/configs/mistral/7B_lora.yaml
+++ b/recipes/configs/mistral/7B_lora.yaml
@@ -22,6 +22,8 @@
# or 7B_qlora_single_device.yaml for those cases
+output_dir: /tmp/torchtune/mistral_7B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
@@ -53,7 +55,7 @@ checkpointer:
pytorch_model-00002-of-00002.bin
]
recipe_checkpoint: null
- output_dir: /tmp/Mistral-7B-v0.1
+ output_dir: ${output_dir}
model_type: MISTRAL
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -74,8 +76,8 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -90,11 +92,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Mistral-7B-v0.1
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/mistral/7B_lora_single_device.yaml b/recipes/configs/mistral/7B_lora_single_device.yaml
index c11cbe1ad2..c98f23c840 100644
--- a/recipes/configs/mistral/7B_lora_single_device.yaml
+++ b/recipes/configs/mistral/7B_lora_single_device.yaml
@@ -19,6 +19,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/mistral_7B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
@@ -50,7 +52,7 @@ checkpointer:
pytorch_model-00002-of-00002.bin
]
recipe_checkpoint: null
- output_dir: /tmp/Mistral-7B-v0.1
+ output_dir: ${output_dir}
model_type: MISTRAL
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -71,8 +73,8 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -87,8 +89,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Mistral-7B-v0.1
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/mistral/7B_qlora_single_device.yaml b/recipes/configs/mistral/7B_qlora_single_device.yaml
index 536f2efdf1..353ad54187 100644
--- a/recipes/configs/mistral/7B_qlora_single_device.yaml
+++ b/recipes/configs/mistral/7B_qlora_single_device.yaml
@@ -20,6 +20,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/mistral_7B/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
@@ -51,7 +53,7 @@ checkpointer:
pytorch_model-00002-of-00002.bin
]
recipe_checkpoint: null
- output_dir: /tmp/Mistral-7B-v0.1
+ output_dir: ${output_dir}
model_type: MISTRAL
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -72,8 +74,8 @@ loss:
batch_size: 4
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -88,8 +90,7 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Mistral-7B-v0.1
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/phi3/mini_full.yaml b/recipes/configs/phi3/mini_full.yaml
index 594ffdc916..7dc954576d 100644
--- a/recipes/configs/phi3/mini_full.yaml
+++ b/recipes/configs/phi3/mini_full.yaml
@@ -17,6 +17,8 @@
# Single device full finetuning requires more memory optimizations. It's
# best to use mini_low_memory.yaml for those cases
+output_dir: /tmp/torchtune/phi3_mini/full # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.phi3.phi3_mini
@@ -36,7 +38,7 @@ checkpointer:
model-00002-of-00002.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Phi-3-mini-4k-instruct
+ output_dir: ${output_dir}
model_type: PHI3_MINI
resume_from_checkpoint: False
@@ -51,14 +53,14 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 5e-6
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -70,13 +72,13 @@ enable_activation_offloading: False # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/phi3_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Phi-3-mini-4k-instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/phi3/mini_full_low_memory.yaml b/recipes/configs/phi3/mini_full_low_memory.yaml
index 05c1db379a..8162e73c18 100644
--- a/recipes/configs/phi3/mini_full_low_memory.yaml
+++ b/recipes/configs/phi3/mini_full_low_memory.yaml
@@ -19,6 +19,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/phi3_mini/full_low_memory # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.phi3.phi3_mini
@@ -38,7 +40,7 @@ checkpointer:
model-00002-of-00002.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Phi-3-mini-4k-instruct
+ output_dir: ${output_dir}
model_type: PHI3_MINI
resume_from_checkpoint: False
@@ -53,14 +55,14 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 5e-6
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -71,13 +73,13 @@ enable_activation_offloading: True # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/phi3_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Phi-3-mini-4k-instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/phi3/mini_lora.yaml b/recipes/configs/phi3/mini_lora.yaml
index 0c13048119..429a1c2a6d 100644
--- a/recipes/configs/phi3/mini_lora.yaml
+++ b/recipes/configs/phi3/mini_lora.yaml
@@ -17,6 +17,8 @@
# For single device LoRA finetuning please use mini_lora_single_device.yaml
# or mini_qlora_single_device.yaml
+output_dir: /tmp/torchtune/phi3_mini/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.phi3.lora_phi3_mini
@@ -42,7 +44,7 @@ checkpointer:
model-00002-of-00002.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Phi-3-mini-4k-instruct
+ output_dir: ${output_dir}
model_type: PHI3_MINI
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -58,7 +60,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -69,7 +71,7 @@ lr_scheduler:
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -80,13 +82,13 @@ enable_activation_offloading: False # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/phi3_lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Phi-3-mini-4k-instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/phi3/mini_lora_single_device.yaml b/recipes/configs/phi3/mini_lora_single_device.yaml
index 3aae4f2b6c..26e5ac457f 100644
--- a/recipes/configs/phi3/mini_lora_single_device.yaml
+++ b/recipes/configs/phi3/mini_lora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/phi3_mini/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.phi3.lora_phi3_mini
@@ -40,7 +42,7 @@ checkpointer:
model-00002-of-00002.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Phi-3-mini-4k-instruct
+ output_dir: ${output_dir}
model_type: PHI3_MINI
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -56,7 +58,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -67,7 +69,7 @@ lr_scheduler:
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -80,10 +82,9 @@ enable_activation_offloading: False # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/phi3_lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Phi-3-mini-4k-instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/phi3/mini_qlora_single_device.yaml b/recipes/configs/phi3/mini_qlora_single_device.yaml
index f59a68a59d..a81e34f669 100644
--- a/recipes/configs/phi3/mini_qlora_single_device.yaml
+++ b/recipes/configs/phi3/mini_qlora_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/phi3_mini/qlora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.phi3.qlora_phi3_mini
@@ -40,7 +42,7 @@ checkpointer:
model-00002-of-00002.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Phi-3-mini-4k-instruct
+ output_dir: ${output_dir}
model_type: PHI3_MINI
resume_from_checkpoint: False
save_adapter_weights_only: False
@@ -56,7 +58,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -67,7 +69,7 @@ lr_scheduler:
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-compile: False # pytorch compile, set to true for better perf/memory
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training env
device: cuda
@@ -80,10 +82,9 @@ enable_activation_offloading: False # True reduces memory
dtype: bf16
# Logging
-output_dir: /tmp/phi3_qlora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: /tmp/Phi-3-mini-4k-instruct/logs
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/qwen2/0.5B_full.yaml b/recipes/configs/qwen2/0.5B_full.yaml
index 84336894be..093887fb59 100644
--- a/recipes/configs/qwen2/0.5B_full.yaml
+++ b/recipes/configs/qwen2/0.5B_full.yaml
@@ -17,6 +17,8 @@
# Single device full finetuning requires more memory optimizations. It's
# best to use 0.5B_full.yaml for those cases
+output_dir: /tmp/torchtune/qwen2_0_5B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
@@ -42,7 +44,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-0.5B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -56,8 +58,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -73,11 +75,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Qwen2-0.5B-Instruct-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2/0.5B_full_single_device.yaml b/recipes/configs/qwen2/0.5B_full_single_device.yaml
index 8b60a17090..4f670695ca 100644
--- a/recipes/configs/qwen2/0.5B_full_single_device.yaml
+++ b/recipes/configs/qwen2/0.5B_full_single_device.yaml
@@ -15,6 +15,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_0_5B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
@@ -40,7 +42,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-0.5B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -57,8 +59,8 @@ loss:
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
@@ -73,11 +75,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Qwen2-0.5B-Instruct-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2/0.5B_lora.yaml b/recipes/configs/qwen2/0.5B_lora.yaml
index 16e5955da3..f4ce567afb 100644
--- a/recipes/configs/qwen2/0.5B_lora.yaml
+++ b/recipes/configs/qwen2/0.5B_lora.yaml
@@ -18,6 +18,8 @@
# or 0.5B_qlora_single_device.yaml
+output_dir: /tmp/torchtune/qwen2_0_5B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
@@ -40,7 +42,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -70,14 +72,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/qwen2/0.5B_lora_single_device.yaml b/recipes/configs/qwen2/0.5B_lora_single_device.yaml
index e54db398fb..9bd95bfedc 100644
--- a/recipes/configs/qwen2/0.5B_lora_single_device.yaml
+++ b/recipes/configs/qwen2/0.5B_lora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_0_5B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
@@ -38,7 +40,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -68,14 +70,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2-0.5B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/qwen2/1.5B_full.yaml b/recipes/configs/qwen2/1.5B_full.yaml
index 37b5c0a926..04017db7ec 100644
--- a/recipes/configs/qwen2/1.5B_full.yaml
+++ b/recipes/configs/qwen2/1.5B_full.yaml
@@ -17,6 +17,8 @@
# Single device full finetuning requires more memory optimizations. It's
# best to use 1.5B_full.yaml for those cases
+output_dir: /tmp/torchtune/qwen2_1_5B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
@@ -42,7 +44,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-1.5B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -56,8 +58,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -73,11 +75,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Qwen2-1.5B-Instruct-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2/1.5B_full_single_device.yaml b/recipes/configs/qwen2/1.5B_full_single_device.yaml
index 2acdfb3810..d529629823 100644
--- a/recipes/configs/qwen2/1.5B_full_single_device.yaml
+++ b/recipes/configs/qwen2/1.5B_full_single_device.yaml
@@ -19,6 +19,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_1_5B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
@@ -45,7 +47,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-1.5B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -62,8 +64,8 @@ loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
@@ -78,11 +80,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Qwen2-1.5B-Instruct-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2/1.5B_lora.yaml b/recipes/configs/qwen2/1.5B_lora.yaml
index aea2f79e09..665e13c671 100644
--- a/recipes/configs/qwen2/1.5B_lora.yaml
+++ b/recipes/configs/qwen2/1.5B_lora.yaml
@@ -16,6 +16,8 @@
# This config works best when the model is being fine-tuned on 2+ GPUs.
+output_dir: /tmp/torchtune/qwen2_1_5B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2.lora_qwen2_1_5b
@@ -38,7 +40,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -66,14 +68,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/qwen2/1.5B_lora_single_device.yaml b/recipes/configs/qwen2/1.5B_lora_single_device.yaml
index 2c23954be3..47f6afd7bd 100644
--- a/recipes/configs/qwen2/1.5B_lora_single_device.yaml
+++ b/recipes/configs/qwen2/1.5B_lora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_1_5B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2.lora_qwen2_1_5b
@@ -38,7 +40,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -66,14 +68,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/qwen2/knowledge_distillation_distributed.yaml b/recipes/configs/qwen2/1.5_to_0.5B_KD_lora_distributed.yaml
similarity index 87%
rename from recipes/configs/qwen2/knowledge_distillation_distributed.yaml
rename to recipes/configs/qwen2/1.5_to_0.5B_KD_lora_distributed.yaml
index d94f15c54e..c9b9dfd5e0 100644
--- a/recipes/configs/qwen2/knowledge_distillation_distributed.yaml
+++ b/recipes/configs/qwen2/1.5_to_0.5B_KD_lora_distributed.yaml
@@ -10,11 +10,13 @@
# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora
#
# To launch on 2 devices, run the following command from root:
-# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed
+# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/1.5_to_0.5B_KD_lora_distributed
#
# This config works best for distilling on 2+ devices.
+output_dir: /tmp/torchtune/qwen2_1_5_to_0_5B/KD_lora_distributed # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
@@ -39,7 +41,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-0.5B-Instruct-kd
+ output_dir: ${output_dir}
model_type: QWEN2
teacher_checkpointer:
@@ -81,14 +83,13 @@ kd_ratio: 0.5
# Training
epochs: 1
max_steps_per_epoch: null
-compile: False # pytorch compile, set to true for better perf/memory
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
# Logging
-output_dir: /tmp/qwen_kd
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
@@ -96,6 +97,7 @@ log_peak_memory_stats: False
device: cuda
dtype: bf16
enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
diff --git a/recipes/configs/qwen2/knowledge_distillation_single_device.yaml b/recipes/configs/qwen2/1.5_to_0.5B_KD_lora_single_device.yaml
similarity index 87%
rename from recipes/configs/qwen2/knowledge_distillation_single_device.yaml
rename to recipes/configs/qwen2/1.5_to_0.5B_KD_lora_single_device.yaml
index 70c3496d0e..b7eda6df3e 100644
--- a/recipes/configs/qwen2/knowledge_distillation_single_device.yaml
+++ b/recipes/configs/qwen2/1.5_to_0.5B_KD_lora_single_device.yaml
@@ -10,11 +10,13 @@
# tune run lora_finetune_single_device --config qwen2/1.5B_lora_single_device
#
# To launch on a single device, run the following command from root:
-# tune run knowledge_distillation_single_device --config qwen2/knowledge_distillation_single_device
+# tune run knowledge_distillation_single_device --config qwen2/1.5_to_0.5B_KD_lora_single_device
#
# This config works only for distilling on a single device.
+output_dir: /tmp/torchtune/qwen2_1_5_to_0_5B/KD_lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
@@ -39,7 +41,7 @@ checkpointer:
model.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-0.5B-Instruct
+ output_dir: ${output_dir}
model_type: QWEN2
teacher_checkpointer:
@@ -81,14 +83,13 @@ kd_ratio: 0.5
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/qwen_kd
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
@@ -96,6 +97,9 @@ log_peak_memory_stats: True
device: cuda
dtype: bf16
enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
+
+
# Profiler (disabled)
profiler:
diff --git a/recipes/configs/qwen2/7B_full.yaml b/recipes/configs/qwen2/7B_full.yaml
index 20d74346e1..ec82a0d701 100644
--- a/recipes/configs/qwen2/7B_full.yaml
+++ b/recipes/configs/qwen2/7B_full.yaml
@@ -17,6 +17,8 @@
# Single device full finetuning requires more memory optimizations. It's
# best to use 7B_full.yaml for those cases
+output_dir: /tmp/torchtune/qwen2_7B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
@@ -45,7 +47,7 @@ checkpointer:
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-7B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -59,8 +61,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -76,11 +78,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Qwen2-7B-Instruct-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2/7B_full_single_device.yaml b/recipes/configs/qwen2/7B_full_single_device.yaml
index cff3244b18..0b01526ba4 100644
--- a/recipes/configs/qwen2/7B_full_single_device.yaml
+++ b/recipes/configs/qwen2/7B_full_single_device.yaml
@@ -19,6 +19,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_7B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
@@ -47,7 +49,7 @@ checkpointer:
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-7B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -61,8 +63,8 @@ optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_step
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
@@ -77,11 +79,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Qwen2-7B-Instruct-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2/7B_lora.yaml b/recipes/configs/qwen2/7B_lora.yaml
index 779e3fdc49..1da8e0de4d 100644
--- a/recipes/configs/qwen2/7B_lora.yaml
+++ b/recipes/configs/qwen2/7B_lora.yaml
@@ -18,6 +18,8 @@
# or 7B_qlora_single_device.yaml
+output_dir: /tmp/torchtune/qwen2_7B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2.lora_qwen2_7b
@@ -44,7 +46,7 @@ checkpointer:
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-7B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -72,14 +74,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2-7B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/qwen2/7B_lora_single_device.yaml b/recipes/configs/qwen2/7B_lora_single_device.yaml
index d8c576fc41..082be9a3fd 100644
--- a/recipes/configs/qwen2/7B_lora_single_device.yaml
+++ b/recipes/configs/qwen2/7B_lora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_7B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2.lora_qwen2_7b
@@ -42,7 +44,7 @@ checkpointer:
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2-7B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -70,14 +72,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2-7B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
diff --git a/recipes/configs/qwen2_5/0.5B_full.yaml b/recipes/configs/qwen2_5/0.5B_full.yaml
index 1298c058e9..c415425d5b 100644
--- a/recipes/configs/qwen2_5/0.5B_full.yaml
+++ b/recipes/configs/qwen2_5/0.5B_full.yaml
@@ -13,6 +13,8 @@
#
# This config is for fine-tuning on 2+ GPUs.
+output_dir: /tmp/torchtune/qwen2_5_0_5B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.qwen2_5.qwen2_5_0_5b
@@ -30,7 +32,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2.5-0.5B-Instruct
checkpoint_files: [model.safetensors]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2.5-0.5B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -45,7 +47,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -64,13 +66,13 @@ dtype: bf16
compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2.5-0.5B-Instruct-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/0.5B_full_single_device.yaml b/recipes/configs/qwen2_5/0.5B_full_single_device.yaml
index 39dfb2f8a0..2ac3a79f00 100644
--- a/recipes/configs/qwen2_5/0.5B_full_single_device.yaml
+++ b/recipes/configs/qwen2_5/0.5B_full_single_device.yaml
@@ -13,6 +13,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_5_0_5B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.qwen2_5.qwen2_5_0_5b
@@ -30,7 +32,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2.5-0.5B-Instruct
checkpoint_files: [model.safetensors]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2.5-0.5B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -45,7 +47,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -64,13 +66,13 @@ dtype: bf16
compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2.5-0.5B-Instruct-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/0.5B_lora.yaml b/recipes/configs/qwen2_5/0.5B_lora.yaml
index 50fe1a0a28..704aa7ca80 100644
--- a/recipes/configs/qwen2_5/0.5B_lora.yaml
+++ b/recipes/configs/qwen2_5/0.5B_lora.yaml
@@ -13,6 +13,8 @@
#
# This config is for fine-tuning on 2+ GPUs.
+output_dir: /tmp/torchtune/qwen2_5_0_5B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_0_5b
@@ -35,7 +37,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2.5-0.5B-Instruct
checkpoint_files: [model.safetensors]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2.5-0.5B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -50,7 +52,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -72,13 +74,13 @@ dtype: bf16
compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2.5-0.5B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/0.5B_lora_single_device.yaml b/recipes/configs/qwen2_5/0.5B_lora_single_device.yaml
index fa507e3414..20ceb6536d 100644
--- a/recipes/configs/qwen2_5/0.5B_lora_single_device.yaml
+++ b/recipes/configs/qwen2_5/0.5B_lora_single_device.yaml
@@ -13,6 +13,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_5_0_5B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_0_5b
@@ -35,7 +37,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2.5-0.5B-Instruct
checkpoint_files: [model.safetensors]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2.5-0.5B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -50,7 +52,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -72,13 +74,13 @@ dtype: bf16
compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2.5-0.5B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/1.5B_full.yaml b/recipes/configs/qwen2_5/1.5B_full.yaml
index e0fb09c152..431c1b519a 100644
--- a/recipes/configs/qwen2_5/1.5B_full.yaml
+++ b/recipes/configs/qwen2_5/1.5B_full.yaml
@@ -13,6 +13,8 @@
#
# This config is for fine-tuning on 2+ GPUs.
+output_dir: /tmp/torchtune/qwen2_5_1_5B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.qwen2_5.qwen2_5_1_5b_instruct
@@ -30,7 +32,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2.5-1.5B-Instruct
checkpoint_files: [model.safetensors]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2.5-1.5B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -45,7 +47,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -64,13 +66,13 @@ dtype: bf16
compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2.5-1.5B-Instruct-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/1.5B_full_single_device.yaml b/recipes/configs/qwen2_5/1.5B_full_single_device.yaml
index 480249631d..d48176616d 100644
--- a/recipes/configs/qwen2_5/1.5B_full_single_device.yaml
+++ b/recipes/configs/qwen2_5/1.5B_full_single_device.yaml
@@ -17,6 +17,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_5_1_5B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.qwen2_5.qwen2_5_1_5b_instruct
@@ -34,7 +36,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2.5-1.5B-Instruct
checkpoint_files: [model.safetensors]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2.5-1.5B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -49,7 +51,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 4
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
+gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 2e-5
@@ -67,13 +69,13 @@ dtype: bf16
compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2.5-1.5B-Instruct-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/1.5B_lora.yaml b/recipes/configs/qwen2_5/1.5B_lora.yaml
index 8d530c3670..84d9e2c9bd 100644
--- a/recipes/configs/qwen2_5/1.5B_lora.yaml
+++ b/recipes/configs/qwen2_5/1.5B_lora.yaml
@@ -13,6 +13,8 @@
#
# This config is for fine-tuning on 2+ GPUs.
+output_dir: /tmp/torchtune/qwen2_5_1_5B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_1_5b_instruct
@@ -35,7 +37,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2.5-1.5B-Instruct
checkpoint_files: [model.safetensors]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2.5-1.5B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -50,7 +52,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -71,13 +73,13 @@ dtype: bf16
compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2.5-1.5B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/1.5B_lora_single_device.yaml b/recipes/configs/qwen2_5/1.5B_lora_single_device.yaml
index e784066fe0..579c39bfec 100644
--- a/recipes/configs/qwen2_5/1.5B_lora_single_device.yaml
+++ b/recipes/configs/qwen2_5/1.5B_lora_single_device.yaml
@@ -13,6 +13,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_5_1_5B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_1_5b_instruct
@@ -35,7 +37,7 @@ checkpointer:
checkpoint_dir: /tmp/Qwen2.5-1.5B-Instruct
checkpoint_files: [model.safetensors]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2.5-1.5B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -50,7 +52,7 @@ shuffle: True
epochs: 1
max_steps_per_epoch: null
batch_size: 2
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
+gradient_accumulation_steps: 8 # Use to increase effective batch size
optimizer:
_component_: torch.optim.AdamW
fused: True
@@ -71,13 +73,13 @@ dtype: bf16
compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2.5-1.5B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/14B_lora_single_device.yaml b/recipes/configs/qwen2_5/14B_lora_single_device.yaml
index 2886a56664..93220bb466 100644
--- a/recipes/configs/qwen2_5/14B_lora_single_device.yaml
+++ b/recipes/configs/qwen2_5/14B_lora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_5_14B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_14b_instruct
@@ -46,7 +48,7 @@ checkpointer:
model-00008-of-00008.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2_5-14B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -74,14 +76,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2_5-14B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
diff --git a/recipes/configs/qwen2_5/32B_lora.yaml b/recipes/configs/qwen2_5/32B_lora.yaml
index bed3868365..6e5ab5174f 100644
--- a/recipes/configs/qwen2_5/32B_lora.yaml
+++ b/recipes/configs/qwen2_5/32B_lora.yaml
@@ -14,6 +14,8 @@
# tune run --nnodes 1 --nproc_per_node 8 lora_finetune_distributed --config qwen2_5/32B_lora checkpointer.checkpoint_dir=
+output_dir: /tmp/torchtune/qwen2_5_32B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_32b_instruct
@@ -53,7 +55,7 @@ checkpointer:
model-00017-of-00017.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2_5-32B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -81,14 +83,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2_5-32B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
@@ -97,6 +98,7 @@ device: cuda
dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory
+# custom_sharded_layers: ['tok_embeddings'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
diff --git a/recipes/configs/qwen2_5/3B_full.yaml b/recipes/configs/qwen2_5/3B_full.yaml
index 7267dd5efe..217769ad8c 100644
--- a/recipes/configs/qwen2_5/3B_full.yaml
+++ b/recipes/configs/qwen2_5/3B_full.yaml
@@ -17,6 +17,8 @@
# Single device full finetuning requires more memory optimizations. It's
# best to use 3B_full_single_device.yaml for those cases
+output_dir: /tmp/torchtune/qwen2_5_3B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
@@ -43,7 +45,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2_5-3B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -57,8 +59,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -74,11 +76,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Qwen2_5-3B-Instruct-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/3B_full_single_device.yaml b/recipes/configs/qwen2_5/3B_full_single_device.yaml
index ef8d283098..38b1645817 100644
--- a/recipes/configs/qwen2_5/3B_full_single_device.yaml
+++ b/recipes/configs/qwen2_5/3B_full_single_device.yaml
@@ -19,6 +19,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_5_3B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
@@ -45,7 +47,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2_5-3B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -59,8 +61,8 @@ optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_step
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
@@ -75,11 +77,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Qwen2_5-3B-Instruct-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/3B_lora.yaml b/recipes/configs/qwen2_5/3B_lora.yaml
index 6cde39b86e..152c8da204 100644
--- a/recipes/configs/qwen2_5/3B_lora.yaml
+++ b/recipes/configs/qwen2_5/3B_lora.yaml
@@ -17,6 +17,8 @@
# For single device LoRA finetuning please use 3B_lora_single_device.yaml
+output_dir: /tmp/torchtune/qwen2_5_3B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_3b
@@ -40,7 +42,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2_5-3B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -68,14 +70,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2_5-3B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
diff --git a/recipes/configs/qwen2_5/3B_lora_single_device.yaml b/recipes/configs/qwen2_5/3B_lora_single_device.yaml
index bd3cb9fa68..98ed48f06f 100644
--- a/recipes/configs/qwen2_5/3B_lora_single_device.yaml
+++ b/recipes/configs/qwen2_5/3B_lora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_5_3B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_3b
@@ -39,7 +41,7 @@ checkpointer:
model-00002-of-00002.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2_5-3B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -67,14 +69,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2_5-3B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
diff --git a/recipes/configs/qwen2_5/72B_lora.yaml b/recipes/configs/qwen2_5/72B_lora.yaml
index fc7ad2dc7d..41ff800c5a 100644
--- a/recipes/configs/qwen2_5/72B_lora.yaml
+++ b/recipes/configs/qwen2_5/72B_lora.yaml
@@ -14,6 +14,8 @@
# tune run --nnodes 1 --nproc_per_node 8 lora_finetune_distributed --config qwen2_5/72B_lora checkpointer.checkpoint_dir=
+output_dir: /tmp/torchtune/qwen2_5_72B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_72b_instruct
@@ -73,7 +75,7 @@ checkpointer:
model-00037-of-00037.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2_5-72B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -101,14 +103,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2_5-72B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
@@ -117,6 +118,7 @@ device: cuda
dtype: bf16
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
+# custom_sharded_layers: ['tok_embeddings'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
diff --git a/recipes/configs/qwen2_5/7B_full.yaml b/recipes/configs/qwen2_5/7B_full.yaml
index e1de8d5584..d071687103 100644
--- a/recipes/configs/qwen2_5/7B_full.yaml
+++ b/recipes/configs/qwen2_5/7B_full.yaml
@@ -17,6 +17,8 @@
# Single device full finetuning requires more memory optimizations. It's
# best to use 7B_full_single_device.yaml for those cases
+output_dir: /tmp/torchtune/qwen2_5_7B/full # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
@@ -45,7 +47,7 @@ checkpointer:
model-00004-of-00004.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2_5-7B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -59,8 +61,8 @@ optimizer:
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
# Training env
@@ -76,11 +78,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Qwen2_5-7B-Instruct-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/7B_full_single_device.yaml b/recipes/configs/qwen2_5/7B_full_single_device.yaml
index 3bc3428410..e6ebbcb8cf 100644
--- a/recipes/configs/qwen2_5/7B_full_single_device.yaml
+++ b/recipes/configs/qwen2_5/7B_full_single_device.yaml
@@ -19,6 +19,8 @@
#
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_5_7B/full_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2_5.qwen2_5_tokenizer
@@ -47,7 +49,7 @@ checkpointer:
model-00004-of-00004.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2_5-7B-Instruct-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -61,8 +63,8 @@ optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_step
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
-gradient_accumulation_steps: 1 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Training environment
device: cuda
@@ -77,11 +79,11 @@ dtype: bf16
# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
-output_dir: /tmp/Qwen2_5-7B-Instruct-finetune
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
+
# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
diff --git a/recipes/configs/qwen2_5/7B_lora.yaml b/recipes/configs/qwen2_5/7B_lora.yaml
index 460c67d26f..f78c522e8a 100644
--- a/recipes/configs/qwen2_5/7B_lora.yaml
+++ b/recipes/configs/qwen2_5/7B_lora.yaml
@@ -17,6 +17,8 @@
# For single device LoRA finetuning please use 7B_lora_single_device.yaml
+output_dir: /tmp/torchtune/qwen2_5_7B/lora # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_7b_instruct
@@ -43,7 +45,7 @@ checkpointer:
model-00004-of-00004.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2_5-7B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -71,14 +73,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2_5-7B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
diff --git a/recipes/configs/qwen2_5/7B_lora_single_device.yaml b/recipes/configs/qwen2_5/7B_lora_single_device.yaml
index 5c3353f7e9..3accf271d3 100644
--- a/recipes/configs/qwen2_5/7B_lora_single_device.yaml
+++ b/recipes/configs/qwen2_5/7B_lora_single_device.yaml
@@ -16,6 +16,8 @@
# This config works only for training on single device.
+output_dir: /tmp/torchtune/qwen2_5_7B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.
+
# Model Arguments
model:
_component_: torchtune.models.qwen2_5.lora_qwen2_5_7b_instruct
@@ -42,7 +44,7 @@ checkpointer:
model-00004-of-00004.safetensors,
]
recipe_checkpoint: null
- output_dir: /tmp/Qwen2_5-7B-Instruct-lora-finetune
+ output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False
@@ -70,14 +72,13 @@ loss:
# Training
epochs: 1
max_steps_per_epoch: null
-gradient_accumulation_steps: 8 # Use to increase virtual batch size
-compile: False # pytorch compile, set to true for better perf/memory
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+compile: False # torch.compile the model + loss, True increases speed + decreases memory
# Logging
-output_dir: /tmp/Qwen2_5-7B-Instruct-lora-finetune
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
- log_dir: ${output_dir}
+ log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False
diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py
index aa48920815..4a227701d7 100644
--- a/recipes/full_finetune_distributed.py
+++ b/recipes/full_finetune_distributed.py
@@ -122,15 +122,6 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)
- if (
- cfg.get("fsdp_cpu_offload", False)
- and cfg.optimizer.get("fused", False)
- and not utils.torch_version_ge("2.4.0")
- ):
- raise RuntimeError(
- "Using fused optimizer on CPU is only supported in PyTorch nightly."
- )
-
# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
@@ -955,7 +946,7 @@ def recipe_main(cfg: DictConfig) -> None:
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
- init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
+ init_process_group("cuda:nccl,cpu:gloo")
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py
index c920f4b069..d74bc40e2b 100644
--- a/recipes/knowledge_distillation_distributed.py
+++ b/recipes/knowledge_distillation_distributed.py
@@ -971,11 +971,11 @@ def recipe_main(cfg: DictConfig) -> None:
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
+ init_process_group("cuda:nccl,cpu:gloo")
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()
- init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
config.log_config(recipe_name="KDRecipeDistributed", cfg=cfg)
diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py
index ab37623cc1..993fd2ac1f 100644
--- a/recipes/lora_dpo_distributed.py
+++ b/recipes/lora_dpo_distributed.py
@@ -33,7 +33,6 @@
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
-from torchtune.rlhf.loss import SimPOLoss
from tqdm import tqdm
log = utils.get_logger("DEBUG")
@@ -59,6 +58,18 @@ class LoRADPORecipeDistributed(FTRecipeInterface):
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.
+ - Activation Offloading. This can be controlled using the ``enable_activation_offloading``
+ flag. Activation offloading is a technique similar to activations checkpointing that helps
+ reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
+ checkpointing drops the activation in the forward to recompute it later in the backward,
+ activations offloading will drop the activation in the forward to the CPU and bring it
+ back during the backward pass. As always, there is a tradeoff--these savings in memory can
+ come at the cost of training performance and CPU resources. To recover some runtime cost,
+ we've added an option to enable offloading on a different stream to permit overlapping with
+ the computation. This option is currently only available on PyTorch 2.5 or later and will
+ be enabled by default if an acceptable torch version is found. Activation offloading can be
+ used in conjunction with activation checkpointing.
+
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
@@ -97,7 +108,6 @@ class LoRADPORecipeDistributed(FTRecipeInterface):
The following losses are supported in this recipe:
- :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO).
- :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO).
- - :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO).
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
has example commands for how to kick-off training.
@@ -109,6 +119,8 @@ class LoRADPORecipeDistributed(FTRecipeInterface):
ValueError: If ``dtype`` is set to fp16.
ValueError: If world_size is 1
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
+ RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
+ RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
"""
def __init__(self, cfg: DictConfig) -> None:
@@ -135,8 +147,28 @@ def __init__(self, cfg: DictConfig) -> None:
)
self._log_peak_memory_stats = False
- # training attributes
- self._enable_activation_checkpointing = cfg.enable_activation_checkpointing
+ # activation checkpointing/offloading
+ self._enable_activation_checkpointing = cfg.get(
+ "enable_activation_checkpointing", False
+ )
+ self._enable_activation_offloading = cfg.get(
+ "enable_activation_offloading", False
+ )
+ if self._enable_activation_offloading:
+ if self._device.type != "cuda":
+ raise RuntimeError(
+ "enable_activation_offloading should only be True when training on CUDA"
+ )
+ if not self._enable_activation_checkpointing:
+ raise RuntimeError(
+ "enable_activation_offloading should only be True when enable_activation_checkpointing is True"
+ )
+ elif self._enable_activation_checkpointing:
+ utils.log_rank_zero(
+ log,
+ "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
+ "Enabling activation offloading should reduce memory further.",
+ )
# These attributes constitute the recipe state and are updated by ``load_checkpoint``
# when ``resume_from_checkpoint`` is ``True``
@@ -232,6 +264,8 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
+ enable_activation_offloading=self._enable_activation_offloading,
+ custom_sharded_layers=cfg.get("custom_sharded_layers", None),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
@@ -293,6 +327,7 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
+ enable_activation_offloading: bool,
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
base_model_state_dict: Dict[str, Any],
@@ -396,6 +431,12 @@ def _setup_model(
lora_unexpected=lora_unexpected,
)
# Ensure no params and buffers are on meta device
+
+ # activation offloading
+ self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
+ model, enable_activation_offloading
+ )
+
training.validate_no_params_on_meta_device(model)
utils.log_rank_zero(
log,
@@ -581,14 +622,10 @@ def concatenated_forward(
# formed by concatenating an equal number of "chosen" and "rejected".
len_chosen = concatenated_input_ids.shape[0] // 2
- all_logits = model(concatenated_input_ids)
+ with self.activations_handling_ctx:
+ all_logits = model(concatenated_input_ids)
- all_log_probs = rlhf.get_batch_log_probs(
- all_logits,
- concatenated_labels,
- # see :class:`~torchtune.rlhf.loss.dpo.SimPOLoss`
- return_average_logprobs=isinstance(self._loss_fn, SimPOLoss),
- )
+ all_log_probs = rlhf.get_batch_log_probs(all_logits, concatenated_labels)
chosen_log_probs = all_log_probs[:len_chosen]
rejected_log_probs = all_log_probs[len_chosen:]
@@ -647,26 +684,19 @@ def train(self) -> None:
# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
del policy_chosen_logits, policy_rejected_logits
- if isinstance(self._loss_fn, SimPOLoss):
- loss, chosen_rewards, rejected_rewards = self._loss_fn(
- policy_chosen_log_probs, policy_rejected_log_probs
- )
- else:
- # reference based losses (e.g. DPO) explicitly regularize the objective fn based on
- # the reference model's output - reference-free losses (such as SimPO) don't require this.
- with torch.no_grad(), disable_adapter(self._model):
- (
- reference_chosen_log_probs,
- reference_rejected_log_probs,
- _,
- _,
- ) = self.concatenated_forward(self._model, batch)
- loss, chosen_rewards, rejected_rewards = self._loss_fn(
- policy_chosen_log_probs,
- policy_rejected_log_probs,
+ with torch.no_grad(), disable_adapter(self._model):
+ (
reference_chosen_log_probs,
reference_rejected_log_probs,
- )
+ _,
+ _,
+ ) = self.concatenated_forward(self._model, batch)
+ loss, chosen_rewards, rejected_rewards = self._loss_fn(
+ policy_chosen_log_probs,
+ policy_rejected_log_probs,
+ reference_chosen_log_probs,
+ reference_rejected_log_probs,
+ )
loss = loss.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
@@ -752,11 +782,11 @@ def recipe_main(cfg: DictConfig) -> None:
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
+ init_process_group("cuda:nccl,cpu:gloo")
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()
- init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
config.log_config(recipe_name="LoRADPORecipeDistributed", cfg=cfg)
diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py
index 53b3b67be5..17f985e75f 100644
--- a/recipes/lora_dpo_single_device.py
+++ b/recipes/lora_dpo_single_device.py
@@ -30,7 +30,6 @@
)
from torchtune.recipe_interfaces import FTRecipeInterface
-from torchtune.rlhf.loss import SimPOLoss
from tqdm import tqdm
log = utils.get_logger("DEBUG")
@@ -44,9 +43,11 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface):
This recipe supports:
- Activation checkpointing. This is enabled by default but is configurable.
+ - Activation offloading - this is enabled by default and should only be used alongside
+ activation checkpointing.
- Full bf16 training for supported HW architectures. We currently check bf16 support via
- the `torch.cuda.is_bf16_supported` API. This is disabled by default but can be enabled via
- setting `dtype=bf16` in configuration.
+ the `torch.cuda.is_bf16_supported` API. This is disabled by default but can be enabled via
+ setting `dtype=bf16` in configuration.
- Checkpointing: of LoRA adapter parameters and their optimizer states. When resuming
from a checkpoint, the adapter parameters are loaded from the checkpoint along
with the base model weights. Note that intra-epoch resumption is not supported.
@@ -56,7 +57,6 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface):
The following losses are supported in this recipe:
- :class:`~torchtune.rlhf.loss.DPOLoss`: Direct Preference Optimization (DPO).
- :class:`~torchtune.rlhf.loss.RSOPLoss`: Rejection Sampling Optimization (RSO).
- - :class:`~torchtune.rlhf.loss.SimPOLoss`: Simple Preference Optimization (SimPO).
Assumptions:
- Checkpoints are ONLY saved at epoch boundaries. In case of failure, work done
@@ -74,6 +74,8 @@ class LoRADPORecipeSingleDevice(FTRecipeInterface):
Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
+ RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
+ RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
"""
@@ -101,6 +103,29 @@ def __init__(self, cfg: DictConfig) -> None:
)
self._log_peak_memory_stats = False
+ # activation checkpointing/offloading
+ self._enable_activation_checkpointing = cfg.get(
+ "enable_activation_checkpointing", False
+ )
+ self._enable_activation_offloading = cfg.get(
+ "enable_activation_offloading", False
+ )
+ if self._enable_activation_offloading:
+ if self._device.type != "cuda":
+ raise RuntimeError(
+ "enable_activation_offloading should only be True when training on CUDA"
+ )
+ if not self._enable_activation_checkpointing:
+ raise RuntimeError(
+ "enable_activation_offloading should only be True when enable_activation_checkpointing is True"
+ )
+ elif self._enable_activation_checkpointing:
+ utils.log_rank_zero(
+ log,
+ "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
+ "Enabling activation offloading should reduce memory further.",
+ )
+
# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = training.set_seed(seed=cfg.seed)
@@ -190,6 +215,7 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
+ enable_activation_offloading=self._enable_activation_offloading,
compile_model=cfg.compile,
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=(
@@ -251,6 +277,7 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
+ enable_activation_offloading: bool,
compile_model: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
@@ -289,6 +316,11 @@ def _setup_model(
lora_unexpected=lora_unexpected,
)
+ # activation offloading
+ self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
+ model, enable_activation_offloading
+ )
+
log.info(f"Model is initialized with precision {self._dtype}.")
# Compile model, if enabled.
@@ -443,14 +475,10 @@ def concatenated_forward(
# formed by concatenating an equal number of "chosen" and "rejected".
len_chosen = concatenated_input_ids.shape[0] // 2
- all_logits = model(concatenated_input_ids)
+ with self.activations_handling_ctx:
+ all_logits = model(concatenated_input_ids)
- all_log_probs = rlhf.get_batch_log_probs(
- all_logits,
- concatenated_labels,
- # see :class:`~torchtune.rlhf.loss.dpo.SimPOLoss`
- return_average_logprobs=isinstance(self._loss_fn, SimPOLoss),
- )
+ all_log_probs = rlhf.get_batch_log_probs(all_logits, concatenated_labels)
chosen_log_probs = all_log_probs[:len_chosen]
rejected_log_probs = all_log_probs[len_chosen:]
@@ -503,26 +531,19 @@ def train(self) -> None:
# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
del policy_chosen_logits, policy_rejected_logits
- if isinstance(self._loss_fn, SimPOLoss):
- loss, chosen_rewards, rejected_rewards = self._loss_fn(
- policy_chosen_log_probs, policy_rejected_log_probs
- )
- else:
- # reference based losses (e.g. DPO) explicitly regularize the objective fn based on
- # the reference model's output - reference-free losses (such as SimPO) don't require this.
- with torch.no_grad(), disable_adapter(self._model):
- (
- reference_chosen_log_probs,
- reference_rejected_log_probs,
- _,
- _,
- ) = self.concatenated_forward(self._model, batch)
- loss, chosen_rewards, rejected_rewards = self._loss_fn(
- policy_chosen_log_probs,
- policy_rejected_log_probs,
+ with torch.no_grad(), disable_adapter(self._model):
+ (
reference_chosen_log_probs,
reference_rejected_log_probs,
- )
+ _,
+ _,
+ ) = self.concatenated_forward(self._model, batch)
+ loss, chosen_rewards, rejected_rewards = self._loss_fn(
+ policy_chosen_log_probs,
+ policy_rejected_log_probs,
+ reference_chosen_log_probs,
+ reference_rejected_log_probs,
+ )
loss = loss.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py
index 45209814a0..e95fbb40c6 100644
--- a/recipes/lora_finetune_distributed.py
+++ b/recipes/lora_finetune_distributed.py
@@ -273,6 +273,7 @@ def setup(self, cfg: DictConfig) -> None:
cfg_model=cfg.model,
enable_activation_checkpointing=self._enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
+ custom_sharded_layers=cfg.get("custom_sharded_layers", None),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
@@ -919,11 +920,11 @@ def recipe_main(cfg: DictConfig) -> None:
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
+ init_process_group("cuda:nccl,cpu:gloo")
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()
- init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg)
diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py
index ab9cea3eda..e005dc0247 100644
--- a/recipes/qat_distributed.py
+++ b/recipes/qat_distributed.py
@@ -133,15 +133,6 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)
- if (
- cfg.get("fsdp_cpu_offload", False)
- and cfg.optimizer.get("fused", False)
- and not utils.torch_version_ge("2.4.0")
- ):
- raise RuntimeError(
- "Using fused optimizer on CPU is only supported in PyTorch nightly."
- )
-
# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
@@ -944,7 +935,7 @@ def recipe_main(cfg: DictConfig) -> None:
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
- init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
+ init_process_group("cuda:nccl,cpu:gloo")
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
diff --git a/recipes/qat_lora_finetune_distributed.py b/recipes/qat_lora_finetune_distributed.py
new file mode 100644
index 0000000000..f9b1fc991f
--- /dev/null
+++ b/recipes/qat_lora_finetune_distributed.py
@@ -0,0 +1,972 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys
+import time
+
+from functools import partial
+from typing import Any, Dict, List, Optional, Tuple, Union
+from warnings import warn
+
+import torch
+from omegaconf import DictConfig, ListConfig
+
+from torch import nn
+from torch.distributed import destroy_process_group, init_process_group
+
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader, DistributedSampler
+from torchtune import config, modules, training, utils
+from torchtune.config._utils import _get_component_from_path
+from torchtune.data import padded_collate_packed
+from torchtune.datasets import ConcatDataset
+from torchtune.modules.peft import (
+ DoRALinear,
+ get_adapter_params,
+ get_adapter_state_dict,
+ get_lora_module_names,
+ get_merged_lora_ckpt,
+ LoRALinear,
+ set_trainable_params,
+ validate_missing_and_unexpected_for_lora,
+)
+from torchtune.recipe_interfaces import FTRecipeInterface
+from torchtune.training import DummyProfiler, PROFILER_KEY
+from torchtune.training.quantization import swap_lora_linear_with_qat
+
+from tqdm import tqdm
+
+log = utils.get_logger("DEBUG")
+
+
+class QATLoRAFinetuneRecipeDistributed(FTRecipeInterface):
+ """
+ Distributed quantization-aware training (QAT) and LoRA finetuning recipe for dense transformer-based
+ LLMs such as Llama2. This recipe supports distributed training and can be run on a single node (1 to
+ 8 GPUs). Only compatible with torchao 0.7+.
+
+ Features:
+ - Quantization-aware training (QAT). Perform fake quantization on weights and/or activations
+ during finetuning, with the goal of ultimately producing a quantized model with minimal
+ accuracy degradation. This recipe produces an unquantized model in the original dtype,
+ which can then be quantized separately.
+
+ - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
+ is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
+ done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
+ ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
+ DDP is currently not supported. Training on CPU is not supported.
+
+ - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
+ flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
+ activations in memory and instead recompute them during the backward pass. This is especially
+ helpful for larger batch sizes when you're memory constrained. But these savings in memory
+ come at the cost of training performance. In most cases training can slow-down quite a bit as
+ a result of this activation recomputation.
+
+ - Activation Offloading. This can be controlled using the ``enable_activation_offloading``
+ flag. Activation offloading is a technique similar to activations checkpointing that helps
+ reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
+ checkpointing drops the activation in the forward to recompute it later in the backward,
+ activations offloading will drop the activation in the forward to the CPU and bring it
+ back during the backward pass. As always, there is a tradeoff--these savings in memory can
+ come at the cost of training performance and CPU resources. To recover some runtime cost,
+ we've added an option to enable offloading on a different stream to permit overlapping with
+ the computation. This option is currently only available on PyTorch 2.5.0 or later and will be
+ enabled by default if an acceptable torch version is found. Activation offloading can be used in
+ conjunction with activation checkpointing.
+
+ - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
+ flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
+ most cases this should halve the memory footprint of full precision (fp32) training, without
+ loss in model quality (will depend on the model, training data and other settings). For
+ GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
+ precision are currently not supported.
+
+ - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
+ controlled using the ``gradient_accumulation_steps`` flag.
+
+ Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
+
+ For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
+ total batch size of 64.
+
+ Gradient accumulation is especially useful when you are memory constrained. In this case,
+ accumulating gradients might give you better training speed than enabling activation
+ checkpointing.
+
+ - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
+ training. Currently we checkpoint both the adapter weights (trainable params only) and the
+ complete merged weights (adapter weights added back to the base model). For more details
+ please take a look at our LoRA tutorial
+ (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html).
+
+ Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are
+ only saved at the end of a given epoch and used in case of resuming training. Resuming
+ training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
+ currently not supported.
+
+ For more details on the checkpointer, please take a look at
+ our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html).
+
+ - Logging. Terminal, Disk, WandB and TensorBoard are all supported.
+
+ - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
+ ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
+ ``clip_grad_norm='inf'``.
+
+ For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
+ has example commands for how to kick-off training.
+
+ Args:
+ cfg (DictConfig): OmegaConf object parsed from yaml file
+
+ Raises:
+ ValueError: If ``dtype`` is set to fp16.
+ ValueError: If world_size is 1
+ RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
+ RuntimeError: If ``left_pad_sequence`` is set as the data collator.
+ RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
+ RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
+ """
+
+ def __init__(self, cfg: DictConfig) -> None:
+ try:
+ from torchao.quantization import qat # noqa: F401
+ except ImportError as err:
+ raise ValueError(
+ "qat_lora_finetune_distributed is only compatible with torchao 0.7+"
+ ) from err
+
+ self._device = utils.get_device(device=cfg.device)
+ self._dtype = training.get_dtype(cfg.dtype, device=self._device)
+
+ if self._dtype == torch.float16:
+ raise ValueError(
+ "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
+ )
+
+ _, rank = training.get_world_size_and_rank()
+
+ # _is_rank_zero is used primarily for logging. In the future, the logger
+ # should directly take care of this
+ self._is_rank_zero = rank == 0
+
+ # logging attributes
+ self._output_dir = cfg.output_dir
+ self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
+ self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
+
+ if self._log_peak_memory_stats and self._device.type != "cuda":
+ log.info(
+ "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
+ )
+ self._log_peak_memory_stats = False
+
+ # These attributes constitute the recipe state and are updated by ``load_checkpoint``
+ # when ``resume_from_checkpoint`` is ``True``
+ self.seed = training.set_seed(seed=cfg.seed)
+ self.epochs_run = 0
+ self.total_epochs = cfg.epochs
+ self.max_steps_per_epoch = cfg.max_steps_per_epoch
+ self.global_step = 0
+ self._clip_grad_norm = cfg.get("clip_grad_norm", None)
+
+ self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
+ self._resume_from_checkpoint = cfg.resume_from_checkpoint
+ self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
+
+ # activation checkpointing/offloading
+ self._enable_activation_checkpointing = cfg.get(
+ "enable_activation_checkpointing", False
+ )
+ self._enable_activation_offloading = cfg.get(
+ "enable_activation_offloading", False
+ )
+ if self._enable_activation_offloading:
+ if self._device.type != "cuda":
+ raise RuntimeError(
+ "enable_activation_offloading should only be True when training on CUDA"
+ )
+ if not self._enable_activation_checkpointing:
+ raise RuntimeError(
+ "enable_activation_offloading should only be True when enable_activation_checkpointing is True"
+ )
+ elif (
+ self._enable_activation_checkpointing
+ and cfg.checkpointer.model_type != "LLAMA3_VISION"
+ ):
+ utils.log_rank_zero(
+ log,
+ "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
+ "Enabling activation offloading should reduce memory further.",
+ )
+
+ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
+ """
+ Extract the checkpoint state from file and validate. This includes the
+ base model weights. If resume_from_checkpoint is True, this also includes
+ the adapter weights and recipe state
+ """
+ self._checkpointer = config.instantiate(
+ cfg_checkpointer,
+ resume_from_checkpoint=self._resume_from_checkpoint,
+ )
+ checkpoint_dict = self._checkpointer.load_checkpoint()
+
+ # When resuming from checkpoint for LoRA, the recipe expects the adapter weights
+ # and recipe state to be present. The keys should match up with what ``save_checkpoint``
+ # used to create these intermediate checkpoints
+ if self._resume_from_checkpoint:
+ if training.ADAPTER_KEY not in checkpoint_dict:
+ raise ValueError(
+ "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
+ )
+ # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
+ # no need to check here
+ self._update_recipe_state(checkpoint_dict)
+ return checkpoint_dict
+
+ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
+ """
+ Updates the recipe state from checkpoint.
+ """
+ try:
+ self.epochs_run = ckpt_dict[training.EPOCHS_KEY]
+
+ # on mismatch, warn the user and prevent the override
+ if self.seed != ckpt_dict[training.SEED_KEY]:
+ warn(
+ message=(
+ "Config value for seed does not match the checkpoint value, "
+ f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}"
+ )
+ )
+ self.seed = ckpt_dict[training.SEED_KEY]
+ if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]:
+ warn(
+ message=(
+ "Config value for max_steps_per_epoch does not match the checkpoint value, "
+ f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}"
+ )
+ )
+ self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY]
+
+ # on mismatch, warn the user but allow the override
+ if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]:
+ warn(
+ message=(
+ "Config value for total_epochs does not match the checkpoint value, "
+ f"using the config value: {self.total_epochs}"
+ )
+ )
+
+ except KeyError as e:
+ raise KeyError(
+ "Checkpoint does not contain the required keys needed for updating recipe state. "
+ "Are you sure you passed in the right recipe checkpoint?"
+ ) from e
+
+ def setup(self, cfg: DictConfig) -> None:
+ """
+ Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True),
+ model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader.
+ """
+ if self._is_rank_zero:
+ self._metric_logger = config.instantiate(cfg.metric_logger)
+
+ # log config with parameter override
+ self._metric_logger.log_config(cfg)
+
+ checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
+ self._compile = cfg.get("compile", False)
+
+ self._model = self._setup_model(
+ cfg_model=cfg.model,
+ enable_activation_checkpointing=self._enable_activation_checkpointing,
+ enable_activation_offloading=self._enable_activation_offloading,
+ fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
+ reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
+ base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
+ lora_weights_state_dict=(
+ checkpoint_dict[training.ADAPTER_KEY]
+ if self._resume_from_checkpoint
+ else None
+ ),
+ quantizer_cfg=cfg.get("quantizer", None),
+ )
+ self._tokenizer = config.instantiate(cfg.tokenizer)
+
+ self._optimizer = self._setup_optimizer(
+ cfg_optimizer=cfg.optimizer,
+ opt_state_dict=(
+ checkpoint_dict[training.OPT_KEY]
+ if self._resume_from_checkpoint
+ else None
+ ),
+ )
+
+ # initialize loss
+ self._loss_fn = config.instantiate(cfg.loss)
+
+ if self._compile:
+ training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
+
+ if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
+ # set num_output_chunks for model
+ self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
+ if self._is_rank_zero:
+ log.info("Loss is initialized.")
+
+ # sampler and dataloader depend on the tokenizer and loss_fn and should be
+ # setup after all of these are setup
+ collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
+ self._sampler, self._dataloader = self._setup_data(
+ cfg_dataset=cfg.dataset,
+ shuffle=cfg.shuffle,
+ batch_size=cfg.batch_size,
+ collate_fn=collate_name,
+ )
+
+ # Finally update the recipe state which can only be correctly set after all of the
+ # other components have been initialized and updated.
+
+ # Number of training steps in each epoch depends on the number of batches produced
+ # by the dataloader and the max_steps_per_epoch param set by the user and is used
+ # for logging and tracking training state. This should be computed after the dataloader
+ # has been setup
+ self._steps_per_epoch = (
+ len(self._dataloader) // self._gradient_accumulation_steps
+ )
+ if (
+ self.max_steps_per_epoch is not None
+ and self.max_steps_per_epoch < self._steps_per_epoch
+ ):
+ self._steps_per_epoch = self.max_steps_per_epoch
+ self.global_step = self.epochs_run * self._steps_per_epoch
+
+ # Learning rate scheduler can only be set up after number of steps
+ # has been computed
+ self._lr_scheduler = self._setup_lr_scheduler(
+ cfg_lr_scheduler=cfg.lr_scheduler,
+ num_training_steps=self.total_epochs * self._steps_per_epoch,
+ last_epoch=self.global_step - 1,
+ )
+
+ # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
+ # if cfg is missing profiler key or if `cfg.profiler.enabled = False`
+ self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
+
+ # Used to ignore labels for loss computation
+ self.ignore_labels_cache = torch.full(
+ (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
+ )
+
+ def _setup_profiler(
+ self, cfg_profiler: Optional[DictConfig] = None
+ ) -> Union[torch.profiler.profile, DummyProfiler]:
+ """
+ Parses the `profiler` section of top-level `cfg` and sets up profiler
+
+ Args:
+ cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
+ `recipe.main`). Default None.
+
+ Returns:
+ profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
+ for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
+ that the instrumented training loop does not need to be changed profiling is disabled.
+
+ The profiler config can be provided in configs under the `profiler` key with the following layout:
+
+ .. code-block:: yaml
+ profiler:
+ enabled: bool
+
+ #Output directory of trace artifacts
+ output_dir: str
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: bool
+ cuda: bool
+
+ #Trace options
+ profile_memory: bool
+ with_stack: bool
+ record_shapes: bool
+ with_flops: bool
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: int
+ warmup_steps: int
+ active_steps: int
+ num_cycles: int
+ """
+ # Missing profiler section in config, assume disabled
+ if cfg_profiler is None:
+ cfg_profiler = DictConfig({"enabled": False})
+
+ # Check that component is included and set correctly
+ if cfg_profiler.get("_component_", None) is None:
+ cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
+ else:
+ assert (
+ cfg_profiler.get("_component_")
+ == "torchtune.training.setup_torch_profiler"
+ ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"
+
+ profiler, profiler_cfg = config.instantiate(cfg_profiler)
+
+ if self._is_rank_zero:
+ log.info(f" Profiler config after instantiation: {profiler_cfg}")
+
+ self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
+ if profiler_cfg["enabled"]:
+ self.profiler_wait_steps = profiler_cfg["wait_steps"]
+ self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
+ self.profiler_active_steps = profiler_cfg["active_steps"]
+
+ return profiler
+
+ def _convert_model_to_qat(self, model: nn.Module, quantizer_cfg: DictConfig):
+ """
+ Convert the model to support quantization-aware training during fine-tuning.
+ """
+ for name, child in model.named_modules():
+ if isinstance(child, DoRALinear):
+ raise ValueError("QAT is currently not compatible with DoRA")
+ quantizer = config.instantiate(quantizer_cfg)
+ quantizer.precision = self._dtype
+ quantizer_mode = training.quantization.get_quantizer_mode(quantizer)
+ if "qat" not in quantizer_mode:
+ raise ValueError(
+ "Quantizer mode '%s' is not supported for finetuning" % quantizer_mode
+ )
+ activation_config = quantizer.get_activation_fake_quantize_config()
+ weight_config = quantizer.get_weight_fake_quantize_config()
+ swap_lora_linear_with_qat(model, activation_config, weight_config)
+
+ def _setup_model(
+ self,
+ cfg_model: DictConfig,
+ enable_activation_checkpointing: bool,
+ enable_activation_offloading: bool,
+ fsdp_cpu_offload: bool,
+ reshard_after_forward: bool,
+ base_model_state_dict: Dict[str, Any],
+ custom_sharded_layers: Optional[List[str]] = None,
+ lora_weights_state_dict: Optional[Dict[str, Any]] = None,
+ quantizer_cfg: Optional[DictConfig] = None,
+ ) -> nn.Module:
+ """
+ Model initialization has some important considerations:
+ a. To minimize GPU peak memory, we initialize the model on meta device with
+ the right dtype
+ b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
+ full state dicts are loaded with ``torch.load(mmap=True)``
+ c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module`
+ """
+
+ self._lora_rank = cfg_model.lora_rank
+ self._lora_alpha = cfg_model.lora_alpha
+ self._lora_attn_modules = list(cfg_model.lora_attn_modules)
+ self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
+ self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)
+
+ if self._is_rank_zero:
+ log.info(
+ "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
+ )
+ init_start = time.perf_counter()
+
+ if quantizer_cfg is None:
+ raise ValueError("Quantizer must be specified for QAT + LoRA finetuning")
+
+ with training.set_default_dtype(self._dtype), torch.device("meta"):
+ model = config.instantiate(cfg_model)
+ self._convert_model_to_qat(model, quantizer_cfg)
+
+ set_trainable_params(model, get_adapter_params(model))
+
+ if self._compile:
+ training.compile_model(model, verbose=self._is_rank_zero)
+
+ if enable_activation_checkpointing:
+ training.set_activation_checkpointing(
+ model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
+ )
+
+ # For FSDP sharding
+ fsdp_shard_conditions = [
+ partial(
+ training.get_shard_conditions,
+ names_to_match=custom_sharded_layers,
+ )
+ ]
+ training.shard_model(
+ model=model,
+ shard_conditions=fsdp_shard_conditions,
+ cpu_offload=fsdp_cpu_offload,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+ if lora_weights_state_dict:
+ lora_missing, lora_unexpected = training.load_from_full_model_state_dict(
+ model,
+ lora_weights_state_dict,
+ self._device,
+ self._is_rank_zero,
+ cpu_offload=fsdp_cpu_offload,
+ )
+ else:
+ lora_missing, lora_unexpected = None, None
+
+ # Initialize LoRA params and RoPE buffers
+ with training.set_default_dtype(self._dtype), self._device:
+ lora_device = "cpu" if fsdp_cpu_offload else self._device
+ for m in model.modules():
+ if (
+ isinstance(m, LoRALinear) or isinstance(m, DoRALinear)
+ ) and not lora_weights_state_dict:
+ # lora may not be covered in state dict
+ # if finetune for the 1st time
+ m.lora_a.to_empty(device=lora_device)
+ m.lora_b.to_empty(device=lora_device)
+ m.initialize_parameters()
+ # RoPE is not covered in state dict
+ if hasattr(m, "rope_init"):
+ m.rope_init()
+
+ base_missing, base_unexpected = training.load_from_full_model_state_dict(
+ model,
+ base_model_state_dict,
+ self._device,
+ self._is_rank_zero,
+ cpu_offload=fsdp_cpu_offload,
+ )
+ validate_missing_and_unexpected_for_lora(
+ lora_attn_modules=self._lora_attn_modules,
+ apply_lora_to_mlp=self._apply_lora_to_mlp,
+ apply_lora_to_output=self._apply_lora_to_output,
+ base_missing=base_missing,
+ base_unexpected=base_unexpected,
+ lora_missing=lora_missing,
+ lora_unexpected=lora_unexpected,
+ )
+ # Ensure no params and buffers are on meta device
+ training.validate_no_params_on_meta_device(model)
+
+ # activation offloading
+ self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
+ model, enable_activation_offloading
+ )
+
+ # log
+ if self._is_rank_zero:
+ log.info(
+ f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
+ )
+ memory_stats = training.get_memory_stats(device=self._device)
+ training.log_memory_stats(memory_stats)
+
+ # synchronize before training begins
+ torch.distributed.barrier()
+
+ return model
+
+ def _setup_optimizer(
+ self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
+ ) -> Optimizer:
+ optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
+ if opt_state_dict:
+ training.load_from_full_optimizer_state_dict(
+ optimizer,
+ opt_state_dict,
+ self._device,
+ )
+
+ if self._is_rank_zero:
+ log.info("Optimizer is initialized.")
+ return optimizer
+
+ def _setup_lr_scheduler(
+ self,
+ cfg_lr_scheduler: DictConfig,
+ num_training_steps: int,
+ last_epoch: int,
+ ) -> Optimizer:
+ lr_scheduler = config.instantiate(
+ cfg_lr_scheduler,
+ self._optimizer,
+ num_training_steps=num_training_steps,
+ last_epoch=last_epoch,
+ )
+ if self._is_rank_zero:
+ log.info("Learning rate scheduler is initialized.")
+ return lr_scheduler
+
+ def _setup_data(
+ self,
+ cfg_dataset: DictConfig,
+ shuffle: bool,
+ batch_size: int,
+ collate_fn: str,
+ ) -> Tuple[DistributedSampler, DataLoader]:
+ """
+ All data related setup happens here. Currently this recipe only supports the
+ DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
+ iterable datasets and streaming datasets are not supported.
+ """
+ world_size, rank = training.get_world_size_and_rank()
+
+ if isinstance(cfg_dataset, ListConfig):
+ datasets = [
+ config.instantiate(single_cfg_dataset, self._tokenizer)
+ for single_cfg_dataset in cfg_dataset
+ ]
+ ds = ConcatDataset(datasets=datasets)
+ packed = False
+ else:
+ ds = config.instantiate(cfg_dataset, self._tokenizer)
+ packed = cfg_dataset.get("packed", False)
+
+ # Instantiate collate_fn
+ if "left_pad_sequence" in collate_fn:
+ raise RuntimeError("left_pad_sequence collator is only for inference.")
+ collate_fn = _get_component_from_path(collate_fn)
+
+ sampler = DistributedSampler(
+ ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
+ )
+
+ dataloader = DataLoader(
+ dataset=ds,
+ batch_size=batch_size,
+ sampler=sampler,
+ # dropping last avoids shape issues with compile + flex attention
+ drop_last=True,
+ collate_fn=(
+ partial(
+ collate_fn,
+ padding_idx=self._tokenizer.pad_id,
+ ignore_idx=self._loss_fn.ignore_index,
+ )
+ if not packed
+ else padded_collate_packed
+ ),
+ )
+
+ if self._is_rank_zero:
+ log.info("Dataset and Sampler are initialized.")
+
+ return sampler, dataloader
+
+ def save_checkpoint(
+ self,
+ epoch: int,
+ ) -> None:
+ """
+ Checkpoint the state of the recipe. The constructed checkpoint state dict
+ contains the following information:
+ - Merged weights with key MODEL_KEY
+ - Adapter weights with key ADAPTER_KEY
+ - Relevant recipe state if training is not complete
+ - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
+
+ Checkpointer will save the merged weights, adapter weights and recipe state in
+ different checkpoint files. To correctly resume from training, the adapter weights
+ and recipe state must be provided along with the base model weights.
+ """
+ # final dict passed onto the checkpointer
+ checkpoint_dict = {}
+
+ intermediate_checkpoint = epoch + 1 < self.total_epochs
+
+ if self._is_rank_zero:
+ log.info(
+ "Saving checkpoint. This may take some time. Retrieving full model state dict..."
+ )
+ start = time.perf_counter()
+
+ # To prevent GPU memory from spiking during checkpoint save,
+ # we consolidate the full model and optim state dicts on CPU for rank 0
+ state_dict = self._model.state_dict()
+ if self._save_adapter_weights_only:
+ state_dict = get_adapter_state_dict(state_dict, device=None)
+
+ cpu_state_dict = training.gather_cpu_state_dict(
+ state_dict,
+ self._is_rank_zero,
+ device=self._device,
+ )
+ if self._is_rank_zero:
+ log.info(
+ f"Getting full model state dict took {time.perf_counter() - start:.2f} secs"
+ )
+
+ if intermediate_checkpoint:
+ if self._is_rank_zero:
+ log.info("Retrieving optimizer state dict...")
+ opt_state_dict = training.get_full_optimizer_state_dict(
+ self._optimizer,
+ self._is_rank_zero,
+ device=self._device,
+ )
+ if self._is_rank_zero:
+ log.info(
+ f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs"
+ )
+ else:
+ opt_state_dict = None
+
+ # Now that we have the model and opt state dict, create the actual checkpoint dict
+ # to be sent to the checkpointer and ultimately written to file
+ if self._is_rank_zero:
+ start = time.perf_counter()
+
+ if self._save_adapter_weights_only:
+ adapter_state_dict = cpu_state_dict
+ else:
+ # Filter out the adapter keys and weights from the model state dict. These will
+ # be saved separately
+ adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
+
+ # merge the adapter weights and base weights to create the model checkpoint
+ merged_state_dict = get_merged_lora_ckpt(
+ cpu_state_dict,
+ rank=self._lora_rank,
+ alpha=self._lora_alpha,
+ )
+ checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
+ checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
+
+ # if training is in-progress, checkpoint the optimizer state and recipe state
+ # as well.
+ if intermediate_checkpoint:
+ checkpoint_dict.update(
+ {
+ training.OPT_KEY: opt_state_dict,
+ training.SEED_KEY: self.seed,
+ training.EPOCHS_KEY: self.epochs_run,
+ training.TOTAL_EPOCHS_KEY: self.total_epochs,
+ training.MAX_STEPS_KEY: self.max_steps_per_epoch,
+ }
+ )
+
+ adapter_config = {
+ "r": self._lora_rank,
+ "lora_alpha": self._lora_alpha,
+ "target_modules": get_lora_module_names(
+ self._lora_attn_modules,
+ self._apply_lora_to_mlp,
+ self._apply_lora_to_output,
+ ),
+ "peft_type": "LORA",
+ }
+ checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config})
+ self._checkpointer.save_checkpoint(
+ checkpoint_dict,
+ epoch=epoch,
+ intermediate_checkpoint=intermediate_checkpoint,
+ adapter_only=self._save_adapter_weights_only,
+ )
+ log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs")
+
+ torch.distributed.barrier()
+
+ def train(self) -> None:
+ """
+ The core training loop.
+ """
+ # clean up before training begins
+ training.cleanup_before_training()
+
+ world_size, rank = training.get_world_size_and_rank()
+
+ # zero out the gradients before starting training
+ self._optimizer.zero_grad()
+
+ # Initialize tokens count and running loss (for grad accumulation)
+ t0 = time.perf_counter()
+ running_loss = 0
+ num_tokens = 0
+
+ self._profiler.start()
+ # self.epochs_run should be non-zero when we're resuming from a checkpoint
+ for curr_epoch in range(self.epochs_run, self.total_epochs):
+
+ # Update the sampler to ensure data is correctly shuffled across epochs
+ # in case shuffle is True
+ self._sampler.set_epoch(curr_epoch)
+
+ pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
+ for idx, batch in enumerate(self._dataloader):
+ if (
+ self.max_steps_per_epoch is not None
+ and (idx // self._gradient_accumulation_steps)
+ == self.max_steps_per_epoch
+ ):
+ break
+
+ # Start tracking CUDA memory for active steps for just the first epoch
+ if (
+ self._is_rank_zero
+ and curr_epoch == 0
+ and self.profiler_profile_memory
+ and idx == self.profiler_wait_steps + self.profiler_warmup_steps
+ ):
+ torch.cuda.memory._record_memory_history()
+
+ utils.batch_to_device(batch, self._device)
+
+ # Calculate the number of unmasked tokens in the current batch
+ # and increment the total number of tokens seen in the step
+ current_num_tokens = (
+ batch["labels"] != self._loss_fn.ignore_index
+ ).sum()
+ num_tokens += current_num_tokens
+
+ # Shape [b, s], needed for the loss not the model
+ labels = batch.pop("labels")
+
+ with self.activations_handling_ctx:
+ logits = self._model(**batch)
+
+ # Shift labels to compute loss
+ # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
+ # But this way we dont need to slice the logits. We just add an ignore index to labels.
+ labels = torch.hstack(
+ (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
+ )
+ if not isinstance(logits, list):
+ labels = labels.reshape(-1)
+ logits = logits.reshape(-1, logits.size(-1))
+
+ # Compute loss
+ # Loss is normalized by default so we multiply by the number of tokens
+ # This way we can normalize by the total number of tokens if we're accumulating gradients
+ current_loss = self._loss_fn(logits, labels) * current_num_tokens
+
+ # free logits otherwise it peaks backward memory
+ del logits
+
+ running_loss += current_loss
+ current_loss.backward()
+
+ # Step with optimizer
+ if (idx + 1) % self._gradient_accumulation_steps == 0:
+ # Get total number of tokens across all ranks to normalize gradients
+ torch.distributed.all_reduce(num_tokens)
+ # This will ensure that the logged loss matches what we're optimizing
+ torch.distributed.all_reduce(running_loss)
+ # Manually scale the gradients from unnormalized loss by total # of tokens
+ training.scale_grads(self._model, 1 / num_tokens)
+ if self._clip_grad_norm is not None:
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ self._model.parameters(),
+ max_norm=float(self._clip_grad_norm),
+ )
+ self._optimizer.step()
+ self._optimizer.zero_grad(set_to_none=True)
+ self._lr_scheduler.step()
+
+ # Update the number of steps when the weights are updated
+ self.global_step += 1
+
+ loss_to_log = running_loss.item() / num_tokens
+ pbar.update(1)
+ pbar.set_description(
+ f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
+ )
+
+ # Log per-step metrics
+ if (
+ self.global_step % self._log_every_n_steps == 0
+ and self._is_rank_zero
+ ):
+ time_per_step = time.perf_counter() - t0
+ log_dict = {
+ "loss": loss_to_log,
+ "lr": self._optimizer.param_groups[0]["lr"],
+ "tokens_per_second_per_gpu": num_tokens
+ / (time_per_step * world_size),
+ }
+ if self._log_peak_memory_stats:
+ log_dict.update(
+ training.get_memory_stats(device=self._device)
+ )
+
+ if self._clip_grad_norm is not None:
+ log_dict.update({"grad_norm": grad_norm})
+ self._metric_logger.log_dict(
+ log_dict,
+ step=self.global_step,
+ )
+
+ # Reset running stats for the next step
+ running_loss = 0
+ num_tokens = 0
+ t0 = time.perf_counter()
+
+ # Stop tracking CUDA memory now that active steps are complete
+ if (
+ self._is_rank_zero
+ and curr_epoch == 0
+ and self.profiler_profile_memory
+ and idx
+ == self.profiler_wait_steps
+ + self.profiler_warmup_steps
+ + self.profiler_active_steps
+ ):
+ torch.cuda.memory._record_memory_history(enabled=None)
+
+ # Step profiler
+ # Note that this is called within gradient accumulation block, hence
+ # will include multiple forward / backward passes if gradient accumulation > 1
+ self._profiler.step()
+
+ self.epochs_run += 1
+ self.save_checkpoint(epoch=curr_epoch)
+
+ self._profiler.stop()
+
+ def cleanup(self) -> None:
+ if self._is_rank_zero:
+ self._metric_logger.close()
+ destroy_process_group()
+
+
+@config.parse
+def recipe_main(cfg: DictConfig) -> None:
+ """
+ Entry point for the recipe.
+
+ Configurable parameters are read in the following order:
+ - Parameters specified in config (see available configs through ``tune ls``)
+ - Overwritten by arguments from the command-line
+ """
+ if not training.is_distributed():
+ raise RuntimeError(
+ "Distributed finetune recipe should be run via a distributed launcher."
+ "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
+ )
+ if cfg.get("fsdp_cpu_offload", False):
+ # Utilize all available CPU cores for intra-op parallelism. This provides ~2x
+ # speed up when benchmarking fused AdamW on CPU
+ training.set_torch_num_threads()
+ init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
+
+ config.log_config(recipe_name="QATLoRAFinetuneRecipeDistributed", cfg=cfg)
+
+ recipe = QATLoRAFinetuneRecipeDistributed(cfg=cfg)
+ recipe.setup(cfg=cfg)
+ recipe.train()
+ recipe.cleanup()
+
+
+if __name__ == "__main__":
+ sys.exit(recipe_main())
diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py
index 31f9a137bd..fae8d19d49 100644
--- a/tests/recipes/test_full_finetune_distributed.py
+++ b/tests/recipes/test_full_finetune_distributed.py
@@ -112,6 +112,8 @@ def test_loss(
# should be the same.
if not optim_in_bwd:
cmd.append("clip_grad_norm=100")
+ # Test that gradient clipping works with CPU offload
+ cmd.append("fsdp_cpu_offload=True")
else:
cmd.append("optimizer_in_bwd=True")
diff --git a/tests/recipes/test_knowledge_distillation_distributed.py b/tests/recipes/test_knowledge_distillation_distributed.py
index 2751b004f5..103327dfca 100644
--- a/tests/recipes/test_knowledge_distillation_distributed.py
+++ b/tests/recipes/test_knowledge_distillation_distributed.py
@@ -70,7 +70,7 @@ def test_loss(self, tmpdir, monkeypatch):
cmd = f"""
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \
- --config llama3_2/knowledge_distillation_distributed \
+ --config llama3_2/8B_to_1B_KD_lora_distributed \
output_dir={tmpdir} \
checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
@@ -128,7 +128,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
# Train for two epochs
cmd_1 = f"""
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \
- --config llama3_2/knowledge_distillation_distributed \
+ --config llama3_2/8B_to_1B_KD_lora_distributed \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
@@ -158,7 +158,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
cmd_2 = f"""
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \
- --config llama3_2/knowledge_distillation_distributed \
+ --config llama3_2/8B_to_1B_KD_lora_distributed \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
@@ -209,7 +209,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
cmd = f"""
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \
- --config llama3_2/knowledge_distillation_distributed \
+ --config llama3_2/8B_to_1B_KD_lora_distributed \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{ckpt_dir}' \
diff --git a/tests/recipes/test_knowledge_distillation_single_device.py b/tests/recipes/test_knowledge_distillation_single_device.py
index af1de4ccd9..76127df629 100644
--- a/tests/recipes/test_knowledge_distillation_single_device.py
+++ b/tests/recipes/test_knowledge_distillation_single_device.py
@@ -73,7 +73,7 @@ def test_loss(
tmpdir,
monkeypatch,
):
- config = "qwen2/knowledge_distillation_single_device"
+ config = "qwen2/1.5_to_0.5B_KD_lora_single_device"
model_type = "llama3"
ckpt_type = "tune"
ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
@@ -160,7 +160,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
# Train for two epochs
cmd_1 = f"""
tune run knowledge_distillation_single_device \
- --config qwen2/knowledge_distillation_single_device \
+ --config qwen2/1.5_to_0.5B_KD_lora_single_device \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
@@ -196,7 +196,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
cmd_2 = f"""
tune run knowledge_distillation_single_device \
- --config qwen2/knowledge_distillation_single_device \
+ --config qwen2/1.5_to_0.5B_KD_lora_single_device \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir={tmpdir} \
@@ -252,7 +252,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch):
cmd = f"""
tune run knowledge_distillation_single_device \
- --config qwen2/knowledge_distillation_single_device \
+ --config qwen2/1.5_to_0.5B_KD_lora_single_device \
output_dir={tmpdir} \
checkpointer._component_={ckpt_component} \
checkpointer.checkpoint_dir='{ckpt_dir}' \
diff --git a/tests/recipes/test_qat_lora_finetune_distributed.py b/tests/recipes/test_qat_lora_finetune_distributed.py
new file mode 100644
index 0000000000..4d7c4b6899
--- /dev/null
+++ b/tests/recipes/test_qat_lora_finetune_distributed.py
@@ -0,0 +1,271 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import runpy
+import sys
+from pathlib import Path
+
+import pytest
+import torch
+from omegaconf import OmegaConf
+from tests.common import TUNE_PATH
+from tests.recipes.utils import (
+ CKPT_COMPONENT_MAP,
+ dummy_alpaca_dataset_config,
+ MODEL_TEST_CONFIGS,
+ write_hf_ckpt_config,
+)
+from tests.test_utils import (
+ CKPT_MODEL_PATHS,
+ gen_log_file_name,
+ get_loss_values_from_metric_logger,
+ gpu_test,
+ TOKENIZER_PATHS,
+)
+from torchtune import config
+from torchtune.training.quantization import _torchao_0_7_supported
+
+
+class TestQATLoRAFinetuneDistributedRecipe:
+ def _get_test_config_overrides(self):
+ return [
+ "dataset.train_on_input=False",
+ "seed=9",
+ "epochs=2",
+ "dtype=fp32",
+ "max_steps_per_epoch=2",
+ "optimizer.lr=2e-5",
+ "log_every_n_steps=1",
+ "compile=False",
+ ] + dummy_alpaca_dataset_config()
+
+ def _fetch_expected_loss_values(self, model_type):
+ loss_values_map = {
+ "llama3": [11.9835, 11.9694, 11.9615, 11.9383],
+ }
+ return loss_values_map[model_type]
+
+ @pytest.mark.integration_test
+ @gpu_test(gpu_count=2)
+ @pytest.mark.parametrize(
+ "micro_batch_size, gradient_accumulation_steps, should_compile",
+ [(4, 1, True), (1, 4, False)],
+ )
+ @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
+ def test_loss(
+ self,
+ micro_batch_size,
+ gradient_accumulation_steps,
+ should_compile,
+ tmpdir,
+ monkeypatch,
+ ):
+ ckpt = "llama3_tune"
+ ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
+ tokenizer_path = Path(TOKENIZER_PATHS["llama3"])
+ ckpt_dir = ckpt_path.parent
+ log_file = gen_log_file_name(tmpdir)
+ cmd = f"""
+ tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed
+ --config llama3/8B_qat_lora \
+ batch_size={micro_batch_size} \
+ gradient_accumulation_steps={gradient_accumulation_steps} \
+ output_dir={tmpdir} \
+ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
+ checkpointer.checkpoint_dir='{ckpt_dir}' \
+ checkpointer.checkpoint_files=[{ckpt_path}]\
+ checkpointer.output_dir={tmpdir} \
+ checkpointer.model_type=LLAMA3 \
+ metric_logger.filename={log_file} \
+ tokenizer.path={tokenizer_path} \
+ tokenizer.prompt_template=null \
+ compile={should_compile} \
+ enable_activation_checkpointing=False \
+ enable_activation_offloading=False \
+ quantizer.groupsize=32 \
+ """.split()
+
+ model_config = MODEL_TEST_CONFIGS["llama3_lora"]
+
+ cmd = cmd + self._get_test_config_overrides() + model_config
+ monkeypatch.setattr(sys, "argv", cmd)
+ runpy.run_path(TUNE_PATH, run_name="__main__")
+ loss_values = get_loss_values_from_metric_logger(log_file)
+ expected_loss_values = self._fetch_expected_loss_values("llama3")
+ torch.testing.assert_close(
+ loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
+ )
+
+ @pytest.mark.integration_test
+ @gpu_test(gpu_count=2)
+ @pytest.mark.parametrize(
+ "config, model_type, ckpt_type, save_adapter_weights_only",
+ [
+ ("llama3/8B_qat_lora", "llama3", "tune", False),
+ ],
+ )
+ @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
+ def test_training_state_on_resume(
+ self,
+ config,
+ model_type,
+ ckpt_type,
+ tmpdir,
+ monkeypatch,
+ save_adapter_weights_only,
+ ):
+ """Test whether the recipe state is correctly updated on resume. Since this
+ is model agnostic, we should run this on the small model only. The test
+ consists of three stages:
+ - Train a model for 2 epochs
+ - Resume training after epoch 1
+ - Make sure final loss matches the expected value of a model successfully resumed from a ckpt
+ """
+ ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
+ ckpt = model_type + "_" + ckpt_type
+ expected_loss_values = self._fetch_expected_loss_values(model_type)
+
+ ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
+ tokenizer_path = Path(TOKENIZER_PATHS[model_type])
+ ckpt_dir = ckpt_path.parent
+ log_file = gen_log_file_name(tmpdir)
+
+ # Config file needed for model conversion.
+ # Create a second copy for training resume
+ write_hf_ckpt_config(ckpt_dir)
+ write_hf_ckpt_config(tmpdir)
+
+ # Train for two epochs
+ cmd_1 = f"""
+ tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed \
+ --config {config} \
+ batch_size=4 \
+ gradient_accumulation_steps=1 \
+ output_dir={tmpdir} \
+ checkpointer._component_={ckpt_component} \
+ checkpointer.checkpoint_dir='{ckpt_dir}' \
+ checkpointer.checkpoint_files=[{ckpt_path}]\
+ checkpointer.output_dir={tmpdir} \
+ checkpointer.model_type={model_type.upper()} \
+ tokenizer.path='{tokenizer_path}' \
+ tokenizer.prompt_template=null \
+ save_adapter_weights_only={save_adapter_weights_only} \
+ enable_activation_checkpointing=True \
+ enable_activation_offloading=True \
+ quantizer.groupsize=32 \
+ """.split()
+
+ model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
+
+ cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
+ monkeypatch.setattr(sys, "argv", cmd_1)
+ runpy.run_path(TUNE_PATH, run_name="__main__")
+
+ # Resume training
+ cmd_2 = f"""
+ tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed \
+ --config {config} \
+ batch_size=4 \
+ gradient_accumulation_steps=1 \
+ output_dir={tmpdir} \
+ checkpointer._component_={ckpt_component} \
+ checkpointer.checkpoint_dir={tmpdir} \
+ checkpointer.checkpoint_files=[{ckpt_path}]\
+ checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
+ checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
+ checkpointer.output_dir={tmpdir} \
+ checkpointer.model_type={model_type.upper()} \
+ tokenizer.path='{tokenizer_path}' \
+ tokenizer.prompt_template=null \
+ resume_from_checkpoint=True \
+ metric_logger.filename={log_file} \
+ enable_activation_checkpointing=True \
+ enable_activation_offloading=True \
+ quantizer.groupsize=32 \
+ """.split()
+
+ cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
+ monkeypatch.setattr(sys, "argv", cmd_2)
+ runpy.run_path(TUNE_PATH, run_name="__main__")
+
+ expected_loss_values = self._fetch_expected_loss_values(model_type)[2:]
+
+ loss_values = get_loss_values_from_metric_logger(log_file)
+ torch.testing.assert_close(
+ loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
+ )
+
+ @pytest.mark.integration_test
+ @pytest.mark.parametrize(
+ "recipe_config, model_type, ckpt_type",
+ [
+ ("llama3/8B_qat_lora", "llama3", "tune"),
+ ],
+ )
+ @gpu_test(gpu_count=2)
+ @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
+ def test_save_and_load_merged_weights(
+ self, recipe_config, model_type, ckpt_type, tmpdir, monkeypatch
+ ):
+ ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
+ ckpt = model_type + "_" + ckpt_type
+ ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
+ tokenizer_path = Path(TOKENIZER_PATHS[model_type])
+ ckpt_dir = ckpt_path.parent
+ cmd = f"""
+ tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed \
+ --config {recipe_config} \
+ batch_size=4 \
+ gradient_accumulation_steps=1 \
+ output_dir={tmpdir} \
+ model=torchtune.models.lora_small_test_model \
+ checkpointer._component_={ckpt_component} \
+ checkpointer.checkpoint_dir='{ckpt_dir}' \
+ checkpointer.checkpoint_files=[{ckpt_path}]\
+ checkpointer.output_dir={tmpdir} \
+ checkpointer.model_type={model_type.upper()} \
+ tokenizer.path='{tokenizer_path}' \
+ tokenizer.prompt_template=null \
+ enable_activation_checkpointing=True \
+ enable_activation_offloading=True \
+ quantizer.groupsize=32 \
+ """.split()
+
+ model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
+
+ cmd = cmd + self._get_test_config_overrides() + model_config
+ monkeypatch.setattr(sys, "argv", cmd)
+ runpy.run_path(TUNE_PATH, run_name="__main__")
+
+ # Next load both the merged weights in a base model
+ # and the base model weights + trained adapter weights in the LoRA model
+ # The results of calling forward on dummy inputs should be the same.
+ inputs = torch.randint(low=0, high=32_000, size=(2, 100))
+
+ # Build LoRA model for loading base + adapter weights separately
+ lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model)
+
+ # Build base model for loading merged weights
+ base_config = MODEL_TEST_CONFIGS[model_type]
+ model = config.instantiate(OmegaConf.from_dotlist(base_config).model)
+
+ # Load base model and trained adapter weights into LoRA model and call fwd
+ with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
+ lora_sd = torch.load(f, weights_only=True)
+ with open(ckpt_path, "rb") as f:
+ base_model_sd = torch.load(f, weights_only=True)
+ lora_model.load_state_dict(lora_sd, strict=False)
+ lora_model.load_state_dict(base_model_sd, strict=False)
+ baseline_out = lora_model(inputs)
+
+ # Load merged final ckpt directly into model and call fwd
+ with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
+ sd = torch.load(f, weights_only=True)
+ model.load_state_dict(sd)
+ merged_ckpt_out = model(inputs)
+
+ torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5)
diff --git a/tests/torchtune/data/test_messages.py b/tests/torchtune/data/test_messages.py
index a46cfd9349..86b7d7319f 100644
--- a/tests/torchtune/data/test_messages.py
+++ b/tests/torchtune/data/test_messages.py
@@ -9,6 +9,7 @@
import pytest
from PIL import Image
+from tests.common import ASSETS
from tests.test_utils import (
assert_dialogue_equal,
CHAT_SAMPLE,
@@ -24,6 +25,8 @@
validate_messages,
)
+PYTORCH_RGB_IMAGE_AS_PIL = Image.open(ASSETS / "rgb_pytorch.png")
+
class TestMessage:
@pytest.fixture
@@ -106,6 +109,60 @@ def sample(self):
"maybe_output": "hello world",
}
+ @pytest.mark.parametrize(
+ "input_image, expected_image",
+ [
+ ("rgb_pytorch.png", PYTORCH_RGB_IMAGE_AS_PIL),
+ (ASSETS / "rgb_pytorch.png", PYTORCH_RGB_IMAGE_AS_PIL),
+ (PYTORCH_RGB_IMAGE_AS_PIL, PYTORCH_RGB_IMAGE_AS_PIL),
+ ],
+ )
+ def test_call_with_image(self, sample, input_image, expected_image):
+ # Add the image to the sample
+ sample["image"] = input_image
+
+ # Create the transform
+ transform = InputOutputToMessages(
+ column_map={
+ "input": "maybe_input",
+ "output": "maybe_output",
+ "image": "image",
+ },
+ # Need to test if the image_dir is properly joined w/ image
+ image_dir=ASSETS if isinstance(input_image, str) else None,
+ )
+ actual = transform(sample)
+ expected = [
+ Message(
+ role="user",
+ content=[
+ {"type": "text", "content": "hello world"},
+ {"type": "image", "content": expected_image},
+ ],
+ masked=True,
+ eot=True,
+ ),
+ Message(role="assistant", content="hello world", masked=False, eot=True),
+ ]
+ assert_dialogue_equal(actual["messages"], expected)
+
+ def test_call_with_image_fails_when_bad_image_inputs_are_passed(self, sample):
+ # Construct a bad column_map without an 'image' key
+ column_map = {
+ "input": "maybe_input",
+ "output": "maybe_output",
+ }
+
+ # Create a transform that expects an image column
+ with pytest.raises(
+ ValueError,
+ match="Please specify an 'image' key in column_map",
+ ):
+ transform = InputOutputToMessages(
+ column_map=column_map,
+ image_dir=ASSETS,
+ )
+
def test_call(self, sample):
transform = InputOutputToMessages(
column_map={"input": "maybe_input", "output": "maybe_output"}
diff --git a/tests/torchtune/datasets/multimodal/test_llava_instruct_dataset.py b/tests/torchtune/datasets/multimodal/test_llava_instruct_dataset.py
index 5df17bb877..11e039e66d 100644
--- a/tests/torchtune/datasets/multimodal/test_llava_instruct_dataset.py
+++ b/tests/torchtune/datasets/multimodal/test_llava_instruct_dataset.py
@@ -86,3 +86,12 @@ def test_get_item(self, load_image, load_dataset, tokenizer, test_image_pil):
assert Counter(input) == expected_count
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 11
assert images == [test_image_pil]
+
+ def test_dataset_fails_with_packed(self, tokenizer):
+ with pytest.raises(
+ ValueError, match="Multimodal datasets don't support packing yet."
+ ):
+ llava_instruct_dataset(
+ model_transform=tokenizer,
+ packed=True,
+ )
diff --git a/tests/torchtune/datasets/multimodal/test_multimodal_chat_dataset.py b/tests/torchtune/datasets/multimodal/test_multimodal_chat_dataset.py
new file mode 100644
index 0000000000..8b12d3a85e
--- /dev/null
+++ b/tests/torchtune/datasets/multimodal/test_multimodal_chat_dataset.py
@@ -0,0 +1,24 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pytest
+from tests.test_utils import DummyTokenizer
+
+from torchtune.datasets.multimodal import multimodal_chat_dataset
+
+
+class TestMultimodalChatDataset:
+ @pytest.fixture
+ def tokenizer(self):
+ return DummyTokenizer()
+
+ def test_dataset_fails_with_packed(self, tokenizer):
+ with pytest.raises(
+ ValueError, match="Multimodal datasets don't support packing yet."
+ ):
+ multimodal_chat_dataset(
+ model_transform=tokenizer, source="json", packed=True
+ )
diff --git a/tests/torchtune/datasets/multimodal/test_the_cauldron_dataset.py b/tests/torchtune/datasets/multimodal/test_the_cauldron_dataset.py
index ed8ed40ec7..ebc485d8dd 100644
--- a/tests/torchtune/datasets/multimodal/test_the_cauldron_dataset.py
+++ b/tests/torchtune/datasets/multimodal/test_the_cauldron_dataset.py
@@ -79,3 +79,13 @@ def test_get_item(self, load_dataset, tokenizer, test_image_pil):
]
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 24
assert images == [test_image_pil]
+
+ def test_dataset_fails_with_packed(self, tokenizer):
+ with pytest.raises(
+ ValueError, match="Multimodal datasets don't support packing yet."
+ ):
+ the_cauldron_dataset(
+ model_transform=tokenizer,
+ subset="dummy",
+ packed=True,
+ )
diff --git a/tests/torchtune/datasets/multimodal/test_vqa_dataset.py b/tests/torchtune/datasets/multimodal/test_vqa_dataset.py
index 6ca36d9615..d2b80fdea7 100644
--- a/tests/torchtune/datasets/multimodal/test_vqa_dataset.py
+++ b/tests/torchtune/datasets/multimodal/test_vqa_dataset.py
@@ -47,3 +47,13 @@ def test_get_item(self, tokenizer):
assert prompt == expected_tokens[i]
assert label == expected_labels[i]
assert isinstance(image[0], PngImageFile)
+
+ def test_dataset_fails_with_packed(self, tokenizer):
+ with pytest.raises(
+ ValueError, match="Multimodal datasets don't support packing yet."
+ ):
+ vqa_dataset(
+ model_transform=tokenizer,
+ source="json",
+ packed=True,
+ )
diff --git a/tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py b/tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py
index 40a7d02c8c..834e8d78a9 100644
--- a/tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py
+++ b/tests/torchtune/datasets/test_hh_rlhf_helpful_dataset.py
@@ -107,3 +107,14 @@ def test_dataset_get_item(self, mock_load_dataset, train_on_input):
else:
# Check that the input is masked
assert sample["rejected_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 16
+
+ def test_dataset_fails_with_packed(self):
+ with pytest.raises(
+ ValueError,
+ match="Packed is currently not supported for preference datasets",
+ ):
+ hh_rlhf_helpful_dataset(
+ tokenizer=DummyTokenizer(),
+ train_on_input=True,
+ packed=True,
+ )
diff --git a/tests/torchtune/datasets/test_preference_dataset.py b/tests/torchtune/datasets/test_preference_dataset.py
index e6bbc264b3..4f5cba7a8d 100644
--- a/tests/torchtune/datasets/test_preference_dataset.py
+++ b/tests/torchtune/datasets/test_preference_dataset.py
@@ -155,3 +155,17 @@ def test_load_local_json(self):
assert expected_chosen_labels[0] == ds[0]["chosen_labels"]
assert expected_rejected_labels[0] == ds[0]["rejected_labels"]
+
+ def test_dataset_fails_with_packed(self):
+ with pytest.raises(
+ ValueError,
+ match="Packed is currently not supported for preference datasets.",
+ ):
+ preference_dataset(
+ tokenizer=DummyTokenizer(),
+ source="json",
+ data_files=str(ASSETS / "hh_rlhf_tiny.json"),
+ train_on_input=False,
+ split="train",
+ packed=True,
+ )
diff --git a/tests/torchtune/datasets/test_stack_exchange_paired_dataset.py b/tests/torchtune/datasets/test_stack_exchange_paired_dataset.py
index f9bcbebc08..6e8a9a4eb8 100644
--- a/tests/torchtune/datasets/test_stack_exchange_paired_dataset.py
+++ b/tests/torchtune/datasets/test_stack_exchange_paired_dataset.py
@@ -100,6 +100,16 @@ def test_dataset_get_item(self, mock_load_dataset, train_on_input):
# Check that the input is masked
assert sample["rejected_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 52
+ def test_dataset_fails_with_packed(self):
+ with pytest.raises(
+ ValueError,
+ match="Packed is currently not supported for preference datasets",
+ ):
+ stack_exchange_paired_dataset(
+ tokenizer=DummyTokenizer(),
+ packed=True,
+ )
+
class TestStackExchangePairedToMessages:
@pytest.fixture
diff --git a/tests/torchtune/modules/_export/test_attention.py b/tests/torchtune/modules/_export/test_attention.py
new file mode 100644
index 0000000000..ed2c022c3e
--- /dev/null
+++ b/tests/torchtune/modules/_export/test_attention.py
@@ -0,0 +1,230 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import tempfile
+import unittest
+
+import torch
+from torch._inductor.package import load_package, package_aoti
+from torch.testing import assert_close
+from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
+from torchtune.modules._export.attention import (
+ MultiHeadAttention as ExportMultiHeadAttention,
+)
+from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention
+from torchtune.utils import torch_version_ge
+
+
+class AttentionTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ torch.manual_seed(0)
+ # Constants
+ self.embed_dim = 2048
+ self.num_heads = 8
+ self.num_kv_heads = 8
+ self.head_dim = 64
+ self.max_seq_len = 128
+ self.rope_base = 500_000
+ self.scale_factor = 32
+
+ # Module dependency injections.
+ self.q_proj = torch.nn.Linear(
+ self.embed_dim, self.num_heads * self.head_dim, bias=False
+ )
+ self.k_proj = torch.nn.Linear(
+ self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
+ )
+ self.k_proj.weight.requires_grad = False
+ self.v_proj = torch.nn.Linear(
+ self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
+ )
+ self.v_proj.weight.requires_grad = False
+ self.output_proj = torch.nn.Linear(
+ self.num_heads * self.head_dim, self.embed_dim, bias=False
+ )
+ self.pos_embeddings = Llama3ScaledRoPE(
+ dim=self.head_dim,
+ max_seq_len=self.max_seq_len,
+ base=self.rope_base,
+ scale_factor=self.scale_factor,
+ )
+
+ # Original TorchTune reference module to test accuracy against.
+ self.tt_mha = TTMultiHeadAttention(
+ embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ num_kv_heads=self.num_kv_heads,
+ head_dim=self.head_dim,
+ q_proj=self.q_proj,
+ k_proj=self.k_proj,
+ v_proj=self.v_proj,
+ output_proj=self.output_proj,
+ pos_embeddings=self.pos_embeddings,
+ max_seq_len=self.max_seq_len,
+ )
+
+ # Source transformed module that we are testing.
+ self.et_mha = ExportMultiHeadAttention(
+ embed_dim=self.embed_dim,
+ num_heads=self.num_heads,
+ num_kv_heads=self.num_kv_heads,
+ head_dim=self.head_dim,
+ q_proj=self.q_proj,
+ k_proj=self.k_proj,
+ v_proj=self.v_proj,
+ output_proj=self.output_proj,
+ pos_embeddings=self.pos_embeddings,
+ max_seq_len=self.max_seq_len,
+ )
+ self.et_mha.load_state_dict(self.tt_mha.state_dict())
+ # Common inputs.
+ seq_len = 10
+ self.x = torch.randn(1, seq_len, self.embed_dim)
+ self.input_pos = torch.arange(seq_len).unsqueeze(0) # shape [1, seq_len]
+ seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
+ self.dynamic_shapes = (
+ {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
+ {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
+ {0: torch.export.Dim.STATIC, 1: seq_len_dim},
+ )
+ self.causal_mask = torch.tril(
+ torch.ones(
+ size=(self.max_seq_len, self.max_seq_len),
+ dtype=torch.bool,
+ )
+ )
+
+ @unittest.skipUnless(
+ torch_version_ge("2.6.0"), reason="torch.cond only works for 2.6.0"
+ )
+ def test_attention_eager(self):
+ et_res = self.et_mha(self.x, self.x) # Self attention.
+ tt_res = self.tt_mha(self.x, self.x) # Self attention.
+
+ assert_close(et_res, tt_res)
+
+ # test with kv cache
+ self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
+ self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
+
+ et_res = self.et_mha(self.x, self.x) # Self attention.
+ tt_res = self.tt_mha(self.x, self.x) # Self attention.
+
+ assert_close(et_res, tt_res)
+ self.et_mha.reset_cache()
+ self.tt_mha.reset_cache()
+
+ et_res = self.et_mha(
+ self.x, self.x, input_pos=self.input_pos
+ ) # Self attention with input pos.
+ tt_res = self.tt_mha(
+ self.x, self.x, input_pos=self.input_pos
+ ) # Self attention with input pos.
+
+ assert_close(et_res, tt_res)
+
+ # test kv cache read. Input pos can be [10, 11, ..., 19]
+ next_input_pos = torch.arange(10, 20).unsqueeze(0)
+ et_res = self.et_mha(
+ self.x, self.x, input_pos=next_input_pos
+ ) # Self attention with input pos.
+ tt_res = self.tt_mha(
+ self.x, self.x, input_pos=next_input_pos
+ ) # Self attention with input pos.
+
+ assert_close(et_res, tt_res)
+
+ @unittest.skipUnless(
+ torch_version_ge("2.6.0.dev20241117"), reason="Need recent fixes for export"
+ )
+ def test_attention_export(self):
+ # Self attention.
+
+ # test with kv cache
+ self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
+ self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
+ with torch.no_grad():
+ et_mha_ep = torch.export.export(
+ self.et_mha,
+ (self.x, self.x),
+ kwargs={"input_pos": self.input_pos},
+ dynamic_shapes=self.dynamic_shapes,
+ )
+ et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
+ tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
+
+ assert_close(et_res, tt_res)
+
+ @unittest.skipUnless(
+ torch_version_ge("2.6.0.dev20241117"), reason="Need recent fixes for aoti"
+ )
+ def test_attention_aoti(self):
+ # Self attention.
+
+ # test with kv cache
+ self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
+ self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
+ with torch.no_grad():
+ so = torch._export.aot_compile(
+ self.et_mha,
+ args=(self.x, self.x),
+ kwargs={"input_pos": self.input_pos},
+ options={
+ "aot_inductor.package": True,
+ "reorder_for_peak_memory": False,
+ },
+ dynamic_shapes=self.dynamic_shapes,
+ )
+ with tempfile.TemporaryDirectory() as tempdir:
+ path = package_aoti(os.path.join(tempdir, "mha.pt2"), so)
+ mha_aoti = load_package(path)
+
+ aoti_res = mha_aoti(self.x, self.x, input_pos=self.input_pos)
+ tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
+ assert_close(aoti_res, tt_res)
+
+ @unittest.skipUnless(
+ torch_version_ge("2.6.0"), reason="torch.cond only works for 2.6.0"
+ )
+ def test_attention_torch_cond_eager(self):
+ # Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they
+ # are giving the same results regarding the if condition.
+ # For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
+ self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
+ self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
+
+ # mask
+ mask = self.causal_mask[self.input_pos, :]
+ # First run
+ et_res = self.et_mha(
+ self.x, self.x, mask=mask, input_pos=self.input_pos
+ ) # Self attention with input pos.
+ tt_res = self.tt_mha(
+ self.x, self.x, mask=mask, input_pos=self.input_pos
+ ) # Self attention with input pos.
+
+ assert_close(et_res, tt_res)
+
+ # Second run test kv cache read. Input pos is [10, 11, ..., 19]
+ next_input_pos = torch.arange(10, 20).unsqueeze(0)
+
+ empty_y = torch.full_like(self.x, torch.nan)
+ mask = self.causal_mask[next_input_pos, :]
+ et_res = self.et_mha(
+ self.x, empty_y, mask=mask, input_pos=next_input_pos
+ ) # Self attention with input pos.
+ tt_res = self.tt_mha(
+ self.x, None, mask=mask, input_pos=next_input_pos
+ ) # Self attention with input pos.
+
+ assert_close(et_res, tt_res)
diff --git a/tests/torchtune/modules/_export/test_export_position_embeddings.py b/tests/torchtune/modules/_export/test_export_position_embeddings.py
new file mode 100644
index 0000000000..20bfb84deb
--- /dev/null
+++ b/tests/torchtune/modules/_export/test_export_position_embeddings.py
@@ -0,0 +1,184 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import tempfile
+import unittest
+
+import torch
+from torch._inductor.package import load_package, package_aoti
+from torch.testing import assert_close
+from torchtune.models.clip import (
+ TiledTokenPositionalEmbedding as TuneTiledTokenPositionalEmbedding,
+ TilePositionalEmbedding as TuneTilePositionalEmbedding,
+)
+from torchtune.modules._export._position_embeddings import (
+ replace_tile_positional_embedding,
+ replace_tiled_token_positional_embedding,
+ TiledTokenPositionalEmbedding,
+ TilePositionalEmbedding,
+)
+from torchtune.utils import torch_version_ge
+
+
+class TilePositionalEmbeddingTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.tpe = TilePositionalEmbedding(4, 1280)
+ self.ref_tpe = TuneTilePositionalEmbedding(4, 1280)
+ self.x = torch.randn(1, 4, 1600, 1280)
+ self.aspect_ratio = torch.tensor([[1, 1]])
+ num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4)
+ num_tokens = torch.export.Dim("num_tokens", min=1, max=1600)
+
+ self.dynamic_shape = {
+ 0: 1, # batch
+ 1: num_tiles_dim, # num tiles
+ 2: num_tokens, # num tokens
+ 3: 1280, # embedding dim
+ }
+
+ def test_tile_positional_embedding_smoke(self):
+ y = self.tpe(self.x, self.aspect_ratio)
+ ref_y = self.ref_tpe(self.x, self.aspect_ratio)
+
+ self.assertTrue(torch.allclose(y, ref_y))
+
+ @unittest.skipUnless(
+ torch_version_ge("2.6.0.dev20241117"), reason="Need recent fixes for export"
+ )
+ def test_tile_positional_embedding_export(self):
+
+ tpe_ep = torch.export.export(
+ self.tpe,
+ (self.x, self.aspect_ratio),
+ dynamic_shapes=(
+ self.dynamic_shape,
+ None,
+ ), # assuming aspect ratio is static
+ )
+
+ y = tpe_ep.module()(self.x, self.aspect_ratio)
+ ref_y = self.ref_tpe(self.x, self.aspect_ratio)
+
+ self.assertTrue(torch.allclose(y, ref_y))
+
+ @unittest.skipUnless(
+ torch_version_ge("2.6.0.dev20241117"), reason="Need recent fixes for aoti"
+ )
+ def test_tile_positional_embedding_aoti(self):
+ so = torch._export.aot_compile(
+ self.tpe,
+ args=(self.x, self.aspect_ratio),
+ options={"aot_inductor.package": True},
+ dynamic_shapes=(
+ self.dynamic_shape,
+ None,
+ ), # assuming aspect ratio is static
+ )
+ with tempfile.TemporaryDirectory() as tmpdir:
+ path = package_aoti(os.path.join(tmpdir, "tpe.pt2"), so)
+ tpe_aoti = load_package(path)
+
+ y = tpe_aoti(self.x, self.aspect_ratio)
+ ref_y = self.ref_tpe(self.x, self.aspect_ratio)
+
+ self.assertTrue(torch.allclose(y, ref_y))
+
+ def test_replace_tile_positional_embedding(self):
+ class Module(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.tpe = TuneTilePositionalEmbedding(4, 1280)
+
+ def forward(self, x, aspect_ratio):
+ return self.tpe(x, aspect_ratio)
+
+ m = Module()
+ m = replace_tile_positional_embedding(m)
+ self.assertTrue(isinstance(m.tpe, TilePositionalEmbedding))
+
+
+class TiledTokenPositionalEmbeddingTest(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+ self.tpe = TiledTokenPositionalEmbedding(4, 1280, 40, 1)
+ self.ref_tpe = TuneTiledTokenPositionalEmbedding(4, 1280, 40, 1)
+ self.tpe.load_state_dict(self.ref_tpe.state_dict())
+ self.x = torch.randn(1, 4, 1601, 1280)
+ self.aspect_ratio = torch.tensor([[1, 2]])
+ num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4)
+
+ self.dynamic_shape = {
+ 0: 1, # batch
+ 1: num_tiles_dim, # num tiles
+ 2: 1601, # num tokens
+ 3: 1280, # embedding dim
+ }
+
+ def test_tiled_token_positional_embedding_smoke(self):
+ y = self.tpe(self.x, self.aspect_ratio)
+ ref_y = self.ref_tpe(self.x, self.aspect_ratio)
+
+ assert_close(y, ref_y)
+
+ @unittest.skipUnless(
+ torch_version_ge("2.6.0.dev20241117"), reason="Need recent fixes for export"
+ )
+ def test_tiled_token_positional_embedding_export(self):
+
+ tpe_ep = torch.export.export(
+ self.tpe,
+ (self.x, self.aspect_ratio),
+ dynamic_shapes=(
+ self.dynamic_shape,
+ None,
+ ), # assuming aspect ratio is static
+ )
+
+ y = tpe_ep.module()(self.x, self.aspect_ratio)
+ ref_y = self.ref_tpe(self.x, self.aspect_ratio)
+
+ assert_close(y, ref_y)
+
+ @unittest.skipUnless(
+ torch_version_ge("2.6.0.dev20241117"), reason="Need recent fixes for aoti"
+ )
+ def test_tiled_token_positional_embedding_aoti(self):
+ tpe_ep = torch.export.export(
+ self.tpe,
+ (self.x, self.aspect_ratio),
+ dynamic_shapes=(
+ self.dynamic_shape,
+ None,
+ ), # assuming aspect ratio is static
+ )
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ path = torch._inductor.aoti_compile_and_package(
+ tpe_ep,
+ (self.x, self.aspect_ratio),
+ package_path=os.path.join(tmpdir, "tpe.pt2"),
+ )
+ tpe_aoti = load_package(path)
+
+ y = tpe_aoti(self.x, self.aspect_ratio)
+ ref_y = self.ref_tpe(self.x, self.aspect_ratio)
+
+ assert_close(y, ref_y)
+
+ def test_replace_tiled_token_positional_embedding(self):
+ class Module(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.tpe = TuneTiledTokenPositionalEmbedding(4, 1280, 40, 1)
+
+ def forward(self, x, aspect_ratio):
+ return self.tpe(x, aspect_ratio)
+
+ m = Module()
+ m = replace_tiled_token_positional_embedding(m)
+ self.assertTrue(isinstance(m.tpe, TiledTokenPositionalEmbedding))
diff --git a/tests/torchtune/modules/peft/test_lora.py b/tests/torchtune/modules/peft/test_lora.py
index 80d2b2d767..ff03b1d3c4 100644
--- a/tests/torchtune/modules/peft/test_lora.py
+++ b/tests/torchtune/modules/peft/test_lora.py
@@ -14,14 +14,17 @@
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4
from torchtune import training
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook
-from torchtune.modules.peft import LoRALinear
+from torchtune.modules.peft import LoRALinear, QATLoRALinear
+from torchtune.training.quantization import _torchao_0_7_supported
from torchtune.training.seed import set_seed
+
RANK = 4
ALPHA = 1.0
BSZ = 2
SEQ_LEN = 32
EXPECTED_VAL = 1.1252
+QAT_EXPECTED_VAL = 0.6291
@pytest.fixture(autouse=True)
@@ -232,3 +235,12 @@ def test_quantized_state_dict(self, dtype):
assert torch.allclose(
lora_linear.weight.quantized_data, lora_linear_reload.weight.quantized_data
)
+
+ @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
+ def test_qat_lora_forward(self, inputs, lora_linear, out_dim) -> None:
+ lora_linear = lora_linear(use_bias=True, dtype=torch.float32)
+ qat_lora_linear = QATLoRALinear.from_lora_linear(lora_linear)
+ expected = torch.tensor(QAT_EXPECTED_VAL)
+ actual = qat_lora_linear(inputs)
+ assert actual.shape == (BSZ, SEQ_LEN, out_dim)
+ torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)
diff --git a/tests/torchtune/modules/test_rms_norm.py b/tests/torchtune/modules/test_rms_norm.py
index c570e583f0..747b4d860d 100644
--- a/tests/torchtune/modules/test_rms_norm.py
+++ b/tests/torchtune/modules/test_rms_norm.py
@@ -5,12 +5,10 @@
# LICENSE file in the root directory of this source tree.
import pytest
-
import torch
from tests.test_utils import assert_expected
from torch.nn.functional import normalize
-
from torchtune.modules.rms_norm import RMSNorm
from torchtune.training.seed import set_seed
@@ -66,6 +64,7 @@ def test_forward_fp16(self, rms_norm, input_random_fp16, dim) -> None:
# convert input to float since rms_norm computes in fp32
expected_fp16 = normalize(input_random_fp16.float(), p=2, dim=-1) * (dim**0.5)
+ expected_fp16 = expected_fp16.to(torch.float16)
assert_expected(output_fp16, expected_fp16, atol=1e-7, rtol=1e-3)
- assert output_fp16.dtype == torch.float32
+ assert output_fp16.dtype == torch.float16
diff --git a/tests/torchtune/modules/test_vq_embeddings.py b/tests/torchtune/modules/test_vq_embeddings.py
new file mode 100644
index 0000000000..b8c1e83286
--- /dev/null
+++ b/tests/torchtune/modules/test_vq_embeddings.py
@@ -0,0 +1,114 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import pytest
+
+import torch
+from tests.test_utils import assert_expected
+from torch import tensor
+from torchtune.modules.vq_embeddings import VectorQuantizedEmbeddings
+
+
+@pytest.fixture(autouse=True)
+def random_seed():
+ torch.manual_seed(4)
+
+
+class TestVectorQuantizedEmbeddings:
+ @pytest.fixture
+ def num_embeddings(self):
+ return 4
+
+ @pytest.fixture
+ def embedding_dim(self):
+ return 5
+
+ @pytest.fixture
+ def embedding_weights(self):
+ # This is 4x5
+ return tensor(
+ [
+ [1.0, 0.0, -1.0, -1.0, 2.0],
+ [2.0, -2.0, 0.0, 0.0, 1.0],
+ [2.0, 1.0, 0.0, 1.0, 1.0],
+ [-1.0, -2.0, 0.0, 2.0, 0.0],
+ ]
+ )
+
+ @pytest.fixture
+ def codebook(self, num_embeddings, embedding_dim, embedding_weights):
+ vq = VectorQuantizedEmbeddings(
+ num_embeddings=num_embeddings,
+ embedding_dim=embedding_dim,
+ )
+ vq.embedding.data = embedding_weights
+ return vq
+
+ @pytest.fixture
+ def encoded(self):
+ # This is 2x3x5
+ encoded = tensor(
+ [
+ [
+ [-1.0, 2.0, 0.0, 0.0, -2.0],
+ [0.0, 1.0, -1.0, 2.0, -1.0],
+ [1.0, 0.0, -1.0, -1.0, 1.0],
+ ],
+ [
+ [2.0, 1.0, 0.0, 1.0, 1.0],
+ [2.0, -1.0, 0.0, 2.0, 0.0],
+ [-1.0, -2.0, 0.0, 1.0, 0.0],
+ ],
+ ]
+ )
+ encoded.requires_grad_()
+
+ return encoded
+
+ def test_quantized_output(self, codebook, encoded):
+ actual = codebook(encoded)
+
+ expected_quantized = tensor(
+ [
+ [
+ [2.0, 1.0, 0.0, 1.0, 1.0],
+ [2.0, 1.0, 0.0, 1.0, 1.0],
+ [1.0, 0.0, -1.0, -1.0, 2.0],
+ ],
+ [
+ [2.0, 1.0, 0.0, 1.0, 1.0],
+ [2.0, -2.0, 0.0, 0.0, 1.0],
+ [-1.0, -2.0, 0.0, 2.0, 0.0],
+ ],
+ ]
+ )
+ expected_token_ids = tensor([[2.0, 2.0, 0.0], [2.0, 1.0, 3.0]]).type(
+ torch.LongTensor
+ )
+
+ assert_expected(actual[0], expected_quantized)
+ assert_expected(actual[1], expected_token_ids)
+
+ def test_decode(self, codebook):
+ indices_flat = tensor([[0, 1]]) # (b, seq_len)
+ indices_shaped = tensor([[[0, 1], [2, 3]]]) # (b, shape)
+ actual_quantized_flat = codebook.decode(indices_flat)
+ actual_quantized = codebook.decode(indices_shaped)
+ expected_quantized_flat = tensor(
+ [[[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]]]
+ )
+ expected_quantized = tensor(
+ [
+ [
+ [[1.0, 0.0, -1.0, -1.0, 2.0], [2.0, -2.0, 0.0, 0.0, 1.0]],
+ [[2.0, 1.0, 0.0, 1.0, 1.0], [-1.0, -2.0, 0.0, 2.0, 0.0]],
+ ]
+ ]
+ )
+ assert_expected(
+ actual_quantized_flat, expected_quantized_flat, rtol=0.0, atol=1e-4
+ )
+ assert_expected(actual_quantized, expected_quantized, rtol=0.0, atol=1e-4)
diff --git a/tests/torchtune/training/test_precision.py b/tests/torchtune/training/test_precision.py
index 6f94ffd9db..5f9fd3e89c 100644
--- a/tests/torchtune/training/test_precision.py
+++ b/tests/torchtune/training/test_precision.py
@@ -89,3 +89,9 @@ def test_validate_expected_param_dtype(self):
m = torch.nn.Linear(10, 10)
with pytest.raises(ValueError, match=f"has dtype {next(m.parameters()).dtype}"):
validate_expected_param_dtype(m.named_parameters(), dtype=torch.float16)
+
+ validate_expected_param_dtype(
+ m.named_parameters(),
+ dtype=torch.float16,
+ exclude_param_names=[name for name, _ in m.named_parameters()],
+ )
diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py
index 5bbf860482..5b12efe11c 100644
--- a/torchtune/_recipe_registry.py
+++ b/torchtune/_recipe_registry.py
@@ -106,6 +106,7 @@ class Recipe:
Config(name="llama3_2/3B_full", file_path="llama3_2/3B_full.yaml"),
Config(name="llama3/70B_full", file_path="llama3/70B_full.yaml"),
Config(name="llama3_1/70B_full", file_path="llama3_1/70B_full.yaml"),
+ Config(name="llama3_3/70B_full", file_path="llama3_3/70B_full.yaml"),
Config(name="mistral/7B_full", file_path="mistral/7B_full.yaml"),
Config(name="gemma/2B_full", file_path="gemma/2B_full.yaml"),
Config(name="gemma/7B_full", file_path="gemma/7B_full.yaml"),
@@ -302,6 +303,10 @@ class Recipe:
name="llama2/7B_lora_dpo_single_device",
file_path="llama2/7B_lora_dpo_single_device.yaml",
),
+ Config(
+ name="llama3_1/8B_lora_dpo_single_device",
+ file_path="llama3_1/8B_lora_dpo_single_device.yaml",
+ ),
],
supports_distributed=False,
),
@@ -313,6 +318,10 @@ class Recipe:
name="llama2/7B_lora_dpo",
file_path="llama2/7B_lora_dpo.yaml",
),
+ Config(
+ name="llama3_1/8B_lora_dpo",
+ file_path="llama3_1/8B_lora_dpo.yaml",
+ ),
],
supports_distributed=True,
),
@@ -345,6 +354,8 @@ class Recipe:
Config(name="llama3/8B_dora", file_path="llama3/8B_dora.yaml"),
Config(name="llama3/70B_lora", file_path="llama3/70B_lora.yaml"),
Config(name="llama3_1/70B_lora", file_path="llama3_1/70B_lora.yaml"),
+ Config(name="llama3_3/70B_lora", file_path="llama3_3/70B_lora.yaml"),
+ Config(name="llama3_3/70B_qlora", file_path="llama3_3/70B_qlora.yaml"),
Config(name="llama3/8B_lora", file_path="llama3/8B_lora.yaml"),
Config(name="llama3_1/8B_lora", file_path="llama3_1/8B_lora.yaml"),
Config(name="llama3_2/1B_lora", file_path="llama3_2/1B_lora.yaml"),
@@ -456,17 +467,28 @@ class Recipe:
],
supports_distributed=True,
),
+ Recipe(
+ name="qat_lora_finetune_distributed",
+ file_path="qat_lora_finetune_distributed.py",
+ configs=[
+ Config(name="llama3/8B_qat_lora", file_path="llama3/8B_qat_lora.yaml"),
+ Config(name="llama3_1/8B_qat_lora", file_path="llama3_1/8B_qat_lora.yaml"),
+ Config(name="llama3_2/1B_qat_lora", file_path="llama3_2/1B_qat_lora.yaml"),
+ Config(name="llama3_2/3B_qat_lora", file_path="llama3_2/3B_qat_lora.yaml"),
+ ],
+ supports_distributed=True,
+ ),
Recipe(
name="knowledge_distillation_single_device",
file_path="knowledge_distillation_single_device.py",
configs=[
Config(
- name="qwen2/knowledge_distillation_single_device",
- file_path="qwen2/knowledge_distillation_single_device.yaml",
+ name="qwen2/1.5_to_0.5B_KD_lora_single_device",
+ file_path="qwen2/1.5_to_0.5B_KD_lora_single_device.yaml",
),
Config(
- name="llama3_2/knowledge_distillation_single_device",
- file_path="llama3_2/knowledge_distillation_single_device.yaml",
+ name="llama3_2/8B_to_1B_KD_lora_single_device",
+ file_path="llama3_2/8B_to_1B_KD_lora_single_device.yaml",
),
],
supports_distributed=False,
@@ -476,12 +498,12 @@ class Recipe:
file_path="knowledge_distillation_distributed.py",
configs=[
Config(
- name="qwen2/knowledge_distillation_distributed",
- file_path="qwen2/knowledge_distillation_distributed.yaml",
+ name="qwen2/1.5_to_0.5B_KD_lora_distributed",
+ file_path="qwen2/1.5_to_0.5B_KD_lora_distributed.yaml",
),
Config(
- name="llama3_2/knowledge_distillation_distributed",
- file_path="llama3_2/knowledge_distillation_distributed.yaml",
+ name="llama3_2/8B_to_1B_KD_lora_distributed",
+ file_path="llama3_2/8B_to_1B_KD_lora_distributed.yaml",
),
],
supports_distributed=True,
diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py
index a6b356b0ca..bbd3ae5981 100644
--- a/torchtune/data/_messages.py
+++ b/torchtune/data/_messages.py
@@ -170,6 +170,7 @@ class InputOutputToMessages(Transform):
Raises:
ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or
``output`` not in ``column_map``.
+ ValueError: If ``image_dir`` is provided but ``image`` not in ``column_map``.
"""
def __init__(
@@ -196,6 +197,14 @@ def __init__(
else:
self.column_map = {"input": "input", "output": "output", "image": "image"}
+ # Ensure that if a user seems to want to construct a multimodal transform, they provide
+ # a proper column_mapping
+ if "image" not in self.column_map.keys() and image_dir is not None:
+ raise ValueError(
+ f"image_dir is specified as {image_dir} but 'image' is not in column_map. "
+ "Please specify an 'image' key in column_map."
+ )
+
self.image_dir = image_dir
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
@@ -206,8 +215,13 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
if is_multimodal:
image_path = sample[self.column_map["image"]]
if isinstance(image_path, str):
+ # Convert image_path to Path obj
+ image_path = Path(image_path)
+
+ # If image_dir is not None, prepend image_dir to image_path
if self.image_dir is not None:
image_path = self.image_dir / image_path
+
# Load if not loaded
pil_image = load_image(image_path)
else:
diff --git a/torchtune/datasets/_preference.py b/torchtune/datasets/_preference.py
index 1cc53b3626..dea4eec852 100644
--- a/torchtune/datasets/_preference.py
+++ b/torchtune/datasets/_preference.py
@@ -89,9 +89,14 @@ class requires the dataset to have "chosen" and "rejected" model responses. Thes
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs `_ for more
details.
+ packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False. Packed is
+ currently not supported for ``PreferenceDataset`` and a ``ValueError`` will be raised if this is set to True.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. See Hugging
Face's `API ref `_
for more details.
+
+ Raises:
+ ValueError: If ``packed`` is True, this feature is not supported for ``PreferenceDataset``.
"""
def __init__(
@@ -101,8 +106,14 @@ def __init__(
message_transform: Transform,
tokenizer: ModelTokenizer,
filter_fn: Optional[Callable] = None,
+ packed: bool = False,
**load_dataset_kwargs: Dict[str, Any],
) -> None:
+ if packed:
+ raise ValueError(
+ "Packed is currently not supported for preference datasets."
+ )
+
self._tokenizer = tokenizer
self._message_transform = message_transform
self._data = load_dataset(source, **load_dataset_kwargs)
diff --git a/torchtune/datasets/_slimorca.py b/torchtune/datasets/_slimorca.py
index da3e343fd4..126b6b92e4 100644
--- a/torchtune/datasets/_slimorca.py
+++ b/torchtune/datasets/_slimorca.py
@@ -65,7 +65,7 @@ def slimorca_dataset(
ValueError: If ``packed=True`` and ``tokenizer.max_seq_len`` is not set.
Example:
- >>> ds = slimorca_dataset(model_transform=tokenizer)
+ >>> ds = slimorca_dataset(tokenizer=tokenizer)
>>> for input, label in ds:
>>> print(input)
>>> print(label)
diff --git a/torchtune/datasets/multimodal/_llava_instruct.py b/torchtune/datasets/multimodal/_llava_instruct.py
index 0dff69879d..2f218731ba 100644
--- a/torchtune/datasets/multimodal/_llava_instruct.py
+++ b/torchtune/datasets/multimodal/_llava_instruct.py
@@ -12,7 +12,6 @@
from torchtune.modules.transforms import Transform
-# TODO: point to Flamingo model transform as an example
def llava_instruct_dataset(
model_transform: Transform,
*,
@@ -119,6 +118,8 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
>>> print(f"Batch size: {len(batch)}")
>>> Batch size: 8
"""
+ if packed:
+ raise ValueError("Multimodal datasets don't support packing yet.")
message_transform = ShareGPTToMessages(
train_on_input=False,
@@ -137,6 +138,5 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
data_files=data_files,
**load_dataset_kwargs,
)
- if packed:
- raise ValueError("Multimodal datasets don't support packing yet.")
+
return ds
diff --git a/torchtune/datasets/multimodal/_multimodal.py b/torchtune/datasets/multimodal/_multimodal.py
index ad519155a8..83673a2e1a 100644
--- a/torchtune/datasets/multimodal/_multimodal.py
+++ b/torchtune/datasets/multimodal/_multimodal.py
@@ -18,6 +18,7 @@ def multimodal_chat_dataset(
source: str,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
+ packed: bool = False,
image_tag: Optional[str] = None,
image_dir: Optional[str] = None,
filter_fn: Optional[Callable] = None,
@@ -79,6 +80,7 @@ def multimodal_chat_dataset(
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
+ packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
image_tag (Optional[str]): placeholder tags in the text content of each message to be replaced by dictionaries
indicating to the tokenizer where to place image tokens. If images are present and this is None,
then will prepend image tokens to the first user message in the sample by default. If text-only, leave
@@ -119,39 +121,47 @@ def multimodal_chat_dataset(
::
>>> from torchtune.datasets.multimodal import multimodal_chat_dataset
- >>> from torchtune.models.flamingo import FlamingoTransform
- >>> model_transform = FlamingoTransform(
- ... path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
- ... tile_size=224,
- ... patch_size=14,
- ... )
- >>> dataset = multimodal_chat_dataset(
- ... model_transform=model_transform,
- ... source="json",
- ... data_files="my_dataset.json",
- ... column_map={
- ... "dialogue": "conversations",
- ... "image_path": "image",
- ... },
- ... image_dir="/home/user/dataset/", # /home/user/dataset/images/clock.jpg
- ... image_tag="",
- ... split="train",
- ... )
- >>> tokens = dataset[0]["tokens"]
- >>> model_transform.decode(tokens, skip_special_tokens=True)
- "What time is it on the clock?It is 10:00 AM."
- >>> print(dataset[0]["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
- torch.Size([4, 3, 224, 224])
-
+ >>> from torchtune.models.llama3_2_vision import llama3_2_vision_transform
+ >>> from torchtune.datasets.multimodal import multimodal_chat_dataset
+ >>> model_transform = Llama3VisionTransform(
+ >>> path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
+ >>> prompt_template="torchtune.data.QuestionAnswerTemplate",
+ >>> max_seq_len=8192,
+ >>> image_size=560,
+ >>> )
+ >>> ds = multimodal_chat_dataset(
+ >>> model_transform=model_transform,
+ >>> source="json",
+ >>> data_files="data/my_data.json",
+ >>> column_map={
+ >>> "dialogue": "conversations",
+ >>> "image_path": "image",
+ >>> },
+ >>> image_dir="/home/user/dataset/", # /home/user/dataset/images/clock.jpg
+ >>> image_tag="",
+ >>> split="train",
+ >>> )
+ >>> tokenized_dict = ds[0]
+ >>> print(model_transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
+ >>> # '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nQuestion:<|image|>What time is it on the clock?Answer:<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nIt is 10:00AM.<|eot_id|>' # noqa
+ >>> print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
+ >>> # torch.Size([4, 3, 224, 224])
This can also be accomplished via the yaml config:
.. code-block:: yaml
+ tokenizer:
+ _component_: torchtune.models.llama3_2_vision_transform
+ path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
+ prompt_template: torchtune.data.QuestionAnswerTemplate
+ max_seq_len: 8192
+
dataset:
_component_: torchtune.datasets.multimodal.multimodal_chat_dataset
source: json
- data_files: my_dataset.json
+ data_files: data/my_data.json
+ split: train
column_map:
dialogue: conversations
image_path: image
@@ -161,7 +171,14 @@ def multimodal_chat_dataset(
Returns:
SFTDataset: the configured :class:`~torchtune.datasets.SFTDataset`
+
+ Raises:
+ ValueError: If ``packed`` is True, they are not supported for multimodal datasets yet.
+
"""
+ if packed:
+ raise ValueError("Multimodal datasets don't support packing yet.")
+
message_transform = ShareGPTToMessages(
train_on_input=False,
column_map=column_map,
diff --git a/torchtune/datasets/multimodal/_the_cauldron.py b/torchtune/datasets/multimodal/_the_cauldron.py
index 4aa54dd2dc..8887edf827 100644
--- a/torchtune/datasets/multimodal/_the_cauldron.py
+++ b/torchtune/datasets/multimodal/_the_cauldron.py
@@ -56,7 +56,7 @@ class TheCauldronToMessages(Transform):
]
Args:
- column_map (Optional[Dict[str, str]]): a mapping to change the expected "texts"
+ column_map (Optional[Dict[str, str]]): a mapping to change the expected "texts" and "image"
column names to the actual column names in the dataset. Default is None,
keeping the default column names.
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
@@ -121,7 +121,6 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
return {"messages": messages}
-# TODO: point to Flamingo model transform as an example
def the_cauldron_dataset(
model_transform: Transform,
*,
@@ -217,6 +216,8 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
>>> print(f"Batch size: {len(batch)}")
>>> Batch size: 8
"""
+ if packed:
+ raise ValueError("Multimodal datasets don't support packing yet.")
message_transform = TheCauldronToMessages(
column_map=column_map,
@@ -232,6 +233,5 @@ def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
split=split,
**load_dataset_kwargs,
)
- if packed:
- raise ValueError("Multimodal datasets don't support packing yet.")
+
return ds
diff --git a/torchtune/datasets/multimodal/_vqa.py b/torchtune/datasets/multimodal/_vqa.py
index 27e8bbe1d4..ce991f07ec 100644
--- a/torchtune/datasets/multimodal/_vqa.py
+++ b/torchtune/datasets/multimodal/_vqa.py
@@ -18,6 +18,7 @@ def vqa_dataset(
image_dir: str = None,
column_map: Optional[Dict[str, str]] = None,
new_system_prompt: Optional[str] = None,
+ packed: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
@@ -63,6 +64,7 @@ def vqa_dataset(
new_system_prompt (Optional[str]): if specified, prepend a system message. This can
serve as instructions to guide the model response. Setting this will OVERRIDE any system
messages already present in the dataset. Default is None.
+ packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs `_ for more
details.
@@ -122,7 +124,14 @@ def vqa_dataset(
Returns:
SFTDataset: the configured :class:`~torchtune.datasets.SFTDataset`
+
+ Raises:
+ ValueError: If ``packed`` is True, they are not supported for multimodal datasets yet.
+
"""
+ if packed:
+ raise ValueError("Multimodal datasets don't support packing yet.")
+
message_transform = InputOutputToMessages(
column_map=column_map, new_system_prompt=new_system_prompt, image_dir=image_dir
)
diff --git a/torchtune/models/llama3_3/__init__.py b/torchtune/models/llama3_3/__init__.py
new file mode 100644
index 0000000000..cd5ac4d306
--- /dev/null
+++ b/torchtune/models/llama3_3/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ._model_builders import llama3_3_70b, lora_llama3_3_70b, qlora_llama3_3_70b # noqa
+
+__all__ = [
+ "llama3_3_70b",
+ "lora_llama3_3_70b",
+ "qlora_llama3_3_70b",
+]
diff --git a/torchtune/models/llama3_3/_model_builders.py b/torchtune/models/llama3_3/_model_builders.py
new file mode 100644
index 0000000000..a55973e136
--- /dev/null
+++ b/torchtune/models/llama3_3/_model_builders.py
@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+from torchtune.models.llama3_1._model_builders import (
+ llama3_1_70b,
+ lora_llama3_1_70b,
+ qlora_llama3_1_70b,
+)
+
+"""
+Model builders build specific instantiations using component builders. The Llama3.3 model
+builders all call the Llama3.1 models as they're identical models apart from the checkpoints.
+"""
+
+llama3_3_70b = llama3_1_70b
+
+llama3_3_70b.__doc__ = """
+Builder for creating a Llama3.3 model initialized w/ the default 70B parameter values.
+Please see `llama3_1_70b` for full API arguments.
+"""
+
+lora_llama3_3_70b = lora_llama3_1_70b
+
+lora_llama3_3_70b.__doc__ = """
+Builder for creating a Llama3.3 70B model with LoRA enabled.
+Please see `lora_llama3_1_70b` for full API arguments.
+"""
+
+qlora_llama3_3_70b = qlora_llama3_1_70b
+
+qlora_llama3_1_70b.__doc__ = """
+Builder for creating a Llama3.3 70B model with QLoRA enabled. Base model weights in linear layers
+that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
+Please see `lora_llama3_1_70b` for full API arguments.
+"""
diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py
index 29c014c33e..338c9ef53e 100644
--- a/torchtune/modules/__init__.py
+++ b/torchtune/modules/__init__.py
@@ -29,6 +29,7 @@
TransformerSelfAttentionLayer,
)
from .vision_transformer import VisionTransformer
+from .vq_embeddings import VectorQuantizedEmbeddings
__all__ = [
"MultiHeadAttention",
@@ -38,6 +39,7 @@
"KVCache",
"RotaryPositionalEmbeddings",
"VisionRotaryPositionalEmbeddings",
+ "VectorQuantizedEmbeddings",
"RMSNorm",
"TiedLinear",
"Fp32LayerNorm",
diff --git a/torchtune/modules/_export/README.md b/torchtune/modules/_export/README.md
index 49ea5ac851..6df701ded9 100644
--- a/torchtune/modules/_export/README.md
+++ b/torchtune/modules/_export/README.md
@@ -1,3 +1,16 @@
# Export
This directory provides [exportable](https://pytorch.org/docs/stable/export.html) variants of torchtune modules.
+
+Modules in this directory:
+
+* Take the same arguments to `__init__()` and `forward()` as the corresponding reference modules in torchtune.
+* Give the output as the reference module in torchtune (unless stated otherwise in the docstring).
+* Are guaranteed to work out of the box with torch.export.export().
+* Should work out of the box with torch.aot_compile().
+
+All modules should be covered by unit tests (under `tests/torchtune/modules/_export/`) that runs daily and on PRs touching this directory.
+
+These modules are subject to change so proceed with caution.
+
+Contributors: @larryliu0820, @Jack-Khuu, @dvorjackz
diff --git a/torchtune/modules/_export/_position_embeddings.py b/torchtune/modules/_export/_position_embeddings.py
new file mode 100644
index 0000000000..0489b7f345
--- /dev/null
+++ b/torchtune/modules/_export/_position_embeddings.py
@@ -0,0 +1,747 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# An torch.export() friendly version of torchtune's positional embeddings.
+# Added torch._check() to make sure guards on symints are enforced.
+# See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py
+
+import logging
+import math
+from typing import Any, Dict, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.distributed._tensor import distribute_tensor, DTensor
+
+FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
+logging.basicConfig(level=logging.INFO, format=FORMAT)
+
+
+class TilePositionalEmbedding(nn.Module):
+ """
+ Positional embedding for tiles, different for every tile, same for every token within a tile.
+
+ Notice that tile is different from patch (token). For details, please check the documentation of
+ :class:`torchtune.modules.vision_transformer.VisionTransformer`.
+
+ Args:
+ max_num_tiles (int): The maximum number of tiles an image can be divided into.
+ embed_dim (int): The dimensionality of each tile embedding.
+ """
+
+ def __init__(
+ self,
+ max_num_tiles: int,
+ embed_dim: int,
+ ):
+ super().__init__()
+ self.max_num_tiles = max_num_tiles
+ self.embed_dim = embed_dim
+
+ scale = embed_dim**-0.5
+ self.embedding = nn.Parameter(
+ scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim)
+ )
+ self.gate = nn.Parameter(torch.zeros(1))
+
+ # Register load hook to interpolate positional embeddings
+ self._register_load_state_dict_pre_hook(self._load_state_dict_hook)
+
+ @torch.no_grad()
+ def _load_state_dict_hook(
+ self,
+ state_dict: Dict[str, Any],
+ prefix: str,
+ *args: Tuple[Any],
+ **kwargs: Dict[str, Any],
+ ):
+ """
+ Interpolates positional embeddings to accomodate different number of tiles,
+ in case the model was instantiated with different
+ settings than the one you are loading the state dict from.
+
+ For more info, check self._dynamic_resize function.
+
+ Args:
+ state_dict (Dict[str, Any]): The state dict to load.
+ prefix (str): The prefix of the state dict.
+ *args (Tuple[Any]): Additional positional arguments.
+ **kwargs (Dict[str, Any]): Additional keyword arguments.
+
+ Raises:
+ ValueError: if the shape of the loaded embedding is not compatible with the current embedding.
+ ValueError: if max_num_tiles_x, max_num_tiles_y are not equal.
+ ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding.
+ """
+
+ embedding = state_dict.get(prefix + "embedding")
+
+ if embedding is not None:
+
+ # We can only apply F.interpolate to vanilla tensors, not DTensors
+ # If pos embeds are a DTensor, we gather the full tensor, apply
+ # interpolate, and then reshard after
+ if isinstance(embedding, DTensor):
+ embedding_is_sharded = True
+ device_mesh = embedding.device_mesh
+ placements = embedding.placements
+ embedding = embedding.full_tensor()
+ else:
+ embedding_is_sharded = False
+
+ # ckpt pos emb
+ (
+ tgt_max_num_tiles_x,
+ tgt_max_num_tiles_y,
+ tgt_num_tokens,
+ tgt_emb,
+ ) = self.embedding.shape
+
+ # instantiated pos emb
+ (
+ inpt_max_num_tiles_x,
+ inpt_max_num_tiles_y,
+ inpt_num_tokens,
+ inpt_emb,
+ ) = state_dict[prefix + "embedding"].shape
+
+ # sanity check
+ if inpt_num_tokens != tgt_num_tokens or inpt_emb != tgt_emb:
+ raise ValueError(
+ "Expected embedding shape to be (..., num_tokens, tgt_emb) to match"
+ f" but found shapes {self.embedding.shape} and {state_dict[prefix + 'embedding'].shape}"
+ )
+
+ if inpt_max_num_tiles_x != inpt_max_num_tiles_y:
+ raise ValueError(
+ "Expected max_num_tiles_x, max_num_tiles_y to be equal but found, but found"
+ f"(max_num_tiles_x, max_num_tiles_y, 1, embed_dim) = {self.embedding.shape}"
+ )
+
+ # resize ckpt to match instantiated shape
+ embedding_new = self._resize_position_embedding(
+ embedding, tgt_max_num_tiles=tgt_max_num_tiles_x
+ )
+
+ if embedding_is_sharded:
+ embedding_new = distribute_tensor(
+ embedding_new,
+ device_mesh=device_mesh,
+ placements=placements,
+ )
+
+ # update state dict
+ state_dict[prefix + "embedding"] = embedding_new
+ if embedding_new.shape != self.embedding.shape:
+ raise ValueError(
+ "Expected embedding shape and embedding_new.shape to match"
+ f" but found shapes {self.embedding.shape} and {embedding_new.shape}"
+ )
+
+ @staticmethod
+ def _resize_position_embedding(
+ embedding: torch.Tensor, tgt_max_num_tiles: int
+ ) -> torch.Tensor:
+ """
+ Interpolates positional embeddings to accomodate a different max_num_tiles. These
+ are the only dimensions that changes during interpolation.
+
+ Args:
+ embedding (torch.Tensor): torch.Tensor with shape (max_num_tiles, max_num_tiles, 1, embed_dim
+ tgt_max_num_tiles (int): The number of tiles to resize to.
+
+ Returns:
+ torch.Tensor: The resized embedding.
+
+ Example:
+ >>> import torch
+ >>> # create dummy embedding
+ >>> embedding = torch.arange(2*2*2*2).reshape(2, 2, 2, 2).float()
+ >>> resized_embed = _dynamic_resize(embedding, tgt_max_num_tiles=1)
+ >>> print(resized_embed.shape)
+ >>> torch.Size([1, 1, 2, 2])
+ """
+ # set max_num_tiles to the last dimension
+ embedding = embedding.permute(2, 3, 0, 1)
+
+ embedding = F.interpolate(
+ embedding,
+ size=(tgt_max_num_tiles, tgt_max_num_tiles),
+ mode="bilinear",
+ align_corners=True,
+ )
+ # permute to the original shape
+ embedding = embedding.permute(2, 3, 0, 1)
+ return embedding.contiguous()
+
+ def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
+ """
+ args:
+ x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
+ aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2),
+ representing the aspect ratio of the image before tile-cropping, e.g. (2,1).
+ returns:
+ torch.Tensor: The input tensor with added positional embeddings.
+ """
+ bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape
+ torch._check(n_tiles <= self.max_num_tiles)
+
+ for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio):
+ # When we batch images, all are padded to the same amount of tiles.
+ # The aspect_ratio lets us know the non padded tiles for each image.
+ # We only add positional encoding to those.
+ n_tiles_h = n_tiles_h.item()
+ n_tiles_w = n_tiles_w.item()
+
+ n_non_padded_tiles = int(n_tiles_h * n_tiles_w)
+
+ # We get only the positional encoding for non padded tiles,
+ # i.e. n_tiles_h, n_tiles_w.
+ torch._check_is_size(n_tiles_h)
+ torch._check_is_size(n_tiles_w)
+ torch._check(n_tiles_h >= 1)
+ torch._check(n_tiles_w >= 1)
+ torch._check(n_tiles_h <= self.max_num_tiles)
+ torch._check(n_tiles_w <= self.max_num_tiles)
+ # TODO: Remove this once pytorch/pytorch#120288 is fixed
+ padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1))
+ pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :]
+
+ # We need to do a clone here in order to make this model export
+ # friendly as the reshape is collapsing dim 0 and dim 1 into a
+ # single dim.
+ pos_embed = pos_embed.clone()
+ pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim)
+
+ x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0))
+ torch._check_is_size(n_non_padded_tiles)
+ torch._check(n_non_padded_tiles < x.size(1))
+ x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed * self.gate.tanh()
+ x = x[:, :n_tiles, :, :]
+
+ return x
+
+
+class TiledTokenPositionalEmbedding(nn.Module):
+ """
+
+ Token positional embedding for tiled images, different for every tile, different for every token.
+
+ There are two positional embeddings in this module:
+
+ * local_token_positional_embedding: same for every tile, different for every token. Equivalent \
+ to :class:`torchtune.models.clip._position_embeddings.TokenPositionalEmbedding`, but gated.
+ * global_token_positional_embedding: different for every tile, different for every token.
+
+ Notice that tile is different from patch (token). For details, please check the documentation of
+ :class:`torchtune.modules.vision_transformer.VisionTransformer`.
+
+ Args:
+ max_num_tiles (int): The maximum number of tiles an image can be divided into.
+ embed_dim (int): The dimensionality of each token embedding.
+ tile_size (int): The size of your image tiles, if the image was tile-cropped in advance. Otherwise,
+ the size of the input image. In this case, the function will consider your image as a single tile.
+ patch_size (int): The size of each patch. Used to divide the tiles into patches.
+ E.g. for ``patch_size=40``, a tile of shape (400, 400) will have 10x10 grid of patches
+ with shape (40, 40) each.
+ """
+
+ def __init__(
+ self, max_num_tiles: int, embed_dim: int, tile_size: int, patch_size: int
+ ) -> None:
+ super().__init__()
+
+ patch_grid_size = tile_size // patch_size
+ self.n_tokens_per_tile = patch_grid_size**2 + 1 # +1 for cls token
+ scale = embed_dim**-0.5
+
+ # different for every token, same for every tile
+ self.local_token_positional_embedding = nn.Parameter(
+ scale * torch.randn((self.n_tokens_per_tile, embed_dim))
+ )
+
+ # different for every token, different for every tile
+ self.global_token_positional_embedding = nn.Parameter(
+ scale
+ * torch.randn(
+ max_num_tiles,
+ max_num_tiles,
+ self.n_tokens_per_tile,
+ embed_dim,
+ )
+ )
+ self.max_num_tiles = max_num_tiles
+ self.gate = nn.Parameter(torch.zeros(1))
+
+ self._register_load_state_dict_pre_hook(self._load_state_dict_hook)
+
+ @torch.no_grad()
+ def _load_state_dict_hook(
+ self,
+ state_dict: Dict[str, Any],
+ prefix: str,
+ *args: Tuple[Any],
+ **kwargs: Dict[str, Any],
+ ) -> None:
+ """
+ Interpolates positional embeddings to accomodate different number of tiles
+ and tokens per tile, in case the model was instantiated with different
+ settings than the one you are loading the state dict from.
+
+ For more info, please check self._resize_local_position_embedding and
+ self._resize_global_position_embedding functions.
+
+ Args:
+ state_dict (Dict[str, Any]): The state dict to load.
+ prefix (str): The prefix of the state dict.
+ *args (Tuple[Any]): Additional positional arguments.
+ **kwargs (Dict[str, Any]): Additional keyword arguments.
+
+ Raises:
+ ValueError: if loaded local or global embedding n_tokens_per_tile is not derived
+ from a squared grid.
+ ValueError: if after interpolation, the shape of the loaded local embedding
+ is not compatible with the current embedding.
+ ValueError: if after interpolation, the shape of the loaded global embedding
+ is not compatible with the current embedding.
+ """
+
+ # process local_token_positional_embedding
+ inpt_local_pos_embed = state_dict.get(
+ prefix + "local_token_positional_embedding"
+ )
+
+ if inpt_local_pos_embed is not None:
+
+ # We can only apply F.interpolate to vanilla tensors, not DTensors
+ # If pos embeds are a DTensor, we gather the full tensor, apply
+ # interpolate, and then reshard after
+ if isinstance(inpt_local_pos_embed, DTensor):
+ local_embed_is_sharded = True
+ local_embed_device_mesh = inpt_local_pos_embed.device_mesh
+ local_embed_placements = inpt_local_pos_embed.placements
+ inpt_local_pos_embed = inpt_local_pos_embed.full_tensor()
+ else:
+ local_embed_is_sharded = False
+
+ # sanity check
+ inpt_n_tokens_per_tile, inpt_embed_dim = inpt_local_pos_embed.shape
+ if math.sqrt(inpt_n_tokens_per_tile - 1) % 1 != 0:
+ raise ValueError(
+ f"Loaded local positional embedding has shape {inpt_n_tokens_per_tile=}, "
+ f"which indicates a grid_size that is not squared. This is currently not supported."
+ )
+
+ # instantiated pos emb
+ (
+ tgt_n_tokens_per_tile,
+ tgt_embed_dim,
+ ) = self.local_token_positional_embedding.shape
+
+ # resize ckpt to match instantiated shape
+ inpt_local_pos_embed = self._resize_local_position_embedding(
+ local_pos_embed=inpt_local_pos_embed,
+ tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
+ )
+
+ if local_embed_is_sharded:
+ inpt_local_pos_embed = distribute_tensor(
+ inpt_local_pos_embed,
+ device_mesh=local_embed_device_mesh,
+ placements=local_embed_placements,
+ )
+
+ # update state dict
+ state_dict[
+ prefix + "local_token_positional_embedding"
+ ] = inpt_local_pos_embed
+ if (
+ inpt_local_pos_embed.shape
+ != self.local_token_positional_embedding.shape
+ ):
+ raise ValueError(
+ f"Loaded local positional embedding has shape {inpt_local_pos_embed.shape}, "
+ f"after interpolation. Expected shape {self.local_token_positional_embedding.shape}."
+ )
+
+ # process global_token_positional_embedding
+ inpt_global_pos_embed = state_dict.get(
+ prefix + "global_token_positional_embedding"
+ )
+
+ if inpt_global_pos_embed is not None:
+
+ # We can only apply F.interpolate to vanilla tensors, not DTensors
+ # If pos embeds are a DTensor, we gather the full tensor, apply
+ # interpolate, and then reshard after
+ if isinstance(inpt_global_pos_embed, DTensor):
+ global_embed_is_sharded = True
+ global_embed_device_mesh = inpt_global_pos_embed.device_mesh
+ global_embed_placements = inpt_global_pos_embed.placements
+ inpt_global_pos_embed = inpt_global_pos_embed.full_tensor()
+ else:
+ global_embed_is_sharded = False
+
+ _, _, inpt_n_tokens_per_tile, _ = inpt_global_pos_embed.shape
+
+ # sanity check
+ if math.sqrt(inpt_n_tokens_per_tile - 1) % 1 != 0:
+ raise ValueError(
+ f"Loaded local positional embedding has shape {inpt_n_tokens_per_tile=}, "
+ f"which indicates a grid_size that is not squared. This is currently not supported."
+ )
+
+ # instantiated pos emb
+ (
+ tgt_max_num_tiles_x,
+ tgt_max_num_tiles_y, # not used, same as tgt_max_num_tiles_x
+ tgt_n_tokens_per_tile,
+ tgt_embed_dim,
+ ) = self.global_token_positional_embedding.shape
+
+ # resize ckpt to match instantiated shape
+ inpt_global_pos_embed = self._resize_global_position_embedding(
+ global_pos_embed=inpt_global_pos_embed,
+ tgt_max_num_tiles=tgt_max_num_tiles_x,
+ tgt_patch_grid_size=int(math.sqrt(tgt_n_tokens_per_tile - 1)),
+ )
+
+ if global_embed_is_sharded:
+ inpt_global_pos_embed = distribute_tensor(
+ inpt_global_pos_embed,
+ device_mesh=global_embed_device_mesh,
+ placements=global_embed_placements,
+ )
+
+ # update state dict
+ state_dict[
+ prefix + "global_token_positional_embedding"
+ ] = inpt_global_pos_embed
+ if (
+ inpt_global_pos_embed.shape
+ != self.global_token_positional_embedding.shape
+ ):
+ raise ValueError(
+ f"Loaded global positional embedding has shape {inpt_global_pos_embed.shape}, "
+ f"after interpolation. Expected shape {self.global_token_positional_embedding.shape}."
+ )
+
+ @staticmethod
+ def _resize_local_position_embedding(
+ local_pos_embed: torch.Tensor, tgt_patch_grid_size: int
+ ) -> torch.Tensor:
+ """
+ Interpolates the local position embedding for a vision encoder to accommodate
+ a different number of tokens per tile. This is the only dimension that
+ changes during interpolation.
+
+ Args:
+ local_pos_embed (torch.Tensor): The position embeddings tensor to be resized. It
+ has shape [n_tokens_per_tile, emb_dim], where the first token is the CLS token
+ and n_tokens_per_tile = patch_grid_size**2 + 1.
+ tgt_patch_grid_size (int): The target size of each patch grid, i.e.,
+ the square root of the number of tokens per tile, excluding the class token.
+
+ Returns:
+ torch.Tensor: The resized position embeddings tensor of shape
+ [tgt_n_tokens_per_tile, dim], where tgt_n_tokens_per_tile = tgt_patch_grid_size**2 + 1.
+
+ Example:
+ >>> import torch
+ >>> import math
+ >>> local_pos_embed = torch.randn((10*10+1, 64)) # Example input tensor
+ >>> tgt_patch_grid_size = 20 # Target number of tokens per tile
+ >>> resized_pos_embed = _resize_local_position_embedding(local_pos_embed, tgt_patch_grid_size)
+ >>> print(resized_pos_embed.shape)
+ torch.Size([20*20+1, 64])
+ """
+ # inverse n_tokens_per_tile = patch_grid_size**2 + 1, where +1 is the cls token
+ inpt_n_tokens_per_tile, inpt_embed_dim = local_pos_embed.shape
+ inpt_patch_grid_size = int(math.sqrt(inpt_n_tokens_per_tile - 1))
+
+ # split tokens between cls and img tokens.
+ # we don't want to interpolate cls token.
+ cls_token, local_pos_embed = (
+ local_pos_embed[[0]], # cls token
+ local_pos_embed[1:], # image tokens
+ )
+
+ # we reshape n_tokens_per_tile - 1 --> (inpt_patch_grid_size, inpt_patch_grid_size)
+ # and permute to have inpt_patch_grid_size as the last two dimensions
+ # we also add a batch dim to the tensor, since F.interpolate expects it
+ local_pos_embed = local_pos_embed.reshape(
+ 1, inpt_patch_grid_size, inpt_patch_grid_size, -1
+ ).permute(0, 3, 1, 2)
+
+ local_pos_embed = F.interpolate(
+ local_pos_embed,
+ size=[tgt_patch_grid_size, tgt_patch_grid_size],
+ mode="bilinear",
+ align_corners=True, # defaults from internal-llama-models
+ )
+
+ # reshape back to [1, tokens_per_tile, embed_dim]
+ local_pos_embed = local_pos_embed.permute(0, 2, 3, 1).reshape(
+ 1, -1, inpt_embed_dim
+ )
+
+ # remove batch dim added previously
+ local_pos_embed = local_pos_embed.squeeze(0)
+
+ # add cls token back in
+ local_pos_embed = torch.cat([cls_token, local_pos_embed], dim=0)
+
+ return local_pos_embed.contiguous()
+
+ # TODO: Switch to public method after 2.5 is stable
+ @staticmethod
+ def _resize_global_position_embedding(
+ global_pos_embed: torch.Tensor,
+ tgt_max_num_tiles: int,
+ tgt_patch_grid_size: int,
+ ) -> torch.Tensor:
+ """
+ Interpolates the global position embedding for a vision encoder to accommodate new grid dimensions.
+ The embedding dimension is not changed during interpolation, only max_num_tiles and num_tokens_per_tile.
+
+ Args:
+ global_pos_embed (torch.Tensor): The input global position embeddings tensor of shape
+ [max_num_tiles_x, max_num_tiles_y, num_tokens_per_tile, embed_dim],
+ where num_tokens_per_tile = inpt_patch_grid_size * inpt_patch_grid_size + 1 (CLS token), and
+ max_num_tiles_x == max_num_tiles_y.
+ tgt_max_num_tiles (int): The target maximum number of tiles along one dimension (assumed square grid).
+ tgt_patch_grid_size (int): The target size of each patch grid, i.e., the square root of the number of tokens
+ per tile, excluding the class token.
+
+
+ Returns:
+ torch.Tensor: The resized global position embeddings tensor of shape
+ [tgt_max_num_tiles, tgt_max_num_tiles, tgt_patch_grid_size * tgt_patch_grid_size + 1, embed_dim].
+
+ Example:
+ >>> import torch
+ >>> global_pos_embed = torch.arange(3*3*(2*2+1)*4).reshape((3, 3, 2*2+1, 4)) # Example input tensor
+ >>> tgt_max_num_tiles = 2 # Target maximum number of tiles
+ >>> tgt_patch_grid_size = 3 # Target patch grid size
+ >>> resized_global_pos_embed = (
+ >>> _resize_global_position_embedding(global_pos_embed, tgt_max_num_tiles, tgt_patch_grid_size))
+ >>> print(resized_global_pos_embed.shape)
+ torch.Size([2, 2, 3*3+1, 4])
+ """
+
+ # remove cls token to interpolate it separately
+ pos_embed = global_pos_embed[:, :, 1:, :]
+ cls_embed = global_pos_embed[:, :, [0], :]
+
+ (
+ max_num_tiles_x,
+ max_num_tiles_y,
+ n_tokens_per_tile,
+ embed_dim,
+ ) = pos_embed.shape
+
+ # tokens_per_tile == inpt_patch_grid_size**2
+ # we reshape n_tokens_per_tile --> (inpt_patch_grid_size, inpt_patch_grid_size)
+ inpt_patch_grid_size = int(math.sqrt(n_tokens_per_tile))
+ pos_embed = pos_embed.reshape(
+ max_num_tiles_x,
+ max_num_tiles_y,
+ inpt_patch_grid_size,
+ inpt_patch_grid_size,
+ embed_dim,
+ )
+
+ # combine max_num_tiles and patch_grid_size into one dimension
+ pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
+ pos_embed = pos_embed.reshape(
+ max_num_tiles_x * inpt_patch_grid_size,
+ max_num_tiles_y * inpt_patch_grid_size,
+ embed_dim,
+ )
+
+ # add batch dim for interpolation
+ pos_embed = pos_embed.unsqueeze(0)
+
+ tgt_size = (
+ int(tgt_max_num_tiles * tgt_patch_grid_size),
+ int(tgt_max_num_tiles * tgt_patch_grid_size),
+ )
+
+ # move to the last two dim for interpolation
+ pos_embed = pos_embed.permute(0, 3, 1, 2)
+ pos_embed = F.interpolate(
+ pos_embed,
+ size=tgt_size,
+ mode="bilinear",
+ align_corners=True, # defaults from internal-llama-models
+ )
+
+ # return to original shape and remove batch dim
+ pos_embed = pos_embed.permute(0, 2, 3, 1).squeeze(0)
+
+ # move it back in place
+ pos_embed = pos_embed.view(
+ tgt_max_num_tiles,
+ tgt_patch_grid_size,
+ tgt_max_num_tiles,
+ tgt_patch_grid_size,
+ embed_dim,
+ )
+ pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
+ pos_embed = pos_embed.view(
+ tgt_max_num_tiles,
+ tgt_max_num_tiles,
+ int(tgt_patch_grid_size**2),
+ embed_dim,
+ )
+
+ # interpolate cls token
+ cls_embed = cls_embed.permute(2, 3, 0, 1)
+ cls_embed_resized = F.interpolate(
+ cls_embed,
+ size=(tgt_max_num_tiles, tgt_max_num_tiles),
+ mode="bilinear",
+ align_corners=True, # defaults from internal-llama-models
+ )
+ cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
+
+ # add cls token back in
+ global_pos_embed = torch.cat([cls_embed, pos_embed], dim=2)
+
+ return global_pos_embed.contiguous()
+
+ def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): torch.Tensor with shape
+ (bsz * n_imgs, n_tiles, n_tokens_per_tile, embed_dim).
+ aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2),
+ where aspect_ratio[k] represents the aspect ratio of the k^th image
+ of the batch before tile-cropping, e.g. aspect_ratio[k] = (2,1).
+ Returns:
+ torch.Tensor: The input tensor with added positional embeddings.
+ """
+ bsz_and_n_imgs, n_tiles, n_tokens_per_tile, embed_dim = x.shape
+
+ # apply local position embedding (same for every tile)
+ x = x + (self.local_token_positional_embedding * (1 - self.gate.tanh()))
+
+ # apply global positional embedding (different for every tile)
+ x = x.view(bsz_and_n_imgs, n_tiles, n_tokens_per_tile, embed_dim)
+ for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio):
+ # When we batch images, all are padded to the same amount of tiles.
+ # The aspect_ratio lets us know the non padded tiles for each image.
+ # We only add positional encoding to those.
+ n_tiles_h = n_tiles_h.item()
+ n_tiles_w = n_tiles_w.item()
+
+ n_non_padded_tiles = int(n_tiles_h * n_tiles_w)
+
+ # We get only the positional encoding for non padded tiles,
+ # i.e. n_tiles_h, n_tiles_w.
+ torch._check(n_tiles_h > 0)
+ torch._check(n_tiles_w > 0)
+ torch._check(n_tiles_h <= self.max_num_tiles)
+ torch._check(n_tiles_w <= self.max_num_tiles)
+ padded_embedding = F.pad(
+ self.global_token_positional_embedding, (0, 0, 0, 0, 0, 1, 0, 1)
+ )
+
+ pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :]
+
+ # Add pos encoding to the non padded tiles.
+ pos_embed = pos_embed.clone()
+ pos_embed = pos_embed.reshape(
+ n_non_padded_tiles, self.n_tokens_per_tile, embed_dim
+ )
+ pos_embed = pos_embed * self.gate.tanh()
+ x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0))
+ torch._check(n_non_padded_tiles < self.max_num_tiles + 1)
+ torch._check(n_non_padded_tiles < x.size(1))
+ x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed
+ x = x[:, :n_tiles, :, :]
+
+ return x
+
+
+def replace_tile_positional_embedding(model: nn.Module) -> nn.Module:
+ """
+ Replace the tile positional embedding from torchtune with an export-friendly one.
+ Recursively searches the submodules of the model and replaces the tile positional embedding if found.
+ Args:
+ model (nn.Module): The model to replace the tile positional embedding in.
+
+ Returns:
+ nn.Module: The model after replacing the tile positional embedding.
+
+ """
+ from torchtune.models.clip._position_embeddings import (
+ TilePositionalEmbedding as TuneTilePositionalEmbedding,
+ )
+
+ for name, module in model.named_children():
+ if isinstance(module, TuneTilePositionalEmbedding):
+ logging.info(
+ f"Replacing tile positional embedding in {name} with export-friendly one."
+ )
+ max_num_tiles, _, _, embed_dim = module.embedding.shape
+ mod = TilePositionalEmbedding(
+ max_num_tiles=max_num_tiles,
+ embed_dim=embed_dim,
+ )
+ mod.load_state_dict(module.state_dict())
+ setattr(
+ model,
+ name,
+ mod,
+ )
+ else:
+ replace_tile_positional_embedding(module)
+ return model
+
+
+def replace_tiled_token_positional_embedding(model: nn.Module) -> nn.Module:
+ """
+ Replace the tiled token positional embedding from torchtune with an export-friendly one.
+ Recursively searches the submodules of the model and replaces the tiled token positional embedding if found.
+ Args:
+ model (nn.Module): The model to replace the tiled token positional embedding in.
+
+ Returns:
+ nn.Module: The model after replacing the tiled token positional embedding.
+
+ """
+ from torchtune.models.clip._position_embeddings import (
+ TiledTokenPositionalEmbedding as TuneTiledTokenPositionalEmbedding,
+ )
+
+ for name, module in model.named_children():
+ if isinstance(module, TuneTiledTokenPositionalEmbedding):
+ logging.info(
+ f"Replacing tiled token positional embedding in {name} with export-friendly one."
+ )
+ (
+ max_num_tiles,
+ _,
+ n_tokens_per_tile,
+ embed_dim,
+ ) = module.global_token_positional_embedding.shape
+ mod = TiledTokenPositionalEmbedding(
+ max_num_tiles=max_num_tiles,
+ embed_dim=embed_dim,
+ tile_size=int(math.sqrt((n_tokens_per_tile - 1))),
+ patch_size=1,
+ )
+ mod.load_state_dict(module.state_dict())
+ setattr(
+ model,
+ name,
+ mod,
+ )
+ else:
+ replace_tiled_token_positional_embedding(module)
+ return model
diff --git a/torchtune/modules/_export/attention.py b/torchtune/modules/_export/attention.py
new file mode 100644
index 0000000000..bb3fe4a94b
--- /dev/null
+++ b/torchtune/modules/_export/attention.py
@@ -0,0 +1,423 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import logging
+from typing import Optional
+
+import torch
+import torchtune.modules.attention as TorchTuneAttention
+from torch import nn
+from torchtune.modules._export.kv_cache import KVCache as InferenceKVCache
+from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
+from torchtune.modules.kv_cache import KVCache
+
+logger = logging.getLogger(__name__)
+
+
+class MultiHeadAttention(nn.Module):
+ """
+ NOTE: torch.export.export() friendly MultiHeadAttention, modified from
+ torchtune.modules.attention.MultiHeadAttention
+ Major differences:
+ - Rewrite `if y is None` to torch.cond().
+ - Logic becomes `if all values of y are NaN`, to make torch.compile() happy.
+ - No input mutations in both false and true branches, so we need to copy kv
+ values back into kv cache after torch.cond().
+ - Added a SDPA module
+ - SDPA module includes transpose and expanding kv dimensions.
+ - Makes it easy to swap with custom SDPAs that are needed by the users of exported
+ program.
+ - Uses new kv cache
+ - This potentially can be merged with torchtune.modules.kv_cache.
+ - Changed += to .add_ to avoid mutating module attributes.
+ - Added clone() method.
+
+ Multi-headed attention layer with support for grouped query
+ attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1.
+
+ GQA is a version of multiheaded attention (MHA) which uses fewer
+ key/value heads than query heads by grouping n query heads for each
+ key and value head. Multi-Query Attention is an extreme
+ version where we have a single key and value head shared by all
+ query heads.
+
+ Following is an example of MHA, GQA and MQA with num_heads = 4
+
+ (credit for the documentation:
+ `litgpt.Config `_).
+
+
+ ::
+
+ ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
+ │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
+ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
+ │ │ │ │ │ │ │
+ ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
+ │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
+ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
+ │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
+ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
+ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
+ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
+ ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
+ MHA GQA MQA
+ n_kv_heads =4 n_kv_heads=2 n_kv_heads=1
+
+ Args:
+ embed_dim (int): embedding dimension for the model
+ num_heads (int): number of query heads. For MHA this is also the
+ number of heads for key and value
+ num_kv_heads (int): number of key and value heads. User should ensure
+ ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``,
+ for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``.
+ head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``.
+ q_proj (nn.Module): projection layer for query.
+ k_proj (nn.Module): projection layer for key.
+ v_proj (nn.Module): projection layer for value.
+ output_proj (nn.Module): projection layer for output.
+ pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings.
+ q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied
+ before updating from kv_cache. This means it will only support token wide normalization and not
+ batch or sequence wide normalization.
+ k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is.
+ kv_cache (Optional[KVCache]): KVCache object used to cache key and value
+ max_seq_len (int): maximum sequence length supported by the model.
+ This is needed to compute the RoPE Cache. Default: 4096.
+ is_causal (bool): sets the default mask to causal when no mask is provided
+ attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function.
+ Default value is 0.0.
+
+ Raises:
+ ValueError: If ``num_heads % num_kv_heads != 0``
+ ValueError: If ``embed_dim % num_heads != 0``
+ ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1``
+ ValueError: if q_norm is defined without k_norm or vice versa
+ """
+
+ def __init__(
+ self,
+ *,
+ embed_dim: int,
+ num_heads: int,
+ num_kv_heads: int,
+ head_dim: int,
+ q_proj: nn.Module,
+ k_proj: nn.Module,
+ v_proj: nn.Module,
+ output_proj: nn.Module,
+ pos_embeddings: Optional[nn.Module] = None,
+ q_norm: Optional[nn.Module] = None,
+ k_norm: Optional[nn.Module] = None,
+ kv_cache: Optional[KVCache] = None,
+ max_seq_len: int = 4096,
+ is_causal: bool = True,
+ attn_dropout: float = 0.0,
+ ) -> None:
+ super().__init__()
+ if num_heads % num_kv_heads != 0:
+ raise ValueError(
+ f"num_heads ({num_heads}) must be divisible by "
+ f"num_kv_heads ({num_kv_heads})"
+ )
+
+ if embed_dim % num_heads != 0:
+ raise ValueError(
+ f"embed_dim ({embed_dim}) must be divisible by "
+ f"num_heads ({num_heads})"
+ )
+
+ if attn_dropout < 0 or attn_dropout > 1:
+ raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0")
+
+ if bool(q_norm) ^ bool(k_norm):
+ raise ValueError("q and k norm must be set together")
+
+ # Set attributes
+ self.num_heads = num_heads
+ self.num_kv_heads = num_kv_heads
+ self.embed_dim = embed_dim
+ self.attn_dropout = attn_dropout
+ self.head_dim = head_dim
+ self.max_seq_len = max_seq_len
+ self.is_causal = is_causal
+
+ # Set layers
+ self.kv_cache = kv_cache
+ self.q_proj = q_proj
+ self.k_proj = k_proj
+ self.v_proj = v_proj
+ self.output_proj = output_proj
+ self.q_norm = q_norm
+ self.k_norm = k_norm
+ self.pos_embeddings = pos_embeddings
+
+ # Use flex attention if supported and we are sample packing
+ self._attention_call = _sdpa_or_flex_attention()
+ self._sdpa = SDPA(
+ num_kv_heads=self.num_kv_heads,
+ num_heads=self.num_heads,
+ head_dim=self.head_dim,
+ attn_dropout=self.attn_dropout if self.training else 0.0,
+ is_causal=self.is_causal,
+ attention_fn=self._attention_call,
+ kv_cache=self.kv_cache,
+ )
+
+ # this flag indicates whether to update the kv-cache during forward
+ # passes. when disabled, we can have the cache setup but still
+ # perform normal forward passes
+ self.cache_enabled = False
+
+ def setup_cache(
+ self, batch_size: int, dtype: torch.dtype, max_seq_len: int
+ ) -> None:
+ """Setup key value caches for attention calculation. If called
+ after kv_cache is already setup, this will be skipped.
+
+ Args:
+ batch_size (int): batch size for the caches.
+ dtype (torch.dtype): dtype for the caches.
+ max_seq_len (int): maximum sequence length model will be run with.
+ """
+ # Don't overwrite user defined kv_cache from init
+ if self.kv_cache is not None:
+ logger.warning(
+ "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping."
+ )
+ else:
+ self.kv_cache = InferenceKVCache(
+ batch_size=batch_size,
+ max_seq_len=max_seq_len,
+ num_kv_heads=self.num_kv_heads,
+ head_dim=self.head_dim,
+ dtype=dtype,
+ transpose_cache=False,
+ )
+ self._sdpa.kv_cache = self.kv_cache
+ self.cache_enabled = True
+
+ def reset_cache(self):
+ """Reset the key value caches."""
+ if self.kv_cache is None:
+ raise RuntimeError(
+ "Key value caches are not setup. Call ``setup_caches()`` first."
+ )
+ self.kv_cache.reset()
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ *,
+ mask: Optional[_MaskType] = None,
+ input_pos: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): input tensor with shape [b x s_x x d] for the query
+ y (torch.Tensor): second input tensor with shape [b x s_y x d], is the input
+ for k and v. For self attention, x=y. If all values are NaN, we read from kv cache.
+ mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication
+ and before the softmax. Either:
+
+ A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``,
+ or ``[b x s x self.encoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers.
+ A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means
+ token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask
+ is used by default.
+
+ A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence
+ created via `create_block_mask `_. We use
+ :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks.
+ Default is None.
+ input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
+ of each token. During training, this is used to indicate the positions
+ of each token relative to its sample when packed, shape [b x s].
+ During inference, this indicates the position of the current token.
+ If none, assume the index of the token is its position id. Default is None.
+
+ Returns:
+ torch.Tensor: output tensor with attention applied
+
+ Notation used for tensor shapes:
+ - b: batch size
+ - s_x: sequence length for x
+ - s_y: sequence length for y
+ - n_h: num heads
+ - n_kv: num kv heads
+ - d: embed dim
+ - h_d: head dim
+ """
+ # x has shape [b, s_x, d]
+ # y has shape [b, s_y, d]
+ b, s_x, _ = x.shape
+
+ # q has shape [b, s_x, num_heads * head_dim]
+ q = self.q_proj(x)
+
+ # number of queries per key/value
+ q_per_kv = self.num_heads // self.num_kv_heads
+ q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim)
+
+ # Apply positional embeddings
+ if self.pos_embeddings is not None:
+ q = self.pos_embeddings(q, input_pos=input_pos)
+
+ # Normalize q
+ if self.q_norm is not None:
+ q = self.q_norm(q)
+
+ def calculate_kv(y):
+ # Update k and v shape, positional embeddings, and normalization
+ s_y = y.shape[1]
+ # k has shape [b, s_y, num_kv_heads * head_dim]
+ # v has shape [b, s_y, num_kv_heads * head_dim]
+ k = self.k_proj(y)
+ v = self.v_proj(y)
+
+ # Apply positional embeddings
+ # k: [b, s_y, n_kv, h_d]
+ k = k.view(b, s_y, -1, self.head_dim)
+ v = v.view(b, s_y, -1, self.head_dim)
+ if self.pos_embeddings is not None:
+ k = self.pos_embeddings(k, input_pos=input_pos)
+
+ # Normalize k
+ if self.k_norm is not None:
+ k = self.k_norm(k)
+ return k, v
+
+ def true_fn(y):
+ kv_cache = self.kv_cache.clone()
+ return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
+
+ def false_fn(y):
+ k, v = calculate_kv(y)
+ kv_cache = self.kv_cache.clone()
+ kv_cache.update(k, v)
+ return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos
+
+ # If kv cache is None, we expect y to be provided
+ if self.kv_cache is None:
+ assert (
+ y is not None
+ ), "Must provide y input or use kv_cache to enable streaming decoding"
+ k, v = calculate_kv(y)
+ else:
+ # Expecting the k, v returning here to be the same size of self.kv_cache
+ # In eager, we expect this predicate to specialize. In export, this will
+ # become a SymBool so it's not specialized.
+ k, v, cache_pos = torch.cond(
+ torch.isnan(y).all().item(), true_fn, false_fn, (y,)
+ )
+ # Update key-value cache
+ self.kv_cache.k_cache.copy_(k)
+ self.kv_cache.v_cache.copy_(v)
+ self.kv_cache.cache_pos.copy_(cache_pos)
+
+ output = self._sdpa(q, k, v, b, s_x, mask=mask)
+ return self.output_proj(output)
+
+
+class SDPA(nn.Module):
+ """
+ TorchTune's SDPA which can be optimized and can be swapped
+ out for a more efficient implementations.
+ """
+
+ def __init__(
+ self,
+ num_kv_heads: int,
+ num_heads: int,
+ head_dim: int,
+ attn_dropout: float,
+ is_causal: bool,
+ attention_fn,
+ kv_cache,
+ ) -> None:
+ super().__init__()
+ self.num_kv_heads = num_kv_heads
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+ self.q_per_kv = self.num_heads // self.num_kv_heads
+ self.attn_dropout = attn_dropout
+ self.is_causal = is_causal
+ self._attention_fn = attention_fn
+ self.kv_cache = kv_cache
+
+ def forward(
+ self,
+ q: torch.Tensor, # [b, s, n_h, h_d]
+ k: torch.Tensor, # [b, s, n_kv, h_d]
+ v: torch.Tensor, # [b, s, n_kv, h_d]
+ bsz: int,
+ seq_len: int,
+ mask: Optional[_MaskType] = None,
+ ) -> torch.Tensor:
+ # View + expand + reshape bring num_kv_heads to num_heads for k and v
+ # to match q.
+
+ # [bsz, n_h, s, h_d]
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+
+ # Expand the key and value tensors to have the same shape
+ # as the query tensor by copying values across the relevant dim
+ if self.num_heads != self.num_kv_heads:
+ expand_shape = (-1, -1, self.q_per_kv, -1, -1)
+ k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2)
+ v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2)
+
+ output = self._attention_fn(
+ q,
+ k,
+ v,
+ mask=mask,
+ dropout_p=self.attn_dropout,
+ is_causal=self.kv_cache is None and mask is None and self.is_causal,
+ )
+ # Reshape the output to be the same shape as the input
+ return output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
+
+
+def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None:
+ for name, child in module.named_children():
+ if isinstance(child, TorchTuneAttention.MultiHeadAttention):
+ setattr(
+ module,
+ name,
+ MultiHeadAttention(
+ embed_dim=child.embed_dim,
+ num_heads=child.num_heads,
+ num_kv_heads=child.num_kv_heads,
+ head_dim=child.head_dim,
+ q_proj=child.q_proj,
+ k_proj=child.k_proj,
+ v_proj=child.v_proj,
+ output_proj=child.output_proj,
+ pos_embeddings=child.pos_embeddings,
+ q_norm=child.q_norm,
+ k_norm=child.k_norm,
+ kv_cache=child.kv_cache,
+ max_seq_len=child.max_seq_len,
+ is_causal=child.is_causal,
+ attn_dropout=child.attn_dropout,
+ ),
+ )
+ else:
+ replace_mha_with_inference_mha(child)
+
+
+def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
+ """
+ Replace TorchTune's MHA with an inference friendly version of MHA that
+ separates out the inference-related parts for further optimization.
+ """
+ _replace_mha_with_inference_mha(module)
+ return module
diff --git a/torchtune/modules/_export/install_requirements.sh b/torchtune/modules/_export/install_requirements.sh
new file mode 100644
index 0000000000..d4f5b5e01e
--- /dev/null
+++ b/torchtune/modules/_export/install_requirements.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+NIGHTLY_VERSION="dev20241121"
+
+# Install pytorch nightly for export-friendly modules to run.
+pip install torch==2.6.0.${NIGHTLY_VERSION} torchvision==0.20.0.${NIGHTLY_VERSION} --extra-index-url https://download.pytorch.org/whl/nightly/cpu
diff --git a/torchtune/modules/_export/kv_cache.py b/torchtune/modules/_export/kv_cache.py
new file mode 100644
index 0000000000..8e0b7047e5
--- /dev/null
+++ b/torchtune/modules/_export/kv_cache.py
@@ -0,0 +1,152 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Tuple
+
+import torch
+from torchtune.modules.kv_cache import KVCache as TuneKVCache
+
+
+class KVCache(TuneKVCache):
+ """
+ NOTE: torch.export.export() friendly KVCache implementation modified from KVCache:
+ https://github.com/pytorch/torchtune/blob/main/torchtune/modules/kv_cache.py
+ Major differences:
+ - Changed += to .add_ to avoid mutating module attributes.
+ - Added clone() method.
+ - Takes a new `transpose_cache` argument to be able to store transposed kv values.
+
+ Standalone ``nn.Module`` containing a kv-cache to cache past key and values during inference.
+
+ Args:
+ batch_size (int): batch size model will be run with
+ max_seq_len (int): maximum sequence length model will be run with
+ num_kv_heads (int): number of key/value heads.
+ head_dim (int): per-attention head embedding dimension
+ dtype (torch.dtype): dtype for the caches
+ transpose_cache (bool): whether we transpose(1, 2) for kv cache.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ max_seq_len: int,
+ num_kv_heads: int,
+ head_dim: int,
+ dtype: torch.dtype,
+ transpose_cache: bool = True,
+ ) -> None:
+ super().__init__(
+ batch_size=batch_size,
+ max_seq_len=max_seq_len,
+ num_kv_heads=num_kv_heads,
+ head_dim=head_dim,
+ dtype=dtype,
+ )
+ self.transpose_cache = transpose_cache
+ self.max_seq_len = max_seq_len
+ if self.transpose_cache:
+ cache_shape = (batch_size, num_kv_heads, max_seq_len, head_dim)
+ else:
+ cache_shape = (batch_size, max_seq_len, num_kv_heads, head_dim)
+
+ self.register_buffer(
+ "k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
+ )
+ self.register_buffer(
+ "v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
+ )
+ self.register_buffer(
+ "cache_pos", torch.arange(0, self.max_seq_len), persistent=False
+ )
+ self.batch_size = batch_size
+
+ def update(
+ self, k_val: torch.Tensor, v_val: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Update KV cache with the new ``k_val``, ``v_val`` and return the updated cache.
+
+ Note:
+ When updating the KV cache, it is assumed that subsequent updates should update key-value
+ positions in consecutive sequence positions. If you wish to update cache values which have
+ already been filled, use ``.reset()``, which will reset the cache to the zero-th position.
+
+ Example:
+ >>> cache = KVCache(batch_size=2, max_seq_len=16, num_kv_heads=4, head_dim=32, dtype=torch.bfloat16)
+ >>> keys, values = torch.ones((2, 4, 8, 32)), torch.ones((2, 4, 8, 32))
+ >>> cache.update(keys, values)
+ >>> # now positions 0 through 7 are filled
+ >>> cache.size
+ >>> 8
+ >>> keys, values = torch.ones((2, 4, 1, 32)), torch.ones((2, 4, 1, 32))
+ >>> cache.update(keys, values)
+ >>> # this will fill at position 8
+ >>> cache.size
+ >>> 9
+
+ Args:
+ k_val (torch.Tensor): Current key tensor with shape [B, H, S, D]
+ v_val (torch.Tensor): Current value tensor with shape [B, H, S, D]
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively.
+
+ Raises:
+ AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length.
+ ValueError: if the batch size of the new key (or value) tensor is greater than the batch size
+ used during cache setup.
+ """
+ if self.transpose_cache:
+ bsz, _, seq_len, _ = k_val.shape
+ else:
+ bsz, seq_len, _, _ = k_val.shape
+ if bsz > self.k_cache.shape[0]:
+ raise ValueError(
+ f"The current cache has been setup with a batch size of {self.k_cache.shape[0]}"
+ f", but found new key tensors with batch size {k_val.shape[0]}!"
+ )
+
+ assert (self.cache_pos[0] + seq_len) <= self.max_seq_len
+
+ k_out = self.k_cache
+ v_out = self.v_cache
+
+ if self.transpose_cache:
+ k_out[:, :, self.cache_pos[:seq_len]] = k_val
+ v_out[:, :, self.cache_pos[:seq_len]] = v_val
+ else:
+ k_out[:, self.cache_pos[:seq_len]] = k_val
+ v_out[:, self.cache_pos[:seq_len]] = v_val
+
+ # forward cache_pos seq_len positions along
+ # cache_pos starts at (0, 1, 2, 3, 4, 5, ...)
+ # an update of seq_len = 5 tokens brings it to
+ # (5, 6, 7, 8, 9, ...)
+ # this allows us to track the current position in the cache
+ # after the last update in a compile-friendly way without any dynamism
+ # e.g. relying on an int size tracker, or re-creating cache_pos every time
+ self.cache_pos.add_(seq_len)
+
+ return k_out, v_out
+
+ def clone(self) -> "KVCache":
+ """Create a clone of the KVCache."""
+ if self.transpose_cache:
+ num_kv_heads = self.k_cache.shape[1]
+ else:
+ num_kv_heads = self.k_cache.shape[2]
+ clone = KVCache(
+ batch_size=self.batch_size,
+ max_seq_len=self.max_seq_len,
+ num_kv_heads=num_kv_heads,
+ head_dim=self.k_cache.shape[3],
+ dtype=self.k_cache.dtype,
+ transpose_cache=self.transpose_cache,
+ )
+ clone.k_cache.copy_(self.k_cache)
+ clone.v_cache.copy_(self.v_cache)
+ clone.cache_pos.copy_(self.cache_pos)
+ return clone
diff --git a/torchtune/modules/model_fusion/_deep_fusion.py b/torchtune/modules/model_fusion/_deep_fusion.py
index 6a61c43744..67de2372e4 100644
--- a/torchtune/modules/model_fusion/_deep_fusion.py
+++ b/torchtune/modules/model_fusion/_deep_fusion.py
@@ -46,7 +46,7 @@ class DeepFusionModel(nn.Module):
>>> # DeepFusionModel combines the encoder and decoder
>>> model = DeepFusionModel(decoder, encoder)
>>>
- >>> # Load full fused checkpoints (e.g. a Flamingo checkpoint)
+ >>> # Load full fused checkpoints (e.g. a Llama3.2 Vision checkpoint)
>>> model.load_state_dict(...)
>>>
>>> # Or load pretrained individual models (fusion_params are not loaded)
diff --git a/torchtune/modules/peft/__init__.py b/torchtune/modules/peft/__init__.py
index 165559df9c..2959bc3bb6 100644
--- a/torchtune/modules/peft/__init__.py
+++ b/torchtune/modules/peft/__init__.py
@@ -17,13 +17,14 @@
validate_missing_and_unexpected_for_lora,
)
from .dora import DoRALinear
-from .lora import LoRALinear
+from .lora import LoRALinear, QATLoRALinear
__all__ = [
+ "AdapterModule",
"DoRALinear",
"LoRALinear",
- "AdapterModule",
+ "QATLoRALinear",
"get_adapter_params",
"set_trainable_params",
"validate_missing_and_unexpected_for_lora",
diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py
index e03d854f1f..f6303b798c 100644
--- a/torchtune/modules/peft/lora.py
+++ b/torchtune/modules/peft/lora.py
@@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
-from typing import List
+from typing import List, Optional
import torch
import torch.nn.functional as F
@@ -131,6 +131,165 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return out + lora_out
+class QATLoRALinear(LoRALinear):
+ """
+ LoRA linear layer with quantization-aware training (QAT) applied to the
+ activations and/or weights before the low rank adapters.
+
+ QAT leverages fake quantization to simulate the quantization numerics during
+ training without actually casting the data to lower precision. This class
+ combines LoRA with QAT to improve the final quantized accuracy during inference
+ while reducing the memory required during training.
+
+ Args:
+ in_dim (int): input dimension
+ out_dim (int): output dimension
+ rank (int): rank of the low-rank approximation
+ alpha (float): scaling factor for the low-rank approximation
+ dropout (float): dropout probability. Default: 0.0
+ activation_qat_config (Optional[FakeQuantizeConfig]): config for specifying
+ how input activations will be fake quantized, defaults to None
+ weight_qat_config (Optional[FakeQuantizeConfig]): config for specifying
+ how weights will be fake quantized, defaults to None
+
+ Raises:
+ ValueError: If `in_dim` is not divisible by weight `group_size`
+
+ Example usage::
+
+ activation_qat_config = FakeQuantizeConfig(
+ dtype=torch.int8,
+ granularity="per_token",
+ is_symmetric=False,
+ )
+ weight_qat_config = FakeQuantizeConfig(
+ dtype=torch.int4,
+ group_size=8,
+ is_symmetric=True,
+ )
+ qat_lora_linear = QATLoRALinear(
+ in_dim=512,
+ out_dim=1024,
+ rank=8,
+ alpha=16,
+ dropout=0.0,
+ activation_qat_config=activation_qat_config,
+ weight_qat_config=weight_qat_config,
+ )
+ qat_lora_linear(torch.randn(512))
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ rank: int,
+ alpha: float,
+ dropout: float = 0.0,
+ # fake quantize configs
+ # TODO: make the types Optional[FakeQuantizeConfig] once we
+ # support torchao 0.7+ by default
+ activation_qat_config: Optional["FakeQuantizeConfig"] = None,
+ weight_qat_config: Optional["FakeQuantizeConfig"] = None,
+ ):
+ super().__init__(
+ in_dim,
+ out_dim,
+ rank,
+ alpha,
+ dropout,
+ use_bias=False,
+ quantize_base=False,
+ )
+
+ try:
+ from torchao.quantization.qat.api import FakeQuantizeConfig
+ from torchao.quantization.qat.fake_quantizer import FakeQuantizer
+ except ImportError as err:
+ raise ValueError(
+ "QATLoRALinear is only compatible with torchao 0.7+"
+ ) from err
+
+ # initialize activation fake quantizer
+ if activation_qat_config is not None:
+ assert isinstance(activation_qat_config, FakeQuantizeConfig)
+ self.activation_fake_quantizer = FakeQuantizer(activation_qat_config)
+ else:
+ self.activation_fake_quantizer = nn.Identity()
+
+ # initialize weight fake quantizer
+ if weight_qat_config is not None:
+ assert isinstance(weight_qat_config, FakeQuantizeConfig)
+ group_size = weight_qat_config.group_size
+ if group_size is not None and in_dim % group_size != 0:
+ raise ValueError(
+ "in_dim (%s) must be divisible by group_size (%s)"
+ % (in_dim, group_size)
+ )
+ self.weight_fake_quantizer = FakeQuantizer(weight_qat_config)
+ else:
+ self.weight_fake_quantizer = nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): input tensor with shape ``(..., in_dim)``
+
+ Returns:
+ torch.Tensor: output tensor with shape ``(..., out_dim)``
+
+ """
+ _x = self.activation_fake_quantizer(x)
+ w = self.weight_fake_quantizer(self.weight)
+ out = F.linear(_x, w)
+ if self.disabled:
+ return out
+ lora_out = self.lora_a(self.dropout(x))
+ lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
+ return out + lora_out
+
+ @classmethod
+ def from_lora_linear(
+ cls,
+ lora_linear: LoRALinear,
+ # TODO: make the types Optional[FakeQuantizeConfig] once we
+ # support torchao 0.7+ by default
+ activation_qat_config: Optional["FakeQuantizeConfig"] = None,
+ weight_qat_config: Optional["FakeQuantizeConfig"] = None,
+ ) -> "QATLoRALinear":
+ """
+ Create a `QATLoRALinear` from an existing `LoRALinear`,
+ preserving the weights and adapters.
+ """
+ if lora_linear.bias is not None:
+ ValueError("Bias is not supported in QAT + LoRA yet")
+ if lora_linear._quantize_base:
+ ValueError("quantize_base is not compatible with QAT + LoRA")
+ if isinstance(lora_linear.dropout, nn.Dropout):
+ dropout = lora_linear.dropout.p
+ else:
+ dropout = 0.0
+ new_linear = cls(
+ lora_linear.in_dim,
+ lora_linear.out_dim,
+ lora_linear.rank,
+ lora_linear.alpha,
+ dropout,
+ activation_qat_config,
+ weight_qat_config,
+ )
+ # In distributed training, the model may be instantiated
+ # on the meta device, in which case there is no need to
+ # copy the weights, and doing so will result in an error
+ if lora_linear.weight.device != torch.device("meta"):
+ new_linear.weight = lora_linear.weight
+ if lora_linear.lora_a.weight.device != torch.device("meta"):
+ new_linear.lora_a.weight = lora_linear.lora_a.weight
+ if lora_linear.lora_b.weight.device != torch.device("meta"):
+ new_linear.lora_b.weight = lora_linear.lora_b.weight
+ return new_linear
+
+
def _lora_a_init_params(x: nn.Linear) -> None:
"""
Initialize LoRA A weight to Kaiming uniform.
diff --git a/torchtune/modules/rms_norm.py b/torchtune/modules/rms_norm.py
index 299e6fb428..f829811ce8 100644
--- a/torchtune/modules/rms_norm.py
+++ b/torchtune/modules/rms_norm.py
@@ -5,18 +5,15 @@
# LICENSE file in the root directory of this source tree.
import torch
-
+import torch.nn.functional as F
from torch import nn
class RMSNorm(nn.Module):
"""
- Implements Root Mean Square Normalization introduced in
- https://arxiv.org/abs/1910.07467.
+ Root Mean Square Normalization in fp32.
- Reference implementation (used for correctness verification)
- can be found here:
- https://github.com/facebookresearch/llama/blob/main/llama/model.py
+ See: https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html
Args:
dim (int): embedding size
@@ -25,6 +22,7 @@ class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
+ self.normalized_shape = (dim,)
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
@@ -37,8 +35,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
torch.Tensor: The normalized and scaled tensor having the same shape as ``x``.
"""
# computation is in fp32
- x_fp32 = x.float()
- x_normed = (
- x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
- ).type_as(x)
- return x_normed * self.scale
+ return F.rms_norm(
+ x.float(),
+ normalized_shape=self.normalized_shape,
+ weight=self.scale,
+ eps=self.eps,
+ ).to(x.dtype)
diff --git a/torchtune/modules/vq_embeddings.py b/torchtune/modules/vq_embeddings.py
new file mode 100644
index 0000000000..14d6cef995
--- /dev/null
+++ b/torchtune/modules/vq_embeddings.py
@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Tuple
+
+import torch
+from torch import nn, Tensor
+from torch.nn import functional as F
+
+
+class VectorQuantizedEmbeddings(nn.Module):
+ """
+ Vector quantized embedding layer that takes in the output of an encoder
+ and performs a nearest-neighbor lookup in the embedding space.
+ Vector quantization was introduced in Oord et al. 2017 (https://arxiv.org/pdf/1711.00937.pdf)
+ to generate high-fidelity images, videos, and audio data.
+
+ This module currently does not support pre-training of the embeddings via EMA.
+
+ Code was adapted from torchmultimodal's `Codebook module
+ `_.
+
+ Args:
+ num_embeddings (int): Number of vectors in the embedding space.
+ embedding_dim (int): Dimensionality of the embedding vectors.
+ """
+
+ def __init__(
+ self,
+ num_embeddings: int,
+ embedding_dim: int,
+ ) -> None:
+ super().__init__()
+ self.embedding = nn.Parameter(torch.empty(num_embeddings, embedding_dim))
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+
+ def forward(self, z: Tensor) -> Tuple[Tensor, Tensor]:
+ """
+ Args:
+ z (Tensor): Tensor containing a batch of encoder outputs of shape ``(b, s, d)``, where
+ b is batch size, s is sequence length or time, and d is ``embedding_dim``.
+
+ Returns:
+ Tuple[Tensor, Tensor]: The quantized input and the embedding vector ids that were used.
+
+ Raises:
+ ValueError: if input embedding dimension does not match embedding dimension of module
+ """
+ bsz, seq_len, z_embed_dim = z.shape
+ if z_embed_dim != self.embedding_dim:
+ raise ValueError(
+ f"Expected last dimension of input tensor ({z_embed_dim}) to be embedding size of {self.embedding_dim}"
+ )
+
+ # Flatten into batch dimension
+ z_flat = z.view(-1, z_embed_dim)
+ # Calculate distances from each encoder, E(x), output vector to each embedding vector, e, ||E(x) - e||^2
+ distances = torch.cdist(z_flat, self.embedding, p=2.0) ** 2
+
+ # Encoding - select closest embedding vectors, shape [b * s, ]
+ token_ids_flat = torch.argmin(distances, dim=1)
+
+ # Quantize - shape [b * s, d]
+ quantized_flat = self.decode(token_ids_flat)
+
+ # Straight through estimator
+ quantized_flat = z_flat + (quantized_flat - z_flat).detach()
+
+ # Reshape to original - [b, s, d] and [b, s]
+ quantized = quantized_flat.view(bsz, seq_len, z_embed_dim)
+ token_ids = token_ids_flat.view(bsz, seq_len)
+
+ return quantized, token_ids
+
+ def extra_repr(self) -> str:
+ return "num_embeddings={}, embedding_dim={}".format(
+ self.num_embeddings, self.embedding_dim
+ )
+
+ def decode(self, token_ids: Tensor) -> Tensor:
+ # Returns the embeddings of shape [b, s, d]
+ return F.embedding(token_ids, self.embedding)
diff --git a/torchtune/rlhf/loss/dpo.py b/torchtune/rlhf/loss/dpo.py
index 29f66a20c3..b19e0d93ca 100644
--- a/torchtune/rlhf/loss/dpo.py
+++ b/torchtune/rlhf/loss/dpo.py
@@ -9,6 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
+from torchtune.utils._logging import deprecated
class DPOLoss(nn.Module):
@@ -160,6 +161,7 @@ def forward(
return losses, chosen_rewards, rejected_rewards
+@deprecated(msg="SimPOLoss will be deprecated in an upcoming release.")
class SimPOLoss(nn.Module):
"""
SimPO: Simple Preference Optimization with a Reference-Free Reward: https://arxiv.org/abs/2405.14734.
diff --git a/torchtune/training/precision.py b/torchtune/training/precision.py
index 6da300be72..85a2c07e4f 100644
--- a/torchtune/training/precision.py
+++ b/torchtune/training/precision.py
@@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
import contextlib
-from typing import Dict, Generator, Iterable, Optional, Tuple
+from typing import Dict, Generator, Iterable, List, Optional, Tuple
import torch
@@ -53,6 +53,7 @@ def verify_bf16_support() -> bool:
- NCCL is available and version >= 2.10
- MPS is available and torch was built with MPS
- NPU is available and supports bf16
+ - XPU is available and supports bf16
Returns:
bool: True if bf16 is available, False otherwise.
@@ -66,7 +67,8 @@ def verify_bf16_support() -> bool:
)
mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built()
npu_support = is_npu_available and torch.npu.is_bf16_supported()
- return cuda_support or mps_support or npu_support
+ xpu_support = torch.xpu.is_available() and torch.xpu.is_bf16_supported()
+ return cuda_support or mps_support or npu_support or xpu_support
def get_dtype(
@@ -147,7 +149,9 @@ def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
def validate_expected_param_dtype(
- named_params: Iterable[Tuple[str, torch.nn.Parameter]], dtype: torch.dtype
+ named_params: Iterable[Tuple[str, torch.nn.Parameter]],
+ dtype: torch.dtype,
+ exclude_param_names: Optional[List[str]] = None,
) -> None:
"""
Validates that all input parameters have the expected dtype.
@@ -155,11 +159,15 @@ def validate_expected_param_dtype(
Args:
named_params (Iterable[Tuple[str, torch.nn.Parameter]]): Iterable of named parameters.
dtype (torch.dtype): Expected dtype.
+ exclude_param_names (Optional[List[str]]): Optional list of parameter names to exclude from dtype checking
Raises:
ValueError: If any parameter has a different dtype than `dtype`.
"""
for name, param in named_params:
+ if exclude_param_names is not None:
+ if any(n in name for n in exclude_param_names):
+ continue
if param.dtype != dtype:
raise ValueError(
f"Parameter {name} has dtype {param.dtype}, but expected {dtype}"
diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py
index 7ff9315f41..4e21cb4936 100644
--- a/torchtune/training/quantization.py
+++ b/torchtune/training/quantization.py
@@ -7,6 +7,10 @@
from typing import Callable, Optional
from warnings import warn
+from torch import nn
+from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear
+
+
try:
# torchao 0.7+
from torchao.dtypes import TensorCoreTiledLayout
@@ -55,6 +59,12 @@
]
+_torchao_0_7_supported = True
+try:
+ from torchao.quantization import qat # noqa: F401
+except ImportError:
+ _torchao_0_7_supported = False
+
_quantizer_to_mode = {}
_quantizer_mode_to_disable_fake_quant = {}
_quantizer_mode_to_enable_fake_quant = {}
@@ -185,3 +195,45 @@ def _get_enable_fake_quant(quantizer_mode: str) -> Callable:
If the quantizer is not recognized as a known QAT quantizer, return None.
"""
return _quantizer_mode_to_enable_fake_quant.get(quantizer_mode, None)
+
+
+def swap_lora_linear_with_qat(
+ module: nn.Module,
+ # TODO: make the types Optional[FakeQuantizeConfig] once we
+ # support torchao 0.7+ by default
+ activation_qat_config: Optional["FakeQuantizeConfig"] = None,
+ weight_qat_config: Optional["FakeQuantizeConfig"] = None,
+) -> None:
+ """
+ Swap all `LoRALinear` in the model with `QATLoRALinear`.
+
+ This is used for combining QAT + LoRA during finetuning. The resulting linear layers
+ will apply the following transformation instead:
+
+ x -> fake_quantize(W_frozen) @ fake_quantize(x) + BAx
+
+ Fake quantization here refers to simulating the quantization numerics without actual
+ dtype casting, with the goal of providing improved accuracies when the model is
+ ultimately quantized after finetuning.
+
+ Args:
+ module (nn.Module): The model to swap linear layers on
+ activation_qat_config (Optional[FakeQuantizeConfig]): The config for specifying
+ how to fake quantize input activations in the base linear layer
+ weight_qat_config (Optional[FakeQuantizeConfig]): The config for specifying
+ how to fake quantize base linear weights
+ """
+ for name, child in module.named_children():
+ if isinstance(child, LoRALinear):
+ new_linear = QATLoRALinear.from_lora_linear(
+ child,
+ activation_qat_config,
+ weight_qat_config,
+ )
+ setattr(module, name, new_linear)
+ else:
+ swap_lora_linear_with_qat(
+ child,
+ activation_qat_config,
+ weight_qat_config,
+ )
diff --git a/torchtune/utils/_device.py b/torchtune/utils/_device.py
index 36ca14a358..d4f84cd63e 100644
--- a/torchtune/utils/_device.py
+++ b/torchtune/utils/_device.py
@@ -89,6 +89,8 @@ def _get_device_type_from_env() -> str:
device = "cuda"
elif is_npu_available:
device = "npu"
+ elif torch.xpu.is_available():
+ device = "xpu"
else:
device = "cpu"
return device
@@ -136,7 +138,7 @@ def get_device(device: Optional[str] = None) -> torch.device:
If CUDA-like is available and being used, this function also sets the CUDA-like device.
Args:
- device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu" or "npu".
+ device (Optional[str]): The name of the device to use, e.g. "cuda" or "cpu" or "npu" or "xpu".
Example:
>>> device = get_device("cuda")
@@ -149,7 +151,7 @@ def get_device(device: Optional[str] = None) -> torch.device:
if device is None:
device = _get_device_type_from_env()
device = torch.device(device)
- if device.type in ["cuda", "npu"]:
+ if device.type in ["cuda", "npu", "xpu"]:
device = _setup_device(device)
_validate_device_from_env(device)
return device
@@ -184,16 +186,18 @@ def batch_to_device(batch: dict, device: torch.device) -> None:
class DeviceSupport(Enum):
"""
This is a simple enum for compute devices,
- This currently only supports CPU, CUDA, NPU.
+ This currently only supports CPU, CUDA, NPU, and XPU.
The following enumeration defines various device configurations with attributes:
- 1. `device_type` (str): The type of device (e.g., "cpu", "cuda", "npu").
- 2. `device_name` (str): A user-friendly name for the device (e.g., "CPU", "GPU", "NPU").
- 3. `communication_backend` (str): Specifies the backend used for communication on this device (e.g., "gloo", "nccl", "hccl").
+ 1. `device_type` (str): The type of device (e.g., "cpu", "cuda", "npu", "xpu").
+ 2. `device_name` (str): A user-friendly name for the device (e.g., "CPU", "GPU", "NPU", "XPU").
+ 3. `communication_backend` (str): Specifies the backend used for communication on this device
+ (e.g., "gloo", "nccl", "hccl", "ccl").
"""
CPU = ("cpu", "CPU", "gloo")
CUDA = ("cuda", "GPU", "nccl")
NPU = ("npu", "NPU", "hccl")
+ XPU = ("xpu", "XPU", "ccl")
def __init__(
self,
@@ -216,7 +220,7 @@ def from_type(device_type: str):
def get_device_support() -> DeviceSupport:
"""function that gets the DeviceSupport with compute devices based on the current machine.
- This currently only supports CPU, CUDA, NPU.
+ This currently only supports CPU, CUDA, NPU, XPU.
Returns:
device_support: DeviceSupport