Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an LLM fine-tuning example #90

Merged
merged 43 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f103347
WIP: Add an LLM finetuning example
lebrice Nov 7, 2024
9b9d698
WIP: add / rename more configs
lebrice Nov 7, 2024
d0500b4
Finetuning example seems to be working
lebrice Nov 8, 2024
e68a1dc
Making progress, more self-contained example
lebrice Nov 8, 2024
59060d9
Works! (need to fix the hash used for path though)
lebrice Nov 8, 2024
a678a93
Improve hashing, reduce default block size
lebrice Nov 8, 2024
4c11b99
Fix val_loss logging and add docstring
lebrice Nov 8, 2024
7b9fb19
Increase the number of dataloader workers
lebrice Nov 8, 2024
cfa3a26
Use smaller model for now
lebrice Nov 11, 2024
10c07eb
Use FSDP in the example
lebrice Nov 11, 2024
5c4f659
Fix bug in id generation from config classes
lebrice Nov 11, 2024
f1fced9
Tweak config, try to setup mid-epoch checkpointing
lebrice Nov 11, 2024
f02e4da
Rename `HFExample` -> `TextClassificationExample`
lebrice Nov 11, 2024
09a5b74
Fix broken links in nav
lebrice Nov 11, 2024
74161cc
Remove "huggingface" datamodule config
lebrice Nov 12, 2024
ef454af
Fix issues in config/tests for text_classification
lebrice Nov 12, 2024
504ece9
Add an entry to test the llm_finetuning_example
lebrice Nov 12, 2024
32f4f91
Fix issues in the text classification example
lebrice Nov 12, 2024
ade8bc0
Fix weird docstring issues with hydra-zen
lebrice Nov 12, 2024
5bca6b9
Fix test and config of text_classification_example
lebrice Nov 12, 2024
8749842
Move test from main_test.py to example_test.py
lebrice Nov 12, 2024
4f5e4fb
forward_pass is a method of LearningAlgorithmTests
lebrice Nov 12, 2024
107176e
Various type hint fixes and tweaks
lebrice Nov 12, 2024
3f08a75
WIP: Adding some tests for LLM finetuning example
lebrice Nov 12, 2024
e0a26b9
Fix issue in `jax.md`
lebrice Nov 12, 2024
3077364
Add link to the example page in index.md
lebrice Nov 12, 2024
d2834bc
Fix tests for the llm finetuning example
lebrice Nov 13, 2024
254224b
Fix issue with tuples in regression files
lebrice Nov 13, 2024
b9bc199
Fix test for `get_hash_of`
lebrice Nov 13, 2024
6f5e367
Remove unused _field function
lebrice Nov 13, 2024
ce61959
Fix issue with built-in modules in autoref plugin
lebrice Nov 13, 2024
39f2226
Add a bit of info in the example doc
lebrice Nov 13, 2024
2ae7e1f
Add more links in the doc of the module
lebrice Nov 13, 2024
54efae6
Fix issue with the text classification example
lebrice Nov 13, 2024
63a530d
Add skipif mark for LLM finetuning test
lebrice Nov 13, 2024
10c52fc
Fix data_dir of text_classification_example
lebrice Nov 14, 2024
7b69c00
Use the "auto" strategy for LLM Finetuning tests
lebrice Nov 14, 2024
59e0673
Fix error in fork_rng of LLM finetuning example
lebrice Nov 14, 2024
922cdb7
Try a hacky fix for failing test
lebrice Nov 14, 2024
e36de4a
Don't run llm finetuning tests on github Cloud CI
lebrice Nov 14, 2024
b06f3bf
Add missing regression files
lebrice Nov 15, 2024
04be192
Rename llm_finetuning_example -> llm_finetuning
lebrice Nov 15, 2024
102c51c
Fix import error
lebrice Nov 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
attention_mask:
device: cpu
max: 1
mean: '1.021e-01'
min: 0
shape:
- 32
- 128
sum: 418
input_ids:
device: cpu
max: 29043
mean: '1.648e+02'
min: 0
shape:
- 32
- 128
sum: 675172
labels:
device: cpu
max: -1
mean: '-1.e+00'
min: -1
shape:
- 32
sum: -32
token_type_ids:
device: cpu
max: 0
mean: '0.e+00'
min: 0
shape:
- 32
- 128
sum: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
attention_mask:
device: cpu
max: 1
mean: '8.374e-02'
min: 0
shape:
- 32
- 128
sum: 343
input_ids:
device: cpu
max: 26101
mean: '1.597e+02'
min: 0
shape:
- 32
- 128
sum: 654306
labels:
device: cpu
max: 1
mean: '7.188e-01'
min: 0
shape:
- 32
sum: 23
token_type_ids:
device: cpu
max: 0
mean: '0.e+00'
min: 0
shape:
- 32
- 128
sum: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
attention_mask:
device: cpu
max: 1
mean: '9.277e-02'
min: 0
shape:
- 32
- 128
sum: 380
input_ids:
device: cpu
max: 29043
mean: '1.362e+02'
min: 0
shape:
- 32
- 128
sum: 557879
labels:
device: cpu
max: 1
mean: '7.5e-01'
min: 0
shape:
- 32
sum: 24
token_type_ids:
device: cpu
max: 0
mean: '0.e+00'
min: 0
shape:
- 32
- 128
sum: 0
3 changes: 2 additions & 1 deletion docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
* [Examples 🧪](examples/index.md)
* [Image Classification (⚡)](examples/torch_sl_example.md)
* [Image Classification (jax+⚡)](examples/jax_sl_example.md)
* [NLP (🤗+⚡)](examples/nlp.md)
* [Text Classification (🤗+⚡)](examples/text_classification.md)
* [Fine-tuning an LLM (🤗+⚡)](examples/llm_finetuning.md)
* [RL (jax)](examples/jax_rl_example.md)
* [Running sweeps](examples/sweeps.md)
* [Profiling your code📎](examples/profiling.md)
Expand Down
23 changes: 17 additions & 6 deletions docs/examples/index.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
---
additional_python_references:
- project.algorithms.jax_rl_example
- project.algorithms.example
- project.algorithms.jax_example
- project.algorithms.text_classification_example
- project.algorithms.llm_finetuning
- project.trainers.jax_trainer
---

# Examples

This template includes examples that use either Jax, PyTorch, or both!

| Example link | Research Area | Reference link | Frameworks |
| --------------------------------------- | ------------------------------------------ | ------------------ | --------------- |
| [ExampleAlgorithm](torch_sl_example.md) | Supervised Learning (image classification) | `ExampleAlgorithm` | Torch + ⚡ |
| [JaxExample](jax_sl_example.md) | Supervised Learning (image classification) | `JaxExample` | Torch + Jax + ⚡ |
| [HFExample](nlp.md) | NLP (text classification) | `HFExample` | Torch + 🤗 + ⚡ |
| [JaxRLExample](jax_rl_example.md) | RL | `JaxRLExample` | Jax |
| Example link | Research Area | Reference link | Frameworks |
| --------------------------------------------------- | ------------------------------------------ | --------------------------- | --------------- |
| [ExampleAlgorithm](torch_sl_example.md) | Supervised Learning (image classification) | `ExampleAlgorithm` | Torch + ⚡ |
| [JaxExample](jax_sl_example.md) | Supervised Learning (image classification) | `JaxExample` | Torch + Jax + ⚡ |
| [TextClassificationExample](text_classification.md) | NLP (text classification) | `TextClassificationExample` | Torch + 🤗 + ⚡ |
| [JaxRLExample](jax_rl_example.md) | RL | `JaxRLExample` | Jax |
| [LLMFinetuningExample](llm_finetuning.md) | NLP (Causal language modeling) | `LLMFineTuningExample` | Torch + 🤗 + ⚡ |
22 changes: 22 additions & 0 deletions docs/examples/llm_finetuning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
---
additional_python_references:
- project.algorithms.llm_finetuning
---
# Fine-tuning LLMs

This example is based on [this language modeling example from the HuggingFace transformers documentation](https://huggingface.co/docs/transformers/en/tasks/language_modeling).

To better understand what's going on in this example, it is a good idea to read through these tutorials first:
* [Causal language modeling simple example - HuggingFace docs](https://huggingface.co/docs/transformers/en/tasks/language_modeling)
* [Fine-tune a language model - Colab Notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb#scrollTo=X6HrpprwIrIz)

The main difference between this example and the original example from HuggingFace is that the `LLMFinetuningExample` is a `LightningModule`, that is trained by a `lightning.Trainer`.

This also means that this example doesn't use [`accelerate`](https://huggingface.co/docs/accelerate/en/index) or the HuggingFace Trainer.


## Running the example

```console
python project/main.py experiment=llm_finetuning_example
```
42 changes: 0 additions & 42 deletions docs/examples/nlp.md

This file was deleted.

41 changes: 41 additions & 0 deletions docs/examples/text_classification.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Text Classification ( + 🤗)

## Overview

The [TextClassificationExample][project.algorithms.text_classification_example.TextClassificationExample] is a [LightningModule][lightning.pytorch.core.module.LightningModule] for a simple text classification task.

It accepts a [TextClassificationDataModule][project.datamodules.text.TextClassificationDataModule] as input, along with a network.

??? note "Click to show the code for HFExample"
{{ inline('project.algorithms.text_classification_example.TextClassificationExample', 4) }}

## Config files

### Algorithm config

??? note "Click to show the Algorithm config"
Source: project/configs/algorithm/text_classification_example.yaml

{{ inline('project/configs/algorithm/text_classification_example.yaml', 4) }}

### Datamodule config

??? note "Click to show the Datamodule config"
Source: project/configs/datamodule/glue_cola.yaml

{{ inline('project/configs/datamodule/glue_cola.yaml', 4) }}

## Running the example

Here is a configuration file that you can use to launch a simple experiment:

??? note "Click to show the yaml config file"
Source: project/configs/experiment/text_classification_example.yaml

{{ inline('project/configs/experiment/text_classification_example.yaml', 4) }}

You can use it like so:

```console
python project/main.py experiment=text_classification_example
```
16 changes: 9 additions & 7 deletions docs/features/jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ additional_python_references:
- project.algorithms.jax_rl_example
- project.algorithms.example
- project.algorithms.jax_example
- project.algorithms.hf_example
- project.algorithms.text_classification_example
- project.trainers.jax_trainer
---

Expand All @@ -13,12 +13,14 @@ additional_python_references:

This template includes examples that use either Jax, PyTorch, or both!

| Example link | Reference | Framework | Lightning? |
| ------------------------------------------------- | ------------------ | ----------- | ------------ |
| [ExampleAlgorithm](../examples/jax_sl_example.md) | `ExampleAlgorithm` | Torch | yes |
| [JaxExample](../examples/jax_sl_example.md) | `JaxExample` | Torch + Jax | yes |
| [HFExample](../examples/nlp.md) | `HFExample` | Torch + 🤗 | yes |
| [JaxRLExample](../examples/jax_rl_example.md) | `JaxRLExample` | Jax | no (almost!) |
<!-- TODO: De-duplicate: This is a bit like a duplicate of the table from the examples/index.md -->

| Example link | Reference | Framework | Lightning? |
| --------------------------------------------------------------- | --------------------------- | ----------- | ------------ |
| [ExampleAlgorithm](../examples/jax_sl_example.md) | `ExampleAlgorithm` | Torch | yes |
| [JaxExample](../examples/jax_sl_example.md) | `JaxExample` | Torch + Jax | yes |
| [TextClassificationExample](../examples/text_classification.md) | `TextClassificationExample` | Torch + 🤗 | yes |
| [JaxRLExample](../examples/jax_rl_example.md) | `JaxRLExample` | Jax | no (almost!) |


In fact, here you can mix and match both Jax and Torch code. For example, you can use Jax for your dataloading, your network, or the learning algorithm, all while still benefiting from the nice stuff that comes from using PyTorch-Lightning.
Expand Down
4 changes: 2 additions & 2 deletions project/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .example import ExampleAlgorithm
from .hf_example import HFExample
from .jax_example import JaxExample
from .jax_rl_example import JaxRLExample
from .no_op import NoOp
from .text_classification_example import TextClassificationExample

__all__ = [
"ExampleAlgorithm",
"JaxExample",
"NoOp",
"HFExample",
"TextClassificationExample",
"JaxRLExample",
]
8 changes: 8 additions & 0 deletions project/algorithms/callbacks/samples_per_second.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import time
from typing import Any, Literal

import jax
import torch
from lightning import LightningModule, Trainer
from torch import Tensor
from torch.optim import Optimizer
Expand Down Expand Up @@ -90,6 +92,12 @@ def log(
def get_num_samples(self, batch: BatchType) -> int:
if is_sequence_of(batch, Tensor):
return batch[0].shape[0]
if isinstance(batch, dict):
return next(
v.shape[0]
for v in jax.tree.leaves(batch)
if isinstance(v, torch.Tensor) and v.ndim > 1
)
raise NotImplementedError(
f"Don't know how many 'samples' there are in batch of type {type(batch)}"
)
Expand Down
17 changes: 17 additions & 0 deletions project/algorithms/example_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""Example showing how the test suite can be used to add tests for a new algorithm."""

import pytest
import torch
from transformers import PreTrainedModel

from project.algorithms.testsuites.algorithm_tests import LearningAlgorithmTests
from project.configs import Config
from project.conftest import command_line_overrides
from project.datamodules.image_classification.cifar10 import CIFAR10DataModule
from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
Expand All @@ -12,6 +16,19 @@
from .example import ExampleAlgorithm


@pytest.mark.parametrize(
command_line_overrides.__name__, ["algorithm=example datamodule=cifar10"], indirect=True
)
def test_example_experiment_defaults(experiment_config: Config) -> None:
"""Test to check that the datamodule is required (even when just an algorithm is set?!)."""

assert experiment_config.algorithm["_target_"] == (
ExampleAlgorithm.__module__ + "." + ExampleAlgorithm.__qualname__
)

assert isinstance(experiment_config.datamodule, CIFAR10DataModule)


@run_for_all_configs_of_type("algorithm", ExampleAlgorithm)
@run_for_all_configs_of_type("datamodule", ImageClassificationDataModule)
@run_for_all_configs_of_type("algorithm/network", torch.nn.Module, excluding=PreTrainedModel)
Expand Down
Loading
Loading