Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.2.2 #49

Merged
merged 12 commits into from
Oct 25, 2023
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,8 @@ trak_results/

# session
Session.vim
local/
*trak_results/
slurm-*.out
A100.pt
H100.pt
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
author = 'Kristian Georgiev'

# The full version, including alpha/beta/rc tags
release = '0.2.1'
version = '0.2.1'
release = '0.2.2'
version = '0.2.2'


# -- General configuration ---------------------------------------------------
Expand Down
25 changes: 12 additions & 13 deletions examples/qnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

from argparse import ArgumentParser
import sys
from tqdm import tqdm

import torch as ch
Expand All @@ -21,7 +20,6 @@

# Huggingface
from datasets import load_dataset
import transformers
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
Expand All @@ -30,7 +28,6 @@
)



GLUE_TASK_TO_KEYS = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
Expand All @@ -44,8 +41,8 @@
}

# NOTE: CHANGE THIS IF YOU WANT TO RUN ON FULL DATASET
TRAIN_SET_SIZE = 5_000
VAL_SET_SIZE = 1_00
TRAIN_SET_SIZE = 50_000
VAL_SET_SIZE = 5_463


class SequenceClassificationModel(nn.Module):
Expand Down Expand Up @@ -76,8 +73,8 @@ def __init__(self):

def forward(self, input_ids, token_type_ids, attention_mask):
return self.model(input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask).logits
token_type_ids=token_type_ids,
attention_mask=attention_mask).logits


def get_dataset(split, inds=None):
Expand All @@ -88,10 +85,9 @@ def get_dataset(split, inds=None):
use_auth_token=None,
)
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)
sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS['qnli']

label_to_id = None #{v: i for i, v in enumerate(label_list)}
label_to_id = None # {v: i for i, v in enumerate(label_list)}

tokenizer = AutoTokenizer.from_pretrained(
'bert-base-cased',
Expand All @@ -102,7 +98,7 @@ def get_dataset(split, inds=None):
)

padding = "max_length"
max_seq_length=128
max_seq_length = 128

def preprocess_function(examples):
# Tokenize the texts
Expand All @@ -113,7 +109,7 @@ def preprocess_function(examples):

# Map labels to IDs (not necessary for GLUE tasks)
if label_to_id is not None and "label" in examples:
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
result["label"] = [(label_to_id[lbl] if lbl != -1 else -1) for lbl in examples["label"]]
return result

raw_datasets = raw_datasets.map(
Expand Down Expand Up @@ -178,10 +174,13 @@ def process_batch(batch):

traker.finalize_features()

traker.start_scoring_checkpoint(model.state_dict(), model_id=0, num_targets=VAL_SET_SIZE)
traker.start_scoring_checkpoint(exp_name='qnli',
checkpoint=model.state_dict(),
model_id=0,
num_targets=VAL_SET_SIZE)
for batch in tqdm(loader_val, desc='Scoring..'):
batch = process_batch(batch)
batch = [x.cuda() for x in batch]
traker.score(batch=batch, num_samples=batch[0].shape[0])

scores = traker.finalize_scores()
scores = traker.finalize_scores(exp_name='qnli')
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from setuptools import setup

setup(name="traker",
version="0.2.1",
version="0.2.2",
description="TRAK: Attributing Model Behavior at Scale",
long_description="Check https://trak.csail.mit.edu/ to learn more about TRAK",
author="MadryLab",
Expand Down
1 change: 0 additions & 1 deletion tests/test_integration_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,3 @@ def test_cifar10(tmp_path, device='cpu'):
@pytest.mark.cuda
def test_cifar10_cuda(tmp_path):
test_cifar10(tmp_path, device='cuda:0')

165 changes: 165 additions & 0 deletions tests/test_integration_qnli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
import pytest
import logging

from trak import TRAKer

from datasets import load_dataset
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
default_data_collator,
)


GLUE_TASK_TO_KEYS = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}

# for testing purposes
TRAIN_SET_SIZE = 20
VAL_SET_SIZE = 10


class SequenceClassificationModel(nn.Module):
"""
Wrapper for HuggingFace sequence classification models.
"""
def __init__(self):
super().__init__()
self.config = AutoConfig.from_pretrained(
'bert-base-cased',
num_labels=2,
finetuning_task='qnli',
cache_dir=None,
revision='main',
use_auth_token=None,
)

self.model = AutoModelForSequenceClassification.from_pretrained(
'bert-base-cased',
config=self.config,
cache_dir=None,
revision='main',
use_auth_token=None,
ignore_mismatched_sizes=False
)

self.model.eval().cuda()

def forward(self, input_ids, token_type_ids, attention_mask):
return self.model(input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask).logits


def get_dataset(split, inds=None):
raw_datasets = load_dataset(
"glue",
'qnli',
cache_dir=None,
use_auth_token=None,
)
sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS['qnli']

tokenizer = AutoTokenizer.from_pretrained(
'bert-base-cased',
cache_dir=None,
use_fast=True,
revision='main',
use_auth_token=False
)

padding = "max_length"
max_seq_length = 128

def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)

return result

raw_datasets = raw_datasets.map(
preprocess_function,
batched=True,
load_from_cache_file=(not False),
desc="Running tokenizer on dataset",
)

if split == 'train':
train_dataset = raw_datasets["train"]
ds = train_dataset
else:
eval_dataset = raw_datasets["validation"]
ds = eval_dataset
return ds


def init_loaders(batch_size=10):
ds_train = get_dataset('train')
ds_train = ds_train.select(range(TRAIN_SET_SIZE))
ds_val = get_dataset('val')
ds_val = ds_val.select(range(VAL_SET_SIZE))
return DataLoader(ds_train, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator), \
DataLoader(ds_val, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator)


def process_batch(batch):
return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels']


# model too large to test on CPU
@pytest.mark.cuda
def test_qnli(tmp_path, device='cuda'):
loader_train, loader_val = init_loaders()

# no need to load model from checkpoint, just testing featurization and scoring
model = SequenceClassificationModel()

logger = logging.getLogger('QNLI')
logger.setLevel(logging.DEBUG)
logger.info(f'Initializing TRAKer with device {device}')

traker = TRAKer(model=model,
task='text_classification',
train_set_size=TRAIN_SET_SIZE,
save_dir=tmp_path,
device=device,
logging_level=logging.DEBUG,
proj_dim=512)

logger.info('Loading checkpoint')
traker.load_checkpoint(model.state_dict(), model_id=0)
logger.info('Loaded checkpoint')
for batch in tqdm(loader_train, desc='Featurizing..'):
# process batch into compatible form for TRAKer TextClassificationModelOutput
batch = process_batch(batch)
batch = [x.to(device) for x in batch]
traker.featurize(batch=batch, num_samples=batch[0].shape[0])

traker.finalize_features()

traker.start_scoring_checkpoint(exp_name='qnli',
checkpoint=model.state_dict(),
model_id=0,
num_targets=VAL_SET_SIZE)
for batch in tqdm(loader_val, desc='Scoring..'):
batch = process_batch(batch)
batch = [x.to(device) for x in batch]
traker.score(batch=batch, num_samples=batch[0].shape[0])

traker.finalize_scores(exp_name='qnli')
Loading
Loading