Skip to content

Commit

Permalink
Merge pull request #30 from pomonam/dev
Browse files Browse the repository at this point in the history
Merge dev branch
  • Loading branch information
pomonam authored Jul 13, 2024
2 parents 1464e19 + ad2d082 commit d204e2d
Show file tree
Hide file tree
Showing 147 changed files with 21,350 additions and 4,258 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ jobs:
pytest -vx tests/test_dataset_utils.py
pytest -vx tests/test_testable_tasks.py
pytest -vx tests/factors/test_covariances.py
pytest -vx tests/factors/test_eigens.py
pytest -vx tests/factors/test_eigendecompositions.py
pytest -vx tests/factors/test_lambdas.py
pytest -vx tests/modules/test_modules.py
pytest -vx tests/modules/test_per_sample_gradients.py
pytest -vx tests/modules/test_matmul.py
pytest -vx tests/scores/test_pairwise_scores.py
pytest -vx tests/scores/test_self_scores.py
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ cython_debug/

# Checkpoints and influence outputs
checkpoints/
analyses/
influence_results/
data/
cache/
*.pth
*.pt
211 changes: 125 additions & 86 deletions DOCUMENTATION.md

Large diffs are not rendered by default.

65 changes: 55 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@

---

> **Kronfluence** is a research repository designed to compute [influence functions](https://arxiv.org/abs/1703.04730) using [Kronecker-factored Approximate Curvature (KFAC)](https://arxiv.org/abs/1503.05671) or [Eigenvalue-corrected KFAC (EKFAC)](https://arxiv.org/abs/1806.03884).
For a detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296) *Studying Large Language Model Generalization with Influence Functions*.
> **Kronfluence** is a PyTorch package designed to compute [influence functions](https://arxiv.org/abs/1703.04730) using [Kronecker-factored Approximate Curvature (KFAC)](https://arxiv.org/abs/1503.05671) or [Eigenvalue-corrected KFAC (EKFAC)](https://arxiv.org/abs/1806.03884).
For detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296), *Studying Large Language Model Generalization with Influence Functions*.

---

> [!WARNING]
> This repository is under active development and has not reached its first stable release.
## Installation

> [!IMPORTANT]
Expand All @@ -53,11 +50,9 @@ pip install -e .

## Getting Started

Kronfluence supports influence computations on `nn.Linear` and `nn.Conv2d` modules. See the [**Technical Documentation**](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md) page for a comprehensive guide.

### Learn More
Kronfluence supports influence computations on [`nn.Linear`](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) and [`nn.Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) modules.
See the [**Technical Documentation**](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md) page for a comprehensive guide.

The [examples](https://github.com/pomonam/kronfluence/tree/main/examples) folder contains several examples demonstrating how to use Kronfluence. More examples will be added in the future.
**TL;DR** You need to prepare a trained model and datasets, and pass them into the `Analyzer` class.

```python
Expand Down Expand Up @@ -115,6 +110,30 @@ analyzer.compute_pairwise_scores(
scores = analyzer.load_pairwise_scores(scores_name="my_scores")
```

Kronfluence supports various PyTorch features; the following table summarizes the supported features:

<div align="center">

| Feature | Supported |
|-----------------------------------------------------------------------------------------------------------------------------|:---------:|
| [Distributed Data Parallel (DDP)](https://pytorch.org/docs/master/generated/torch.nn.parallel.DistributedDataParallel.html) ||
| [Automatic Mixed Precision (AMP)](https://pytorch.org/docs/stable/amp.html) ||
| [Torch Compile](https://pytorch.org/docs/stable/generated/torch.compile.html) ||
| [Gradient Checkpointing](https://pytorch.org/docs/stable/checkpoint.html) ||
| [Fully Sharded Data Parallel (FSDP)](https://pytorch.org/docs/stable/fsdp.html) ||

</div>

The [examples](https://github.com/pomonam/kronfluence/tree/main/examples) folder contains several examples demonstrating how to use Kronfluence.

## LogIX

While Kronfluence supports influence function computations on large-scale models like `Meta-Llama-3-8B-Instruct`, for those
interested in running influence analysis on even larger models or with a large number of query data points, our
project [LogIX](https://github.com/logix-project/logix) may be worth exploring. It integrates with frameworks like
[HuggingFace Trainer](https://huggingface.co/docs/transformers/en/main_classes/trainer) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/)
and is also compatible with many PyTorch features (DDP & FSDP & [DeepSpeed](https://github.com/microsoft/DeepSpeed)).

## Contributing

Contributions are welcome! To get started, please review our [Code of Conduct](https://github.com/pomonam/kronfluence/blob/main/CODE_OF_CONDUCT.md). For bug fixes, please submit a pull request.
Expand All @@ -131,10 +150,36 @@ cd kronfluence
pip install -e ."[dev]"
```

### Style Testing

To maintain code quality and consistency, we run ruff and linting tests on pull requests. Before submitting a
pull request, please ensure that your code adheres to our formatting and linting guidelines. The following commands will
modify your code. It is recommended to create a Git commit before running them to easily revert any unintended changes.

Sort import orderings using [isort](https://pycqa.github.io/isort/):

```bash
isort kronfluence
```

Format code using [ruff](https://docs.astral.sh/ruff/):

```bash
ruff format kronfluence
```

To view all [pylint](https://www.pylint.org/) complaints, run the following command:

```bash
pylint kronfluence
```

Please address any reported issues before submitting your PR.

## Acknowledgements

[Omkar Dige](https://github.com/xeon27) contributed to the profiling, DDP, and FSDP utilities, and [Adil Asif](https://github.com/adil-a/) provided valuable insights and suggestions on structuring the DDP and FSDP implementations.
I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.
I also thank Hwijeen Ahn, Sang Keun Choe, Youngseog Chung, Minsoo Kang, Sophie Liao, Lev McKinney, Laura Ruis, Andrew Wang, and Kewen Zhao for their feedback.

## License

Expand Down
27 changes: 14 additions & 13 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
torch
torchvision
accelerate
einops
einconv
opt_einsum
safetensors
tqdm
datasets
transformers
torch>=2.1.0
torchvision>=0.16.0
accelerate>=0.31.0
einops>=0.8.0
einconv>=0.1.0
opt_einsum>=3.3.0
scikit-learn>=1.4.0
safetensors>=0.4.2
tqdm>=4.66.4
datasets>=2.20.0
transformers>=4.42.0
isort==5.13.2
pylint==3.0.3
pytest==8.0.0
ruff==0.3.0
pylint==3.2.3
pytest==8.2.2
ruff==0.4.0
33 changes: 33 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Kronfluence: Examples

For detailed technical documentation of Kronfluence, please refer to the [Technical Documentation](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md) page.

## Getting Started

To run all examples, install the necessary packages:

```bash
pip install -r requirements.txt
```

Alternatively, navigate to each example folder and run `pip install -r requirements.txt`.

## List of Tasks

Our examples cover the following tasks:

<div align="center">

| Task | Example Datasets |
|----------------------|:------------------------:|
| Regression | UCI |
| Image Classification | CIFAR-10 & ImageNet |
| Text Classification | GLUE |
| Multiple-Choice | SWAG |
| Summarization | DNN/DailyMail |
| Language Modeling | WikiText-2 & OpenWebText |

</div>

These examples demonstrate various use cases of Kronfluence, including the usage of AMP (Automatic Mixed Precision) and DDP (Distributed Data Parallel).
Many examples aim to replicate the settings used in [our paper](https://arxiv.org/abs/2405.12186). If you would like to see more examples added to this repository, please leave an issue.
123 changes: 110 additions & 13 deletions examples/cifar/README.md
Original file line number Diff line number Diff line change
@@ -1,57 +1,154 @@
# CIFAR-10 & ResNet-9 Example

This directory contains scripts for training ResNet-9 on CIFAR-10. The pipeline is motivated from
[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb).
This directory contains scripts for training ResNet-9 and computing influence scores on CIFAR-10 dataset. The pipeline is motivated from
[TRAK repository](https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb). To get started, please install the necessary packages by running the following command:

```bash
pip install -r requirements.txt
```

## Training

To train ResNet-9 on CIFAR-10 dataset, run the following command:
To train ResNet-9 on the CIFAR-10 dataset, run the following command:

```bash
python train.py --dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--train_batch_size 512 \
--eval_batch_size 1024 \
--learning_rate 0.4 \
--weight_decay 0.0001 \
--weight_decay 0.001 \
--num_train_epochs 25 \
--seed 1004
```

This will train the model using the specified hyperparameters and save the trained checkpoint in the `./checkpoints` directory.

## Computing Pairwise Influence Scores

To obtain pairwise influence scores on 2000 query data points using `ekfac`, run the following command:
To compute pairwise influence scores on 2000 query data points using the `ekfac` strategy, run the following command:

```bash
python analyze.py --query_batch_size 1000 \
--dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```
You can also use `identity`, `diagonal`, and `kfac`. On A100 (80GB), it takes roughly 1.5 minutes to compute the
pairwise scores (including computing EKFAC factors).

In addition to `ekfac`, you can also use `identity`, `diagonal`, and `kfac` as the `factor_strategy`. On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the pairwise scores (including computing the EKFAC factors):

```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
| Total | - | 11 | 106.38 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
| Compute Pairwise Score | 46.745 | 1 | 46.745 | 43.941 |
| Fit Lambda | 34.885 | 1 | 34.885 | 32.793 |
| Fit Covariance | 22.538 | 1 | 22.538 | 21.187 |
| Perform Eigendecomposition | 0.91424 | 1 | 0.91424 | 0.85941 |
| Save Pairwise Score | 0.81219 | 1 | 0.81219 | 0.76348 |
| Save Covariance | 0.22351 | 1 | 0.22351 | 0.21011 |
| Save Eigendecomposition | 0.21617 | 1 | 0.21617 | 0.20321 |
| Save Lambda | 0.031038 | 1 | 0.031038 | 0.029177 |
| Load Eigendecomposition | 0.010442 | 1 | 0.010442 | 0.0098156 |
| Load All Factors | 0.0026517 | 1 | 0.0026517 | 0.0024927 |
| Load Covariance | 0.0016419 | 1 | 0.0016419 | 0.0015435 |
----------------------------------------------------------------------------------------------------------------------------------
```

To use AMP when computing influence scores, run:

```bash
python analyze.py --query_batch_size 1000 \
--dataset_dir ./data \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac \
--use_half_precision
```

This reduces computation time to about 40 seconds on an A100 (80GB) GPU:

```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
| Total | - | 11 | 35.965 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
| Compute Pairwise Score | 18.012 | 1 | 18.012 | 50.082 |
| Fit Lambda | 9.2271 | 1 | 9.2271 | 25.656 |
| Fit Covariance | 7.134 | 1 | 7.134 | 19.836 |
| Perform Eigendecomposition | 0.87962 | 1 | 0.87962 | 2.4457 |
| Save Pairwise Score | 0.45432 | 1 | 0.45432 | 1.2632 |
| Save Covariance | 0.12861 | 1 | 0.12861 | 0.35759 |
| Save Eigendecomposition | 0.11296 | 1 | 0.11296 | 0.31407 |
| Save Lambda | 0.010712 | 1 | 0.010712 | 0.029784 |
| Load All Factors | 0.002736 | 1 | 0.002736 | 0.0076074 |
| Load Covariance | 0.0016696 | 1 | 0.0016696 | 0.0046421 |
| Load Eigendecomposition | 0.0014892 | 1 | 0.0014892 | 0.0041406 |
----------------------------------------------------------------------------------------------------------------------------------
```

You can run `half_precision_analysis.py` to verify that the scores computed with AMP have high correlations with those of the default configuration.

<p align="center">
<a href="#"><img width="380" img src="figure/half_precision.png" alt="Half Precision"/></a>
</p>

## Visualizing Influential Training Images

[This Colab notebook](https://colab.research.google.com/drive/1KIwIbeJh_om4tRwceuZ005fVKDsiXKgr?usp=sharing) provides a tutorial on visualizing the top influential training images.

## Mislabeled Data Detection

We can use self-influence scores (see Section 5.4 for the [paper](https://arxiv.org/pdf/1703.04730.pdf)) to detect mislabeled examples.
First, train the model with 10% of training examples mislabeled by running the following command:
We can use self-influence scores (see **Section 5.4** for the [paper](https://arxiv.org/pdf/1703.04730.pdf)) to detect mislabeled examples.
First, train the model with 10% of the training examples mislabeled by running:

```bash
python train.py --dataset_dir ./data \
--corrupt_percentage 0.1 \
--checkpoint_dir ./checkpoints \
--train_batch_size 512 \
--eval_batch_size 1024 \
--learning_rate 0.4 \
--weight_decay 0.0001 \
--weight_decay 0.001 \
--num_train_epochs 25 \
--seed 1004
```

Then, compute self-influence scores with the following command:
Then, compute the self-influence scores with:

```bash
python detect_mislabeled_dataset.py --dataset_dir ./data \
--corrupt_percentage 0.1 \
--checkpoint_dir ./checkpoints \
--factor_strategy ekfac
```

On A100 (80GB), it takes roughly 1.5 minutes to compute the self-influence scores.
We can detect around 82% of mislabeled data points by inspecting 10% of the dataset (96% by inspecting 20%).
On an A100 (80GB) GPU, it takes roughly 2 minutes to compute the self-influence scores:

```
----------------------------------------------------------------------------------------------------------------------------------
| Action | Mean duration (s) | Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------
| Total | - | 11 | 121.85 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------
| Compute Self-Influence Score | 62.778 | 1 | 62.778 | 51.519 |
| Fit Lambda | 35.174 | 1 | 35.174 | 28.866 |
| Fit Covariance | 22.582 | 1 | 22.582 | 18.532 |
| Perform Eigendecomposition | 0.82656 | 1 | 0.82656 | 0.67832 |
| Save Covariance | 0.2478 | 1 | 0.2478 | 0.20336 |
| Save Eigendecomposition | 0.22042 | 1 | 0.22042 | 0.18088 |
| Save Lambda | 0.018463 | 1 | 0.018463 | 0.015152 |
| Load All Factors | 0.0027554 | 1 | 0.0027554 | 0.0022612 |
| Load Covariance | 0.0016607 | 1 | 0.0016607 | 0.0013628 |
| Load Eigendecomposition | 0.0015408 | 1 | 0.0015408 | 0.0012645 |
| Save Self-Influence Score | 0.0010841 | 1 | 0.0010841 | 0.00088966 |
----------------------------------------------------------------------------------------------------------------------------------
```

Around 80% of mislabeled data points can be detected by inspecting 10% of the dataset (97% by inspecting 20%).

<p align="center">
<a href="#"><img width="380" img src="figure/mislabel.png" alt="Mislabeled Data Detection"/></a>
</p>
Loading

0 comments on commit d204e2d

Please sign in to comment.