Skip to content

Commit

Permalink
Fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 12, 2024
1 parent d2f8041 commit b54af93
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 226 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,3 @@ jobs:
- name: Run isort
run: |
isort --profile black kronfluence
actionlint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: reviewdog/action-actionlint@v1
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [push, pull_request]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: chartboost/ruff-action@v1
98 changes: 10 additions & 88 deletions examples/cifar/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import copy
import math
from typing import Dict, List, Optional, Tuple

from torch import nn
import numpy as np
import torch
import torch.nn as nn
import torchvision


class Mul(torch.nn.Module):
class Mul(nn.Module):
def __init__(self, weight: float) -> None:
super().__init__()
self.weight = weight
Expand All @@ -17,12 +16,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.weight


class Flatten(torch.nn.Module):
class Flatten(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.view(x.size(0), -1)


class Residual(torch.nn.Module):
class Residual(nn.Module):
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
Expand Down Expand Up @@ -71,107 +70,37 @@ def conv_bn(
return model


# def get_hyperparameters(data_name: str) -> Dict[str, float]:
# wd = 0.001
# if data_name == "cifar2":
# lr = 0.5
# epochs = 100
# elif data_name == "cifar10":
# lr = 0.4
# epochs = 25
# else:
# raise NotImplementedError()
# return {"lr": lr, "wd": wd, "epochs": epochs}


def get_cifar10_dataset(
split: str,
do_corrupt: bool,
indices: List[int] = None,
data_path: str = "data/",
data_dir: str = "data/",
):
assert split in ["train", "eval_train", "valid"]

normalize = torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261))

if split in ["train", "eval_train"]:
transforms = torchvision.transforms.Compose(
transform_config = torchvision.transforms.Compose(
[
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
normalize,
]
)
else:
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=MEAN, std=STD),
]
)

if split == "train":
transform_config = [
torchvision.transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0)),
torchvision.transforms.RandomHorizontalFlip(),
]
transform_config.extend([torchvision.transforms.ToTensor(), normalize])
transform_config = torchvision.transforms.Compose(transform_config)

else:
transform_config = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(size=256),
torchvision.transforms.CenterCrop(size=224),
torchvision.transforms.ToTensor(),
normalize,
]
)

folder = "train" if split in ["train", "eval_train"] else "val"
dataset = torchvision.datasets.ImageFolder(
root=os.path.join(data_path, folder),
transform=transform_config,
)

if indices is not None:
dataset = torch.utils.data.Subset(dataset, indices)

return dataset


def get_cifar10_dataloader(
batch_size: int,
split: str = "train",
indices: List[int] = None,
do_corrupt: bool = False,
num_workers: int = 4,
) -> torch.utils.data.DataLoader:
MEAN = (0.4914, 0.4822, 0.4465)
STD = (0.247, 0.243, 0.261)

if split in ["train", "eval_train"]:
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=MEAN, std=STD),
]
)
else:
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=MEAN, std=STD),
]
)

dataset = torchvision.datasets.CIFAR10(
root="/tmp/cifar/",
root=data_dir,
download=True,
train=split in ["train", "eval_train", "eval_train_with_aug"],
transform=transforms,
transform=transform_config,
)

if do_corrupt:
Expand All @@ -198,11 +127,4 @@ def get_cifar10_dataloader(
if indices is not None:
dataset = torch.utils.data.Subset(dataset, indices)

return torch.utils.data.DataLoader(
dataset=dataset,
shuffle=split == "train",
batch_size=batch_size,
num_workers=num_workers,
drop_last=split == "train",
pin_memory=True,
)
return dataset
98 changes: 8 additions & 90 deletions examples/glue/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from transformers import default_data_collator

from examples.glue.pipeline import construct_bert, get_glue_dataset
from examples.mnist.pipeline import construct_mnist_mlp, get_mnist_dataset


def parse_args():
Expand Down Expand Up @@ -87,100 +86,19 @@ def main():
args = parse_args()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

if args.seed is not None:
set_seed(args.seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = get_glue_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir)
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=default_data_collator,
drop_last=True,
)
model = construct_bert(args.data_name).to(device=device)
# optimizer = torch.optim.SGD(
# model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay
# )
#
# logger.info("Start training the model.")
# model.train()
# for epoch in range(args.num_train_epochs):
#
# total_loss = 0
#
# with tqdm(train_dataloader, unit="batch") as tepoch:
#
# for batch in tepoch:
# tepoch.set_description(f"Epoch {epoch}")
# inputs, labels = batch
# inputs, labels = inputs.to(device), labels.to(device)
# logits = model(inputs)
# loss = F.cross_entropy(logits, labels)
# total_loss += loss.detach().float()
# loss.backward()
# optimizer.step()
# optimizer.zero_grad()
# tepoch.set_postfix(loss=total_loss.item() / len(train_dataloader))
#
# logger.info("Start evaluating the model.")
# model.eval()
# train_eval_dataset = get_mnist_dataset(
# split="eval_train", data_path=args.dataset_dir
# )
# train_eval_dataloader = DataLoader(
# dataset=train_eval_dataset,
# batch_size=args.eval_batch_size,
# shuffle=False,
# drop_last=False,
# )
# eval_dataset = get_mnist_dataset(split="valid", data_path=args.dataset_dir)
# eval_dataloader = DataLoader(
# dataset=eval_dataset,
# batch_size=args.eval_batch_size,
# shuffle=False,
# drop_last=False,
# )
#
# total_loss = 0
# correct = 0
# for batch in train_eval_dataloader:
# with torch.no_grad():
# inputs, labels = batch
# inputs, labels = inputs.to(device), labels.to(device)
# logits = model(inputs)
# loss = F.cross_entropy(logits, labels)
# preds = logits.argmax(dim=1, keepdim=True)
# correct += preds.eq(labels.view_as(preds)).sum().item()
# total_loss += loss.detach().float()
#
# logger.info(
# f"Train loss: {total_loss.item() / len(train_eval_dataloader.dataset)} | "
# f"Train Accuracy: {100 * correct / len(train_eval_dataloader.dataset)}"
# )
#
# total_loss = 0
# correct = 0
# for batch in eval_dataloader:
# with torch.no_grad():
# inputs, labels = batch
# inputs, labels = inputs.to(device), labels.to(device)
# logits = model(inputs)
# loss = F.cross_entropy(logits, labels)
# preds = logits.argmax(dim=1, keepdim=True)
# correct += preds.eq(labels.view_as(preds)).sum().item()
# total_loss += loss.detach().float()
#
# logger.info(
# f"Train loss: {total_loss.item() / len(eval_dataloader.dataset)} | "
# f"Train Accuracy: {100 * correct / len(eval_dataloader.dataset)}"
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# train_dataset = get_glue_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir)
# train_dataloader = DataLoader(
# dataset=train_dataset,
# batch_size=args.train_batch_size,
# shuffle=True,
# collate_fn=default_data_collator,
# drop_last=True,
# )
#
# if args.checkpoint_dir is not None:
# torch.save(model.state_dict(), os.path.join(args.checkpoint_dir, "model.pth"))


if __name__ == "__main__":
Expand Down
22 changes: 1 addition & 21 deletions examples/imagenet/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ def main():
args = parse_args()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

train_dataset = get_imagenet_dataset(split="eval_train", data_path=args.dataset_dir)
eval_dataset = get_imagenet_dataset(split="valid", data_path=args.dataset_dir)
# eval_dataset = get_imagenet_dataset(split="valid", data_path=args.dataset_dir)

model = construct_resnet50()

Expand All @@ -125,26 +124,7 @@ def main():
factor_args=factor_args,
per_device_batch_size=1024,
overwrite_output_dir=True,
dataloader_num_workers=2,
dataloader_pin_memory=True,
)
# analyzer.perform_eigendecomposition(
# factor_name=args.factor_strategy,
# factor_args=factor_args,
# overwrite_output_dir=True,
# )
# analyzer.fit_lambda(train_dataset, per_device_batch_size=None)
#
# score_name = "full_pairwise"
# analyzer.compute_pairwise_scores(
# score_name=score_name,
# query_dataset=eval_dataset,
# per_device_query_batch_size=len(eval_dataset),
# train_dataset=train_dataset,
# per_device_train_batch_size=len(train_dataset),
# )
# scores = analyzer.load_pairwise_scores(score_name=score_name)
# print(scores.shape)


if __name__ == "__main__":
Expand Down
10 changes: 2 additions & 8 deletions examples/imagenet/ddp_analyze.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
import argparse
import logging
import math
import os
from typing import Dict, Tuple
from typing import Tuple

import torch
import torch.distributed as dist
import torch.nn.functional as F
from analyzer import Analyzer, prepare_model
from arguments import FactorArguments
from module.utils import wrap_tracked_modules
from task import Task
from torch import nn
from torch.nn.parallel.distributed import DistributedDataParallel

from examples.imagenet.pipeline import construct_resnet50, get_imagenet_dataset
from examples.mnist.pipeline import construct_mnist_mlp, get_mnist_dataset

BATCH_DTYPE = Tuple[torch.Tensor, torch.Tensor]
LOCAL_RANK = int(os.environ["LOCAL_RANK"])
Expand Down Expand Up @@ -107,10 +104,9 @@ def main():
args = parse_args()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

train_dataset = get_imagenet_dataset(split="eval_train", data_path=args.dataset_dir)
eval_dataset = get_imagenet_dataset(split="valid", data_path=args.dataset_dir)
# eval_dataset = get_imagenet_dataset(split="valid", data_path=args.dataset_dir)

dist.init_process_group("nccl", rank=WORLD_RANK, world_size=WORLD_SIZE)
device = torch.device("cuda:{}".format(LOCAL_RANK))
Expand Down Expand Up @@ -144,8 +140,6 @@ def main():
factor_args=factor_args,
per_device_batch_size=None,
overwrite_output_dir=True,
dataloader_num_workers=2,
dataloader_pin_memory=True,
)
# analyzer.perform_eigendecomposition(
# factor_name=args.factor_strategy,
Expand Down
5 changes: 0 additions & 5 deletions examples/imagenet/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,3 @@ def get_imagenet_dataset(
dataset = torch.utils.data.Subset(dataset, indices)

return dataset


if __name__ == "__main__":
model = construct_resnet50()
print(model)
1 change: 0 additions & 1 deletion examples/uci/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def main():
args = parse_args()

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()

train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir)
eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", data_path=args.dataset_dir)
Expand Down
Loading

0 comments on commit b54af93

Please sign in to comment.