diff --git a/.gitignore b/.gitignore index bd27852..719dff4 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,8 @@ trak_results/ # session Session.vim +local/ +*trak_results/ +slurm-*.out +A100.pt +H100.pt diff --git a/docs/source/conf.py b/docs/source/conf.py index 94cd0c4..cafa4bd 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,8 +22,8 @@ author = 'Kristian Georgiev' # The full version, including alpha/beta/rc tags -release = '0.2.1' -version = '0.2.1' +release = '0.2.2' +version = '0.2.2' # -- General configuration --------------------------------------------------- diff --git a/examples/qnli.py b/examples/qnli.py index b5538fd..da9e48f 100644 --- a/examples/qnli.py +++ b/examples/qnli.py @@ -10,7 +10,6 @@ """ from argparse import ArgumentParser -import sys from tqdm import tqdm import torch as ch @@ -21,7 +20,6 @@ # Huggingface from datasets import load_dataset -import transformers from transformers import ( AutoConfig, AutoModelForSequenceClassification, @@ -30,7 +28,6 @@ ) - GLUE_TASK_TO_KEYS = { "cola": ("sentence", None), "mnli": ("premise", "hypothesis"), @@ -44,8 +41,8 @@ } # NOTE: CHANGE THIS IF YOU WANT TO RUN ON FULL DATASET -TRAIN_SET_SIZE = 5_000 -VAL_SET_SIZE = 1_00 +TRAIN_SET_SIZE = 50_000 +VAL_SET_SIZE = 5_463 class SequenceClassificationModel(nn.Module): @@ -76,8 +73,8 @@ def __init__(self): def forward(self, input_ids, token_type_ids, attention_mask): return self.model(input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask).logits + token_type_ids=token_type_ids, + attention_mask=attention_mask).logits def get_dataset(split, inds=None): @@ -88,10 +85,9 @@ def get_dataset(split, inds=None): use_auth_token=None, ) label_list = raw_datasets["train"].features["label"].names - num_labels = len(label_list) sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS['qnli'] - label_to_id = None #{v: i for i, v in enumerate(label_list)} + label_to_id = None # {v: i for i, v in enumerate(label_list)} tokenizer = AutoTokenizer.from_pretrained( 'bert-base-cased', @@ -102,7 +98,7 @@ def get_dataset(split, inds=None): ) padding = "max_length" - max_seq_length=128 + max_seq_length = 128 def preprocess_function(examples): # Tokenize the texts @@ -113,7 +109,7 @@ def preprocess_function(examples): # Map labels to IDs (not necessary for GLUE tasks) if label_to_id is not None and "label" in examples: - result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] + result["label"] = [(label_to_id[lbl] if lbl != -1 else -1) for lbl in examples["label"]] return result raw_datasets = raw_datasets.map( @@ -178,10 +174,13 @@ def process_batch(batch): traker.finalize_features() - traker.start_scoring_checkpoint(model.state_dict(), model_id=0, num_targets=VAL_SET_SIZE) + traker.start_scoring_checkpoint(exp_name='qnli', + checkpoint=model.state_dict(), + model_id=0, + num_targets=VAL_SET_SIZE) for batch in tqdm(loader_val, desc='Scoring..'): batch = process_batch(batch) batch = [x.cuda() for x in batch] traker.score(batch=batch, num_samples=batch[0].shape[0]) - scores = traker.finalize_scores() \ No newline at end of file + scores = traker.finalize_scores(exp_name='qnli') diff --git a/setup.py b/setup.py index 09c7e92..ba5f231 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup setup(name="traker", - version="0.2.1", + version="0.2.2", description="TRAK: Attributing Model Behavior at Scale", long_description="Check https://trak.csail.mit.edu/ to learn more about TRAK", author="MadryLab", diff --git a/tests/test_integration_cifar.py b/tests/test_integration_cifar.py index 8ba5b99..cad31c1 100644 --- a/tests/test_integration_cifar.py +++ b/tests/test_integration_cifar.py @@ -49,4 +49,3 @@ def test_cifar10(tmp_path, device='cpu'): @pytest.mark.cuda def test_cifar10_cuda(tmp_path): test_cifar10(tmp_path, device='cuda:0') - diff --git a/tests/test_integration_qnli.py b/tests/test_integration_qnli.py new file mode 100644 index 0000000..58364a3 --- /dev/null +++ b/tests/test_integration_qnli.py @@ -0,0 +1,165 @@ +from tqdm import tqdm +from torch.utils.data import DataLoader +import torch.nn as nn +import pytest +import logging + +from trak import TRAKer + +from datasets import load_dataset +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + default_data_collator, +) + + +GLUE_TASK_TO_KEYS = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +# for testing purposes +TRAIN_SET_SIZE = 20 +VAL_SET_SIZE = 10 + + +class SequenceClassificationModel(nn.Module): + """ + Wrapper for HuggingFace sequence classification models. + """ + def __init__(self): + super().__init__() + self.config = AutoConfig.from_pretrained( + 'bert-base-cased', + num_labels=2, + finetuning_task='qnli', + cache_dir=None, + revision='main', + use_auth_token=None, + ) + + self.model = AutoModelForSequenceClassification.from_pretrained( + 'bert-base-cased', + config=self.config, + cache_dir=None, + revision='main', + use_auth_token=None, + ignore_mismatched_sizes=False + ) + + self.model.eval().cuda() + + def forward(self, input_ids, token_type_ids, attention_mask): + return self.model(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask).logits + + +def get_dataset(split, inds=None): + raw_datasets = load_dataset( + "glue", + 'qnli', + cache_dir=None, + use_auth_token=None, + ) + sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS['qnli'] + + tokenizer = AutoTokenizer.from_pretrained( + 'bert-base-cased', + cache_dir=None, + use_fast=True, + revision='main', + use_auth_token=False + ) + + padding = "max_length" + max_seq_length = 128 + + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) + + return result + + raw_datasets = raw_datasets.map( + preprocess_function, + batched=True, + load_from_cache_file=(not False), + desc="Running tokenizer on dataset", + ) + + if split == 'train': + train_dataset = raw_datasets["train"] + ds = train_dataset + else: + eval_dataset = raw_datasets["validation"] + ds = eval_dataset + return ds + + +def init_loaders(batch_size=10): + ds_train = get_dataset('train') + ds_train = ds_train.select(range(TRAIN_SET_SIZE)) + ds_val = get_dataset('val') + ds_val = ds_val.select(range(VAL_SET_SIZE)) + return DataLoader(ds_train, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator), \ + DataLoader(ds_val, batch_size=batch_size, shuffle=False, collate_fn=default_data_collator) + + +def process_batch(batch): + return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels'] + + +# model too large to test on CPU +@pytest.mark.cuda +def test_qnli(tmp_path, device='cuda'): + loader_train, loader_val = init_loaders() + + # no need to load model from checkpoint, just testing featurization and scoring + model = SequenceClassificationModel() + + logger = logging.getLogger('QNLI') + logger.setLevel(logging.DEBUG) + logger.info(f'Initializing TRAKer with device {device}') + + traker = TRAKer(model=model, + task='text_classification', + train_set_size=TRAIN_SET_SIZE, + save_dir=tmp_path, + device=device, + logging_level=logging.DEBUG, + proj_dim=512) + + logger.info('Loading checkpoint') + traker.load_checkpoint(model.state_dict(), model_id=0) + logger.info('Loaded checkpoint') + for batch in tqdm(loader_train, desc='Featurizing..'): + # process batch into compatible form for TRAKer TextClassificationModelOutput + batch = process_batch(batch) + batch = [x.to(device) for x in batch] + traker.featurize(batch=batch, num_samples=batch[0].shape[0]) + + traker.finalize_features() + + traker.start_scoring_checkpoint(exp_name='qnli', + checkpoint=model.state_dict(), + model_id=0, + num_targets=VAL_SET_SIZE) + for batch in tqdm(loader_val, desc='Scoring..'): + batch = process_batch(batch) + batch = [x.to(device) for x in batch] + traker.score(batch=batch, num_samples=batch[0].shape[0]) + + traker.finalize_scores(exp_name='qnli') diff --git a/tests/test_jl.py b/tests/test_jl.py index 6c54da4..5e3285e 100644 --- a/tests/test_jl.py +++ b/tests/test_jl.py @@ -8,6 +8,7 @@ from trak.projectors import CudaProjector, ProjectionType BasicProjector = CudaProjector +MAX_BATCH_SIZE = 32 PARAM = list(product([0, 1, 10**8], # seed [ProjectionType.normal, ProjectionType.rademacher], # proj type [ch.float16, ch.float32], # dtype @@ -43,7 +44,8 @@ def test_seed_consistency(seed, proj_type=proj_type, seed=seed, device='cuda:0', - dtype=dtype + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE ) result = proj.project(g, model_id=0) @@ -70,7 +72,8 @@ def test_seed_consistency_2(seed, proj_type=proj_type, seed=seed, device='cuda:0', - dtype=dtype + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE ) result = proj.project(g, model_id=0) @@ -80,7 +83,9 @@ def test_seed_consistency_2(seed, proj_type=proj_type, seed=seed, device='cuda:0', - dtype=dtype) + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE + ) result_again = proj_again.project(g, model_id=0) testing.assert_close(result, result_again, equal_nan=True) @@ -102,7 +107,8 @@ def test_norm_preservation(seed, proj_type=proj_type, seed=seed, device='cuda:0', - dtype=dtype + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE ) p = proj.project(g, model_id=0) @@ -145,7 +151,8 @@ def test_prod_preservation(seed, proj_type=proj_type, seed=seed, device='cuda:0', - dtype=dtype + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE ) # check that things break with a garbage matrix @@ -193,7 +200,8 @@ def test_single_nonzero_feature(seed, proj_type=proj_type, seed=seed, device='cuda:0', - dtype=dtype + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE ) p = proj.project(g, model_id=0) assert (~ch.isclose(p, ch.zeros_like(p))).all().item() @@ -218,7 +226,8 @@ def test_first_nonzero_feature(seed, proj_type=proj_type, seed=seed, device='cuda:0', - dtype=dtype + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE ) p = proj.project(g, model_id=0) assert (~ch.isclose(p, ch.zeros_like(p))).all().item() @@ -243,7 +252,8 @@ def test_last_nonzero_feature(seed, proj_type=proj_type, seed=seed, device='cuda:0', - dtype=dtype + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE ) p = proj.project(g, model_id=0) assert (~ch.isclose(p, ch.zeros_like(p))).all().item() @@ -268,7 +278,8 @@ def test_same_features(seed, proj_type=proj_type, seed=seed, device='cuda:0', - dtype=dtype + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE ) p = proj.project(g, model_id=0) @@ -294,7 +305,9 @@ def test_orthogonality(seed, proj_type=proj_type, seed=seed, device='cuda:0', - dtype=dtype) + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE + ) num_successes = 0 num_trials = 100 diff --git a/tests/test_jl_additional.py b/tests/test_jl_additional.py new file mode 100644 index 0000000..90d870d --- /dev/null +++ b/tests/test_jl_additional.py @@ -0,0 +1,115 @@ +import pytest +from itertools import product +import torch as ch +from torch import testing + +from trak.projectors import CudaProjector, ProjectionType + +MAX_BATCH_SIZE = 32 + +# TEST CASES 1 +PARAM = list(product([123], # seed + [ProjectionType.rademacher], # proj type + [ch.float32], # dtype + [ + (8, 180645096), # pass: np.prod(shape) < np.iinfo(np.int32).max + (16, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + (31, 180645096), # fail: np.prod(shape) > np.iinfo(np.int32).max + (32, 180645096), # fail: np.prod(shape) > np.iinfo(np.int32).max + (33, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + (48, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + (50, 180645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + ], # input shape + [15_360], # proj dim + )) + +# TEST CASES 2 +PARAM = list(product([123], # seed + [ProjectionType.rademacher], # proj type + [ch.float32], # dtype + [ + (1, 780645096), # pass: np.prod(shape) < np.iinfo(np.int32).max + (5, 780645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + (6, 780645096), # pass: np.prod(shape) > np.iinfo(np.int32).max + (7, 780645096), # fail: np.prod(shape) > np.iinfo(np.int32).max + (8, 780645096), # fail: np.prod(shape) > np.iinfo(np.int32).max + ], # input shape + [4_096], # proj dim + )) + + +# TEST CASES 3 (ONLY for test_same_features_diff_sms) +PARAM = list(product([123], # seed + [ProjectionType.rademacher], # proj type + [ch.float32], # dtype + [ + (32, 100_000), + ], # input shape + [4_096], # proj dim + )) + +@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) +@pytest.mark.cuda +def test_same_features(seed, + proj_type, + dtype, + proj_dim, + input_shape, + ): + """ + Check that output is the same for the same features + """ + g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) + g[-1] = g[0] + + proj = CudaProjector(grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device='cuda:0', + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE + ) + p = proj.project(g, model_id=0) + + assert ch.allclose(p[0], p[-1]) + +@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim", PARAM) +@pytest.mark.cuda +def test_same_features_diff_sms(seed, + proj_type, + dtype, + proj_dim, + input_shape, + ): + """ + Check that output is the same for the same features + """ + g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) + + + # project with all SMs available + proj_full_sms = CudaProjector(grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device='cuda:0', + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE + ) + p_full_sms = proj_full_sms.project(g, model_id=0) + + # project with half SMs available + proj_half_sms = CudaProjector(grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device='cuda:0', + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE + ) + + proj_half_sms.num_sms = max(proj_half_sms.num_sms // 2, 1) + p_half_sms = proj_half_sms.project(g, model_id=0) + + assert ch.allclose(p_full_sms, p_half_sms) diff --git a/tests/test_jl_gpu_compatibility/test_jl_gpu_compatibility.py b/tests/test_jl_gpu_compatibility/test_jl_gpu_compatibility.py new file mode 100644 index 0000000..5a7e380 --- /dev/null +++ b/tests/test_jl_gpu_compatibility/test_jl_gpu_compatibility.py @@ -0,0 +1,76 @@ +import os +import pytest +from itertools import product +import torch as ch +from torch import testing + +from trak.projectors import CudaProjector, ProjectionType + +MAX_BATCH_SIZE = 32 + +# TEST CASES 1 +PARAM = list(product([123], # seed + [ProjectionType.rademacher], # proj type + [ch.float32], # dtype + [ + (32, 100_000), # pass: np.prod(shape) < np.iinfo(np.int32).max + ], # input shape + [4_096], # proj dim + [108], # num sms + )) + + +@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim, num_sms", PARAM) +@pytest.mark.cuda +def test_create_proj(seed, + proj_type, + dtype, + proj_dim, + input_shape, + num_sms, + ): + """ + Compute the output for each GPU type + """ + GPU_NAME = os.environ['GPU_NAME'] + print(f'GPU: {GPU_NAME}') + + if os.path.exists(f'./{GPU_NAME}.pt'): + os.remove(f'./{GPU_NAME}.pt') + + g = testing.make_tensor(*input_shape, device='cuda:0', dtype=dtype) + + proj = CudaProjector(grad_dim=input_shape[-1], + proj_dim=proj_dim, + proj_type=proj_type, + seed=seed, + device='cuda:0', + dtype=dtype, + max_batch_size=MAX_BATCH_SIZE + ) + + proj.num_sms = num_sms + print(f'# Projector SMs: {proj.num_sms}') + + p = proj.project(g, model_id=0) + + ch.save(p.cpu(), f'./{GPU_NAME}.pt') + + +@pytest.mark.parametrize("seed, proj_type, dtype, input_shape, proj_dim, num_sms", PARAM) +@pytest.mark.cuda +def test_same_proj(seed, + proj_type, + dtype, + proj_dim, + input_shape, + num_sms, + ): + """ + Check that output is the same for different GPUs + """ + + proj_a100 = ch.load('./A100.pt') + proj_h100 = ch.load('./H100.pt') + + assert ch.allclose(proj_a100, proj_h100), 'GPUs have different projection' \ No newline at end of file diff --git a/tests/test_jl_gpu_compatibility/test_jl_gpu_compatibility.sh b/tests/test_jl_gpu_compatibility/test_jl_gpu_compatibility.sh new file mode 100755 index 0000000..1d9a989 --- /dev/null +++ b/tests/test_jl_gpu_compatibility/test_jl_gpu_compatibility.sh @@ -0,0 +1,47 @@ +#!/bin/bash +#SBATCH --job-name=jl_unit_test +#SBATCH --partition=high-priority +#SBATCH hetjob +#SBATCH --output=/mnt/xfs/home/alaakh/installs/temp/trak_fixes/tests/test_jl_gpu_compatibility/%u-%x-%j.log + +CODE_PATH="/mnt/xfs/home/alaakh/installs/temp/trak_fixes/tests/test_jl_gpu_compatibility" +cd $CODE_PATH + +export PYTHONPATH="${PYTHONPATH}:${CODE_PATH}" +export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH + +CREATE_PROJ_TEST="test_jl_gpu_compatibility.py::test_create_proj" +VERIFY_PROJ_TEST="test_jl_gpu_compatibility.py::test_same_proj" + +# Component for a100 GPU +#SBATCH --nodes=1 +#SBATCH --nodelist="deep-chungus-[1-5,7-11]" +#SBATCH --gres=gpu:a100:1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --partition=high-priority + +env GPU_NAME="A100" \ +srun --ntasks=1 \ + python -m pytest $CREATE_PROJ_TEST + +# Component for h100 GPU +#SBATCH hetjob +#SBATCH --nodes=1 +#SBATCH --nodelist="deep-h-[1-3]" +#SBATCH --gres=gpu:h100:1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=1 +#SBATCH --partition=high-priority + +env GPU_NAME="H100" \ +srun --ntasks=1 \ + python -m pytest $CREATE_PROJ_TEST + +srun --ntasks=1 \ + python -m pytest $VERIFY_PROJ_TEST + +rm "${CODE_PATH}/A100.pt" +rm "${CODE_PATH}/H100.pt" + +echo "done" \ No newline at end of file diff --git a/tests/test_rademacher.py b/tests/test_rademacher.py index 2a83cb4..7a0beb1 100644 --- a/tests/test_rademacher.py +++ b/tests/test_rademacher.py @@ -10,13 +10,14 @@ from assertpy import assert_that -bs_error_str = 'CUDA error: too many resources requested for launch\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n' +bs_error_str = 'CUDA error: too many resources requested for launch\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n' # noqa -new_bs_error_str = f'The batch size of the CudaProjector is too large for your GPU. Reduce it by using the max_batch_size argument of the CudaProjector.\nOriginal error: {bs_error_str}' +new_bs_error_str = f'The batch size of the CudaProjector is too large for your GPU. Reduce it by using the max_batch_size argument of the CudaProjector.\nOriginal error: {bs_error_str}' # noqa PARAM = list(product([8], [1024, 2048], [512, 1024, 2048], [0, 1])) + @pytest.mark.parametrize("bs, input_size, output_size, seed ", PARAM) @pytest.mark.cuda def test_shape(bs: int, input_size: int, output_size: int, seed: int): @@ -35,6 +36,7 @@ def test_shape(bs: int, input_size: int, output_size: int, seed: int): assert_that(result.shape).is_equal_to((bs, output_size)) + @pytest.mark.cuda def test_running(): bs = 8 @@ -55,6 +57,7 @@ def test_running(): print(result.sum()) + @pytest.mark.cuda def test_even(): bs = 8 @@ -75,6 +78,7 @@ def test_even(): assert_that(ch.all(result % 2 == 0)).is_true() + @pytest.mark.cuda def test_odd(): bs = 8 @@ -94,4 +98,3 @@ def test_odd(): raise e assert_that(ch.all(result % 2 == 1)).is_true() - diff --git a/tests/test_rademacher_additional.py b/tests/test_rademacher_additional.py new file mode 100644 index 0000000..53a4f86 --- /dev/null +++ b/tests/test_rademacher_additional.py @@ -0,0 +1,101 @@ +import pytest +from itertools import product + +import torch as ch + +try: + import fast_jl +except ModuleNotFoundError: + print('No fast_jl available!') + +from assertpy import assert_that + +bs_error_str = 'CUDA error: too many resources requested for launch\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n' # noqa + +new_bs_error_str = f'The batch size of the CudaProjector is too large for your GPU. Reduce it by using the max_batch_size argument of the CudaProjector.\nOriginal error: {bs_error_str}' # noqa + + +PARAM = list(product([8, 16, 32, 48], [180645096], [2048, 4096, 15_360], [0])) +# PARAM = list(product([1, 6, 7, 8], [780645096], [4096, 15_360], [0])) + + +@pytest.mark.parametrize("bs, input_size, output_size, seed ", PARAM) +@pytest.mark.cuda +def test_shape(bs: int, input_size: int, output_size: int, seed: int): + print(output_size) + input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") + + num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + + try: + result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) + except RuntimeError as e: + if str(e) == bs_error_str: + raise RuntimeError(new_bs_error_str) + else: + raise e + + assert_that(result.shape).is_equal_to((bs, output_size)) + + +@pytest.mark.cuda +def test_running(): + bs = 8 + input_size = 256 + seed = 17 + output_size = 512 + input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") + + num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + + try: + result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) + except RuntimeError as e: + if str(e) == bs_error_str: + raise RuntimeError(new_bs_error_str) + else: + raise e + + print(result.sum()) + + +@pytest.mark.cuda +def test_even(): + bs = 8 + input_size = 10240 + seed = 64 + output_size = 1024 + input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") + + num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + + try: + result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) + except RuntimeError as e: + if str(e) == bs_error_str: + raise RuntimeError(new_bs_error_str) + else: + raise e + + assert_that(ch.all(result % 2 == 0)).is_true() + + +@pytest.mark.cuda +def test_odd(): + bs = 8 + input_size = 10241 + seed = 78 + output_size = 2048 + input_data = ch.ones((bs, input_size), dtype=ch.float16, device="cuda:0") + + num_sms = ch.cuda.get_device_properties(ch.cuda.current_device()).multi_processor_count + + try: + result = fast_jl.project_rademacher_8(input_data, output_size, seed, num_sms) + except RuntimeError as e: + if str(e) == bs_error_str: + raise RuntimeError(new_bs_error_str) + else: + raise e + + assert_that(ch.all(result % 2 == 1)).is_true() diff --git a/trak/__init__.py b/trak/__init__.py index 1ad5c65..f627868 100644 --- a/trak/__init__.py +++ b/trak/__init__.py @@ -1,6 +1,6 @@ from .traker import TRAKer from .utils import test_install -__version__ = '0.2.1' +__version__ = '0.2.2' VERSION = __version__ diff --git a/trak/gradient_computers.py b/trak/gradient_computers.py index f50a90d..4a3bbdc 100644 --- a/trak/gradient_computers.py +++ b/trak/gradient_computers.py @@ -28,6 +28,8 @@ def __init__(self, model: torch.nn.Module, task: AbstractModelOutput, grad_dim: Optional[int] = None, + dtype: Optional[torch.dtype] = torch.float16, + device: Optional[torch.device] = 'cuda', ) -> None: """ Initializes attributes, nothing too interesting happening. @@ -39,11 +41,17 @@ def __init__(self, grad_dim (int, optional): Size of the gradients (number of model parameters). Defaults to None. + dtype (torch.dtype, optional): + Torch dtype of the gradients. Defaults to torch.float16. + device (torch.device, optional): + Torch device where gradients will be stored. Defaults to 'cuda'. """ self.model = model self.modelout_fn = task self.grad_dim = grad_dim + self.dtype = dtype + self.device = device @abstractmethod def load_model_params(self, model) -> None: @@ -62,8 +70,10 @@ class FunctionalGradientComputer(AbstractGradientComputer): def __init__(self, model: torch.nn.Module, task: AbstractModelOutput, - grad_dim: int) -> None: - super().__init__(model, task, grad_dim) + grad_dim: int, + dtype: torch.dtype, + device: torch.device) -> None: + super().__init__(model, task, grad_dim, dtype, device) self.model = model self.num_params = get_num_params(self.model) self.load_model_params(model) @@ -103,8 +113,8 @@ def compute_per_sample_grad(self, batch: Iterable[Tensor]) -> Tensor: grads_loss = torch.func.grad(self.modelout_fn.get_output, has_aux=False, argnums=1) # map over batch dimensions (hence 0 for each batch dimension, and None for model params) grads = torch.empty(size=(batch[0].shape[0], self.num_params), - dtype=batch[0].dtype, - device=batch[0].device) + dtype=self.dtype, + device=self.device) vectorize(torch.func.vmap(grads_loss, in_dims=(None, None, None, *([0] * len(batch))), @@ -151,8 +161,10 @@ class IterativeGradientComputer(AbstractGradientComputer): def __init__(self, model, task: AbstractModelOutput, - grad_dim: int) -> None: - super().__init__(model, task, grad_dim) + grad_dim: int, + dtype: torch.dtype, + device: torch.device) -> None: + super().__init__(model, task, grad_dim, dtype, device) self.load_model_params(model) def load_model_params(self, model) -> Tensor: diff --git a/trak/modelout_functions.py b/trak/modelout_functions.py index 26eda8c..8a0b6e6 100644 --- a/trak/modelout_functions.py +++ b/trak/modelout_functions.py @@ -387,11 +387,14 @@ def get_output(model, attention_mask: Tensor, label: Tensor, ) -> Tensor: + kw_inputs = {'input_ids': input_id.unsqueeze(0), + 'token_type_ids': token_type_id.unsqueeze(0), + 'attention_mask': attention_mask.unsqueeze(0)} + logits = ch.func.functional_call(model, (weights, buffers), - args=(input_id.unsqueeze(0), - token_type_id.unsqueeze(0), - attention_mask.unsqueeze(0))) + args=(), + kwargs=kw_inputs) bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False) logits_correct = logits[bindex, label.unsqueeze(0)] @@ -403,7 +406,13 @@ def get_output(model, def get_out_to_loss_grad(self, model, weights, buffers, batch: Iterable[Tensor]) -> Tensor: input_ids, token_type_ids, attention_mask, labels = batch - logits = ch.func.functional_call(model, (weights, buffers), input_ids, token_type_ids, attention_mask) + kw_inputs = {'input_ids': input_ids, + 'token_type_ids': token_type_ids, + 'attention_mask': attention_mask} + logits = ch.func.functional_call(model, + (weights, buffers), + args=(), + kwargs=kw_inputs) ps = self.softmax(logits / self.loss_temperature)[ch.arange(logits.size(0)), labels] return (1 - ps).clone().detach().unsqueeze(-1) diff --git a/trak/projectors.py b/trak/projectors.py index fbeaff9..d736bfa 100644 --- a/trak/projectors.py +++ b/trak/projectors.py @@ -158,9 +158,15 @@ class BasicProjector(AbstractProjector): a CUDA-enabled device with compute capability >=7.0 (see https://developer.nvidia.com/cuda-gpus). """ - def __init__(self, grad_dim: int, proj_dim: int, seed: int, proj_type: - ProjectionType, device, block_size: int = 200, dtype=ch.float32, - model_id=0, *args, **kwargs) -> None: + def __init__(self, grad_dim: int, + proj_dim: int, + seed: int, + proj_type: ProjectionType, + device: torch.device, + block_size: int = 100, + dtype: torch.dtype = ch.float32, + model_id=0, + *args, **kwargs) -> None: super().__init__(grad_dim, proj_dim, seed, proj_type, device) self.block_size = min(self.proj_dim, block_size) diff --git a/trak/traker.py b/trak/traker.py index e250e99..0ec22f4 100644 --- a/trak/traker.py +++ b/trak/traker.py @@ -1,6 +1,6 @@ from .modelout_functions import AbstractModelOutput, TASK_TO_MODELOUT from .projectors import ProjectionType, AbstractProjector, CudaProjector, BasicProjector -from .gradient_computers import FunctionalGradientComputer,\ +from .gradient_computers import FunctionalGradientComputer, \ AbstractGradientComputer from .score_computers import AbstractScoreComputer, BasicScoreComputer from .savers import AbstractSaver, MmapSaver, ModelIDException @@ -38,6 +38,7 @@ def __init__(self, logging_level=logging.INFO, use_half_precision: bool = True, proj_max_batch_size: int = 32, + projector_seed: int = 0, ) -> None: """ @@ -90,6 +91,12 @@ def __init__(self, If True, TRAK will use half precision (float16) for all computations and arrays will be stored in float16. Otherwise, it will use float32. Defaults to True. + proj_max_batch_size (int): + Batch size used by fast_jl if teh CudaProjector is used. Must be + a multiple of 8. The maximum batch size is 32 for A100 GPUs, 16 + for V100 GPUs, 40 for H100 GPUs. Defaults to 32. + projecotr_seed (int): + Random seed used by the projector. Defaults to 0. """ @@ -107,7 +114,10 @@ def __init__(self, self.num_params = get_num_params(self.model) # inits self.projector - self.init_projector(projector, proj_dim, proj_max_batch_size) + self.proj_seed = projector_seed + self.init_projector(projector=projector, + proj_dim=proj_dim, + proj_max_batch_size=proj_max_batch_size) # normalize to make X^TX numerically stable # doing this instead of normalizing the projector matrix @@ -121,7 +131,9 @@ def __init__(self, self.gradient_computer = gradient_computer(model=self.model, task=self.task, - grad_dim=self.num_params) + grad_dim=self.num_params, + dtype=self.dtype, + device=self.device) if score_computer is None: score_computer = BasicScoreComputer @@ -163,25 +175,37 @@ def init_projector(self, projector, proj_dim, proj_max_batch_size) -> None: else: self.proj_dim = proj_dim - try: - import fast_jl - test_gradient = ch.ones(1, self.num_params).cuda() - num_sms = ch.cuda.get_device_properties('cuda').multi_processor_count - fast_jl.project_rademacher_8(test_gradient, self.proj_dim, 0, num_sms) - projector = CudaProjector - - except (ImportError, RuntimeError, AttributeError) as e: - self.logger.error(f'Could not use CudaProjector.\nReason: {str(e)}') - self.logger.error('Defaulting to BasicProjector.') + if self.device == 'cpu': + self.logger.info('Using BasicProjector since device is CPU') projector = BasicProjector - + # Sampling from bernoulli distribution is not supported for + # dtype float16 on CPU; playing it safe here by defaulting to + # normal projection, rather than rademacher + proj_type = ProjectionType.normal + self.logger.info('Using Normal projection') + else: + try: + import fast_jl + test_gradient = ch.ones(1, self.num_params).cuda() + num_sms = ch.cuda.get_device_properties('cuda').multi_processor_count + fast_jl.project_rademacher_8(test_gradient, self.proj_dim, 0, num_sms) + projector = CudaProjector + + except (ImportError, RuntimeError, AttributeError) as e: + self.logger.error(f'Could not use CudaProjector.\nReason: {str(e)}') + self.logger.error('Defaulting to BasicProjector.') + projector = BasicProjector + proj_type = ProjectionType.rademacher + + self.logger.debug(f'Initializing projector with grad_dim {self.num_params}') self.projector = projector(grad_dim=self.num_params, proj_dim=self.proj_dim, - seed=0, - proj_type=ProjectionType.rademacher, + seed=self.proj_seed, + proj_type=proj_type, max_batch_size=proj_max_batch_size, dtype=self.dtype, device=self.device) + self.logger.debug(f'Initialized projector with proj_dim {self.proj_dim}') def load_checkpoint(self, checkpoint: Iterable[Tensor], @@ -242,11 +266,11 @@ def featurize(self, Number of samples in the batch. Defaults to None. """ - assert self.ckpt_loaded == self.saver.current_model_id,\ + assert self.ckpt_loaded == self.saver.current_model_id, \ "Load a checkpoint using traker.load_checkpoint before featurizing" - assert (inds is None) or (num_samples is None),\ + assert (inds is None) or (num_samples is None), \ "Exactly one of num_samples and inds should be specified" - assert (inds is not None) or (num_samples is not None),\ + assert (inds is not None) or (num_samples is not None), \ "Exactly one of num_samples and inds should be specified" if num_samples is not None: @@ -374,9 +398,9 @@ def score(self, Number of samples in the batch. Defaults to None. """ - assert (inds is None) or (num_samples is None),\ + assert (inds is None) or (num_samples is None), \ "Exactly one of num_samples and inds should be specified" - assert (inds is not None) or (num_samples is not None),\ + assert (inds is not None) or (num_samples is not None), \ "Exactly one of num_samples and inds should be specified" if self.saver.model_ids[self.saver.current_model_id]['is_finalized'] == 0: