Skip to content

Commit

Permalink
Merge pull request #3 from pomonam/documentation
Browse files Browse the repository at this point in the history
Code refactor + complete tests
  • Loading branch information
pomonam authored Mar 19, 2024
2 parents 0f2075a + 493ebc3 commit f37d027
Show file tree
Hide file tree
Showing 65 changed files with 5,292 additions and 2,627 deletions.
410 changes: 404 additions & 6 deletions DOCUMENTATION.md

Large diffs are not rendered by default.

18 changes: 7 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@

---

> **Kronfluence** is a PyTorch-based library designed to compute [influence functions](https://arxiv.org/abs/1703.04730) using [Kronecker-factored Approximate Curvature (KFAC)](https://arxiv.org/abs/1503.05671) or [Eigenvalue-corrected KFAC (EKFAC)](https://arxiv.org/abs/1806.03884).
> **Kronfluence** is a repository designed to compute [influence functions](https://arxiv.org/abs/1703.04730) using [Kronecker-factored Approximate Curvature (KFAC)](https://arxiv.org/abs/1503.05671) or [Eigenvalue-corrected KFAC (EKFAC)](https://arxiv.org/abs/1806.03884).
For a detailed description of the methodology, see the [**paper**](https://arxiv.org/abs/2308.03296) *Studying Large Language Model Generalization with Influence Functions*.

---

> [!WARNING]
> This library is under active development and has not reached its first stable release.
> This repository is under active development and has not reached its first stable release.
## Installation

Expand All @@ -50,17 +50,13 @@ pip install -e .

## Getting Started

Kronfluence currently supports influence computations on `nn.Linear` and `nn.Conv2d` modules.
It also supports several other Hessian approximation techniques: `identity`, `diagonal`, `KFAC`, and `EKFAC`.
The implementation is compatible with [Distributed Data Parallel (DDP)](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html),
[Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html).
See [DOCUMENTATION.md](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md) for detailed description on how to configure the experiment.
Kronfluence supports influence computations on `nn.Linear` and `nn.Conv2d` modules. See the [**Technical Documentation**](https://github.com/pomonam/kronfluence/blob/main/DOCUMENTATION.md)
page for a comprehensive guide on configuring the experiment.

### Examples

The [examples](https://github.com/pomonam/kronfluence/tree/main/examples) folder contains several examples on how to use Kronfluence.

**TL;DR:** You need to prepare the trained model and datasets, and pass them into the `Analyzer`.
We plan to add more language model examples. **TL;DR** You need to prepare the trained model and datasets, and pass them into the `Analyzer`.

```python
import torch
Expand Down Expand Up @@ -94,7 +90,7 @@ eval_dataset = torchvision.datasets.MNIST(
train=True,
)

# Initialize the task for MNIST with relevant loss and measurement function.
# Initialize the task with relevant loss and measurement.
task = MnistTask()

# Prepare the model for influence computation with the specified task.
Expand All @@ -104,7 +100,7 @@ analyzer = Analyzer(analysis_name="mnist", model=model, task=task)
# Fit all EKFAC factors for the given model on the training dataset.
analyzer.fit_all_factors(factors_name="ekfac", dataset=train_dataset)

# Compute all pairwise influence scores using the fitted factors.
# Compute all pairwise influence scores using the computed factors.
analyzer.compute_pairwise_scores(
scores_name="pairwise_scores",
factors_name="ekfac",
Expand Down
3 changes: 2 additions & 1 deletion examples/_test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
scikit-learn
jupyter
jupyter
evaluate
22 changes: 12 additions & 10 deletions examples/cifar/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy
import math
from typing import Dict, List, Optional, Tuple
from typing import List, Optional

import datasets
import numpy as np
import torch
import torchvision
Expand Down Expand Up @@ -32,15 +33,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def construct_resnet9() -> nn.Module:
# ResNet-9 architecture from: https://github.com/MadryLab/trak/blob/main/examples/cifar_quickstart.ipynb.
def conv_bn(
channels_in: int,
channels_out: int,
kernel_size: int = 3,
stride: int = 1,
padding: int = 1,
groups=1,
groups: int = 1,
) -> nn.Module:
assert groups == 1
return torch.nn.Sequential(
torch.nn.Conv2d(
channels_in,
Expand Down Expand Up @@ -73,10 +74,10 @@ def conv_bn(

def get_cifar10_dataset(
split: str,
do_corrupt: bool,
indices: List[int] = None,
data_dir: str = "data/",
):
corrupt_percentage: Optional[float] = None,
dataset_dir: str = "data/",
) -> datasets.Dataset:
assert split in ["train", "eval_train", "valid"]

normalize = torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261))
Expand All @@ -98,16 +99,17 @@ def get_cifar10_dataset(
)

dataset = torchvision.datasets.CIFAR10(
root=data_dir,
root=dataset_dir,
download=True,
train=split in ["train", "eval_train", "eval_train_with_aug"],
train=split in ["train", "eval_train"],
transform=transform_config,
)

if do_corrupt:
if corrupt_percentage is not None:
if split == "valid":
raise NotImplementedError("Performing corruption on the validation dataset is not supported.")
num_corrupt = math.ceil(len(dataset) * 0.1)
assert 0.0 < corrupt_percentage <= 1.0
num_corrupt = math.ceil(len(dataset) * corrupt_percentage)
original_targets = np.array(copy.deepcopy(dataset.targets[:num_corrupt]))
new_targets = torch.randint(
0,
Expand Down
192 changes: 192 additions & 0 deletions examples/cifar/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import argparse
import logging
import os
from typing import Tuple

import numpy as np
import torch
import torch.nn.functional as F
from accelerate.utils import set_seed
from torch import nn
from torch.optim import lr_scheduler
from torch.utils import data
from tqdm import tqdm

from examples.cifar.pipeline import construct_resnet9, get_cifar10_dataset

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def parse_args():
parser = argparse.ArgumentParser(description="Train ResNet-9 model on CIFAR-10 dataset.")

parser.add_argument(
"--corrupt_percentage",
type=float,
default=None,
help="Percentage of the training dataset to corrupt.",
)
parser.add_argument(
"--dataset_dir",
type=str,
default="./data",
help="A folder to download or load CIFAR-10 dataset.",
)

parser.add_argument(
"--train_batch_size",
type=int,
default=512,
help="Batch size for the training dataloader.",
)
parser.add_argument(
"--eval_batch_size",
type=int,
default=1024,
help="Batch size for the evaluation dataloader.",
)

parser.add_argument(
"--learning_rate",
type=float,
default=0.4,
help="Initial learning rate to train the model.",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.001,
help="Weight decay to train the model.",
)
parser.add_argument(
"--num_train_epochs",
type=int,
default=25,
help="Total number of epochs to train the model.",
)

parser.add_argument(
"--seed",
type=int,
default=1004,
help="A seed for reproducible training pipeline.",
)
parser.add_argument(
"--checkpoint_dir",
type=str,
default="./checkpoints",
help="A path to store the final checkpoint.",
)

args = parser.parse_args()

if args.checkpoint_dir is not None:
os.makedirs(args.checkpoint_dir, exist_ok=True)

return args


def train(
dataset: data.Dataset,
batch_size: int,
num_train_epochs: int,
learning_rate: float,
weight_decay: float,
disable_tqdm: bool = False,
) -> nn.Module:
train_dataloader = data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True,
)

model = construct_resnet9().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

iters_per_epoch = len(train_dataloader)
lr_peak_epoch = num_train_epochs // 4
lr_schedule = np.interp(
np.arange((num_train_epochs + 1) * iters_per_epoch),
[0, lr_peak_epoch * iters_per_epoch, num_train_epochs * iters_per_epoch],
[0, 1, 0],
)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule.__getitem__)

model.train()
for epoch in range(num_train_epochs):
total_loss = 0.0
with tqdm(train_dataloader, unit="batch", disable=disable_tqdm) as tepoch:
for batch in tepoch:
tepoch.set_description(f"Epoch {epoch}")
model.zero_grad()
inputs, labels = batch
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
outputs = model(inputs)
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
total_loss += loss.detach().float()
tepoch.set_postfix(loss=total_loss.item() / len(train_dataloader))
return model


def evaluate(model: nn.Module, dataset: data.Dataset, batch_size: int) -> Tuple[float, float]:
dataloader = data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False,
)

model.eval()
total_loss, total_correct = 0.0, 0
for batch in dataloader:
with torch.no_grad():
inputs, labels = batch
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
outputs = model(inputs)
loss = F.cross_entropy(outputs, labels, reduction="sum")
total_loss += loss.detach().float()
total_correct += outputs.detach().argmax(1).eq(labels).sum()

return total_loss.item() / len(dataloader.dataset), total_correct.item() / len(dataloader.dataset)


def main():
args = parse_args()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

if args.seed is not None:
set_seed(args.seed)

train_dataset = get_cifar10_dataset(
split="train", corrupt_percentage=args.corrupt_percentage, dataset_dir=args.dataset_dir
)
model = train(
dataset=train_dataset,
batch_size=args.train_batch_size,
num_train_epochs=args.num_train_epochs,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
)

eval_train_dataset = get_cifar10_dataset(split="eval_train", dataset_dir=args.dataset_dir)
train_loss, train_acc = evaluate(model=model, dataset=eval_train_dataset, batch_size=args.eval_batch_size)
logger.info(f"Train loss: {train_loss}, Train Accuracy: {train_acc}")

eval_dataset = get_cifar10_dataset(split="valid", dataset_dir=args.dataset_dir)
eval_loss, eval_acc = evaluate(model=model, dataset=eval_dataset, batch_size=args.eval_batch_size)
logger.info(f"Evaluation loss: {eval_loss}, Evaluation Accuracy: {eval_acc}")

if args.checkpoint_dir is not None:
model_name = "model"
if args.corrupt_percentage is not None:
model_name += "_corrupt_" + str(args.corrupt_percentage)
torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, f"{model_name}.pth"))


if __name__ == "__main__":
main()
7 changes: 3 additions & 4 deletions examples/glue/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
}


def construct_bert(data_name) -> nn.Module:
def construct_bert(data_name: str = "sst2") -> nn.Module:
config = AutoConfig.from_pretrained(
"bert-base-cased",
num_labels=2,
finetuning_task=data_name,
trust_remote_code=True,
)

return AutoModelForSequenceClassification.from_pretrained(
"bert-base-cased",
from_tf=False,
Expand All @@ -42,14 +41,14 @@ def get_glue_dataset(
data_name: str,
split: str,
indices: List[int] = None,
data_path: str = "data/",
dataset_dir: str = "data/",
) -> Dataset:
assert split in ["train", "eval_train", "valid"]

raw_datasets = load_dataset(
path="glue",
name=data_name,
data_dir=data_path,
# data_dir=dataset_dir,
)
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)
Expand Down
Loading

0 comments on commit f37d027

Please sign in to comment.