diff --git a/.assets/kronfluence.png b/.assets/kronfluence.png deleted file mode 100644 index f4b857c..0000000 Binary files a/.assets/kronfluence.png and /dev/null differ diff --git a/.assets/kronfluence.svg b/.assets/kronfluence.svg new file mode 100644 index 0000000..c9e78b2 --- /dev/null +++ b/.assets/kronfluence.svg @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index c8293f2..bc0c58a 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -46,23 +46,9 @@ jobs: run: | isort --profile black kronfluence - black: - runs-on: ubuntu-latest - - steps: - - name: Checkout Repository - uses: actions/checkout@v2 - - - name: Set up Python 3.9 - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - - name: Install black - run: | - pip install --upgrade pip - pip install black==24.1.1 - - - name: Run black - run: | - black --check kronfluence \ No newline at end of file + jobs: + actionlint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: reviewdog/action-actionlint@v1 \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5abfe88..dab54d5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,7 +4,8 @@ We welcome contributions to the `kronfluence` project. Whether it's bug fixes, f ## Setting Up Development Environment -To contribute to `kronfluence`, you will need to set up a development environment on your machine. This setup includes all the dependencies required for linting, testing, and documentation. +To contribute to `kronfluence`, you will need to set up a development environment on your machine. +This setup includes all the dependencies required for linting and testing. ```bash git clone https://github.com/pomonam/kronfluence.git diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index b1c5744..f0fb085 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -1,7 +1,12 @@ # Kronfluence: Technical Documentation & FAQs +(To be added.) + ## Supported Modules Kronfluence only supports influence computation on supported `nn.Module`. The following modules are supported: 1. `nn.Linear` and `nn.Conv2d` +## Supported Strategies + +- Identity, diagonal, KFAC, EKFAC \ No newline at end of file diff --git a/README.md b/README.md index 956763e..5bea162 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@
diff --git a/dev_requirements.txt b/dev_requirements.txt index f8caa1e..ea5f63e 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,6 +1,6 @@ isort==5.13.2 pylint==3.0.3 pytest==8.0.0 -black==24.1.1 +ruff==0.3.0 datasets>=2.17.0 transformers>=4.37.2 \ No newline at end of file diff --git a/examples/cifar/pipeline.py b/examples/cifar/pipeline.py index a0ec66d..c46932c 100644 --- a/examples/cifar/pipeline.py +++ b/examples/cifar/pipeline.py @@ -91,9 +91,7 @@ def get_cifar10_dataset( ): assert split in ["train", "eval_train", "valid"] - normalize = torchvision.transforms.Normalize( - mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261) - ) + normalize = torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.243, 0.261)) if split in ["train", "eval_train"]: transforms = torchvision.transforms.Compose( @@ -114,9 +112,7 @@ def get_cifar10_dataset( if split == "train": transform_config = [ - torchvision.transforms.RandomResizedCrop( - size=224, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0) - ), + torchvision.transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0)), torchvision.transforms.RandomHorizontalFlip(), ] transform_config.extend([torchvision.transforms.ToTensor(), normalize]) @@ -180,9 +176,7 @@ def get_cifar10_dataloader( if do_corrupt: if split == "valid": - raise NotImplementedError( - "Performing corruption on the validation dataset is not supported." - ) + raise NotImplementedError("Performing corruption on the validation dataset is not supported.") num_corrupt = math.ceil(len(dataset) * 0.1) original_targets = np.array(copy.deepcopy(dataset.targets[:num_corrupt])) new_targets = torch.randint( @@ -197,9 +191,7 @@ def get_cifar10_dataloader( size=new_targets[new_targets == original_targets].shape, generator=torch.Generator().manual_seed(0), ).numpy() - new_targets[new_targets == original_targets] = ( - new_targets[new_targets == original_targets] + offsets - ) % 10 + new_targets[new_targets == original_targets] = (new_targets[new_targets == original_targets] + offsets) % 10 assert (new_targets == original_targets).sum() == 0 dataset.targets[:num_corrupt] = list(new_targets) diff --git a/examples/glue/pipeline.py b/examples/glue/pipeline.py index 6ac8e9c..31d23ae 100644 --- a/examples/glue/pipeline.py +++ b/examples/glue/pipeline.py @@ -55,9 +55,7 @@ def get_glue_dataset( num_labels = len(label_list) assert num_labels == 2 - tokenizer = AutoTokenizer.from_pretrained( - "bert-base-cased", use_fast=True, trust_remote_code=True - ) + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased", use_fast=True, trust_remote_code=True) sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS[data_name] padding = "max_length" @@ -65,13 +63,9 @@ def get_glue_dataset( def preprocess_function(examples): texts = ( - (examples[sentence1_key],) - if sentence2_key is None - else (examples[sentence1_key], examples[sentence2_key]) - ) - result = tokenizer( - *texts, padding=padding, max_length=max_seq_length, truncation=True + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) + result = tokenizer(*texts, padding=padding, max_length=max_seq_length, truncation=True) if "label" in examples: result["labels"] = examples["label"] return result diff --git a/examples/glue/train.py b/examples/glue/train.py index 189e687..0a845e9 100644 --- a/examples/glue/train.py +++ b/examples/glue/train.py @@ -14,9 +14,7 @@ def parse_args(): - parser = argparse.ArgumentParser( - description="Train classification models on MNIST datasets." - ) + parser = argparse.ArgumentParser(description="Train classification models on MNIST datasets.") parser.add_argument( "--dataset_name", @@ -95,9 +93,7 @@ def main(): set_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - train_dataset = get_glue_dataset( - data_name=args.dataset_name, split="train", data_path=args.dataset_dir - ) + train_dataset = get_glue_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir) train_dataloader = DataLoader( dataset=train_dataset, batch_size=args.train_batch_size, diff --git a/examples/imagenet/analyze.py b/examples/imagenet/analyze.py index 9475e86..ede619c 100644 --- a/examples/imagenet/analyze.py +++ b/examples/imagenet/analyze.py @@ -15,9 +15,7 @@ def parse_args(): - parser = argparse.ArgumentParser( - description="Influence analysis on ImageNet datasets." - ) + parser = argparse.ArgumentParser(description="Influence analysis on ImageNet datasets.") parser.add_argument( "--dataset_dir", @@ -57,9 +55,7 @@ def parse_args(): class ClassificationTask(Task): - def compute_model_output( - self, batch: BATCH_DTYPE, model: nn.Module - ) -> torch.Tensor: + def compute_model_output(self, batch: BATCH_DTYPE, model: nn.Module) -> torch.Tensor: inputs, _ = batch return model(inputs) @@ -88,15 +84,11 @@ def compute_measurement( ) -> torch.Tensor: _, labels = batch - bindex = torch.arange(outputs.shape[0]).to( - device=outputs.device, non_blocking=False - ) + bindex = torch.arange(outputs.shape[0]).to(device=outputs.device, non_blocking=False) logits_correct = outputs[bindex, labels] cloned_logits = outputs.clone() - cloned_logits[bindex, labels] = torch.tensor( - -torch.inf, device=outputs.device, dtype=outputs.dtype - ) + cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=outputs.device, dtype=outputs.dtype) margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum() diff --git a/examples/imagenet/ddp_analyze.py b/examples/imagenet/ddp_analyze.py index ef6eaa2..3bedb9c 100644 --- a/examples/imagenet/ddp_analyze.py +++ b/examples/imagenet/ddp_analyze.py @@ -24,9 +24,7 @@ def parse_args(): - parser = argparse.ArgumentParser( - description="Influence analysis on ImageNet datasets." - ) + parser = argparse.ArgumentParser(description="Influence analysis on ImageNet datasets.") parser.add_argument( "--dataset_dir", @@ -66,9 +64,7 @@ def parse_args(): class ClassificationTask(Task): - def compute_model_output( - self, batch: BATCH_DTYPE, model: nn.Module - ) -> torch.Tensor: + def compute_model_output(self, batch: BATCH_DTYPE, model: nn.Module) -> torch.Tensor: inputs, _ = batch return model(inputs) @@ -97,15 +93,11 @@ def compute_measurement( ) -> torch.Tensor: _, labels = batch - bindex = torch.arange(outputs.shape[0]).to( - device=outputs.device, non_blocking=False - ) + bindex = torch.arange(outputs.shape[0]).to(device=outputs.device, non_blocking=False) logits_correct = outputs[bindex, labels] cloned_logits = outputs.clone() - cloned_logits[bindex, labels] = torch.tensor( - -torch.inf, device=outputs.device, dtype=outputs.dtype - ) + cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=outputs.device, dtype=outputs.dtype) margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum() @@ -132,9 +124,7 @@ def main(): model = prepare_model(model, task) model = model.to(device=device) - model = DistributedDataParallel( - model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK - ) + model = DistributedDataParallel(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) analyzer = Analyzer( analysis_name=args.analysis_name, diff --git a/examples/imagenet/pipeline.py b/examples/imagenet/pipeline.py index 6b647ca..119973b 100644 --- a/examples/imagenet/pipeline.py +++ b/examples/imagenet/pipeline.py @@ -8,9 +8,7 @@ def construct_resnet50() -> nn.Module: - return torchvision.models.resnet50( - weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1 - ) + return torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1) def get_imagenet_dataset( @@ -20,15 +18,11 @@ def get_imagenet_dataset( ) -> Dataset: assert split in ["train", "eval_train", "valid"] - normalize = torchvision.transforms.Normalize( - mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) - ) + normalize = torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) if split == "train": transform_config = [ - torchvision.transforms.RandomResizedCrop( - size=224, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0) - ), + torchvision.transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 4.0 / 3.0)), torchvision.transforms.RandomHorizontalFlip(), ] transform_config.extend([torchvision.transforms.ToTensor(), normalize]) diff --git a/examples/uci/analyze.py b/examples/uci/analyze.py index c28724b..648a4b6 100644 --- a/examples/uci/analyze.py +++ b/examples/uci/analyze.py @@ -100,12 +100,8 @@ def main(): logging.basicConfig(level=logging.INFO) logger = logging.getLogger() - train_dataset = get_regression_dataset( - data_name=args.dataset_name, split="train", data_path=args.dataset_dir - ) - eval_dataset = get_regression_dataset( - data_name=args.dataset_name, split="valid", data_path=args.dataset_dir - ) + train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir) + eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", data_path=args.dataset_dir) model = construct_regression_mlp() @@ -147,9 +143,7 @@ def main(): overwrite_output_dir=True, ) - with profile( - activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True - ) as prof: + with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True) as prof: with record_function("eigen"): analyzer.perform_eigendecomposition( factors_name=args.factor_strategy, diff --git a/examples/uci/pipeline.py b/examples/uci/pipeline.py index 9767fa8..a4b0b3e 100644 --- a/examples/uci/pipeline.py +++ b/examples/uci/pipeline.py @@ -74,9 +74,7 @@ def get_regression_dataset( y_train_scaled.astype(np.float32), ) else: - dataset = RegressionDataset( - x_val_scaled.astype(np.float32), y_val_scaled.astype(np.float32) - ) + dataset = RegressionDataset(x_val_scaled.astype(np.float32), y_val_scaled.astype(np.float32)) if indices is not None: dataset = torch.utils.data.Subset(dataset, indices) diff --git a/examples/uci/train.py b/examples/uci/train.py index 324edb4..301226b 100644 --- a/examples/uci/train.py +++ b/examples/uci/train.py @@ -12,9 +12,7 @@ def parse_args(): - parser = argparse.ArgumentParser( - description="Train regression models on UCI datasets." - ) + parser = argparse.ArgumentParser(description="Train regression models on UCI datasets.") parser.add_argument( "--dataset_name", @@ -93,9 +91,7 @@ def main(): if args.seed is not None: set_seed(args.seed) - train_dataset = get_regression_dataset( - data_name=args.dataset_name, split="train", data_path=args.dataset_dir - ) + train_dataset = get_regression_dataset(data_name=args.dataset_name, split="train", data_path=args.dataset_dir) train_dataloader = DataLoader( dataset=train_dataset, batch_size=args.train_batch_size, @@ -103,9 +99,7 @@ def main(): drop_last=True, ) model = construct_regression_mlp() - optimizer = torch.optim.SGD( - model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay - ) + optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) logger.info("Start training the model.") model.train() @@ -134,9 +128,7 @@ def main(): shuffle=False, drop_last=False, ) - eval_dataset = get_regression_dataset( - data_name=args.dataset_name, split="valid", data_path=args.dataset_dir - ) + eval_dataset = get_regression_dataset(data_name=args.dataset_name, split="valid", data_path=args.dataset_dir) eval_dataloader = DataLoader( dataset=eval_dataset, batch_size=args.eval_batch_size, diff --git a/kronfluence/analyzer.py b/kronfluence/analyzer.py index 22b862c..abc2e6f 100644 --- a/kronfluence/analyzer.py +++ b/kronfluence/analyzer.py @@ -28,9 +28,7 @@ def prepare_model( return model -class Analyzer( - CovarianceComputer, EigenComputer, PairwiseScoreComputer, SelfScoreComputer -): +class Analyzer(CovarianceComputer, EigenComputer, PairwiseScoreComputer, SelfScoreComputer): """ Handles the computation of all preconditioning factors (e.g., covariance and Lambda matrices for EKFAC) and influence scores for a given PyTorch model. @@ -98,9 +96,7 @@ def _save_model(self) -> None: self.logger.info(f"Found existing saved model at {model_save_path}.") # Load the existing model's state_dict for comparison. loaded_state_dict = load_file(model_save_path) - if not verify_models_equivalence( - loaded_state_dict, extracted_model.state_dict() - ): + if not verify_models_equivalence(loaded_state_dict, extracted_model.state_dict()): error_msg = ( "Detected a difference between the current model and the one saved at " f"{model_save_path}. Consider using a different `analysis_name` to " diff --git a/kronfluence/arguments.py b/kronfluence/arguments.py index 7ae8ea8..589aacf 100644 --- a/kronfluence/arguments.py +++ b/kronfluence/arguments.py @@ -36,15 +36,11 @@ class FactorArguments(Arguments): ) initial_per_device_batch_size_attempt: int = field( default=4096, - metadata={ - "help": "The initial attempted per-device batch size when the batch size is not provided." - }, + metadata={"help": "The initial attempted per-device batch size when the batch size is not provided."}, ) immediate_gradient_removal: bool = field( default=False, - metadata={ - "help": "Whether to immediately remove computed `.grad` by Autograd within the backward hook." - }, + metadata={"help": "Whether to immediately remove computed `.grad` by Autograd within the backward hook."}, ) # Configuration for fitting covariance matrices. # @@ -82,18 +78,14 @@ class FactorArguments(Arguments): gradient_covariance_dtype: torch.dtype = field( default=torch.float32, metadata={ - "help": "Dtype for computing pseudo-gradient covariance matrices. " - "Recommended to use `torch.float32`." + "help": "Dtype for computing pseudo-gradient covariance matrices. " "Recommended to use `torch.float32`." }, ) # Configuration for performing eigendecomposition. # eigendecomposition_dtype: torch.dtype = field( default=torch.float64, - metadata={ - "help": "Dtype for performing eigendecomposition. " - "Recommended to use `torch.float64." - }, + metadata={"help": "Dtype for performing eigendecomposition. " "Recommended to use `torch.float64."}, ) # Configuration for fitting Lambda matrices. # @@ -140,9 +132,7 @@ class FactorArguments(Arguments): ) lambda_dtype: torch.dtype = field( default=torch.float32, - metadata={ - "help": "Dtype for computing Lambda matrices. Recommended to use `torch.float32`." - }, + metadata={"help": "Dtype for computing Lambda matrices. Recommended to use `torch.float32`."}, ) @@ -152,9 +142,7 @@ class ScoreArguments(Arguments): initial_per_device_batch_size_attempt: int = field( default=4096, - metadata={ - "help": "The initial attempted per-device batch size when the batch size is not provided." - }, + metadata={"help": "The initial attempted per-device batch size when the batch size is not provided."}, ) damping: Optional[float] = field( default=None, @@ -165,9 +153,7 @@ class ScoreArguments(Arguments): ) immediate_gradient_removal: bool = field( default=False, - metadata={ - "help": "Whether to immediately remove computed `.grad` by Autograd within the backward hook." - }, + metadata={"help": "Whether to immediately remove computed `.grad` by Autograd within the backward hook."}, ) data_partition_size: int = field( @@ -197,15 +183,11 @@ class ScoreArguments(Arguments): query_gradient_rank: Optional[int] = field( default=None, - metadata={ - "help": "Rank for the query gradient. Applies no low-rank approximation if None." - }, + metadata={"help": "Rank for the query gradient. Applies no low-rank approximation if None."}, ) query_gradient_svd_dtype: torch.dtype = field( default=torch.float64, - metadata={ - "help": "Dtype for performing singular value decomposition (SVD) on the query gradient." - }, + metadata={"help": "Dtype for performing singular value decomposition (SVD) on the query gradient."}, ) score_dtype: torch.dtype = field( @@ -225,8 +207,5 @@ class ScoreArguments(Arguments): ) precondition_dtype: torch.dtype = field( default=torch.float32, - metadata={ - "help": "Dtype for computing the preconditioned gradient. " - "Recommended to use `torch.float32`." - }, + metadata={"help": "Dtype for computing the preconditioned gradient. " "Recommended to use `torch.float32`."}, ) diff --git a/kronfluence/computer/computer.py b/kronfluence/computer/computer.py index e5e865f..dca498e 100644 --- a/kronfluence/computer/computer.py +++ b/kronfluence/computer/computer.py @@ -72,9 +72,7 @@ def __init__( # Create and configure logger. disable_log = log_main_process_only and self.state.process_index != 0 - self.logger = get_logger( - name=__name__, log_level=log_level, disable_log=disable_log - ) + self.logger = get_logger(name=__name__, log_level=log_level, disable_log=disable_log) self.logger.info(f"Initializing Computer with parameters: {locals()}") self.logger.info(f"Process state configuration:\n{repr(self.state)}") @@ -102,9 +100,7 @@ def __init__( ) if cpu and isinstance(model, (DataParallel, DDP, FSDP)): - error_msg = ( - "To enforce CPU, the model must not be wrapped with DP, DDP, or FSDP." - ) + error_msg = "To enforce CPU, the model must not be wrapped with DP, DDP, or FSDP." self.logger.error(error_msg) raise ValueError(error_msg) @@ -129,9 +125,7 @@ def _save_arguments( """Saves arguments at the specified path.""" arguments_save_path = output_dir / f"{arguments_name}_arguments.json" if arguments_save_path.exists() and not overwrite_output_dir: - self.logger.info( - f"Found existing saved arguments at {arguments_save_path}." - ) + self.logger.info(f"Found existing saved arguments at {arguments_save_path}.") loaded_arguments = load_json(arguments_save_path) if loaded_arguments != arguments.to_dict(): error_msg = ( @@ -146,15 +140,11 @@ def _save_arguments( save_json(arguments.to_dict(), arguments_save_path) self.logger.info(f"Saved arguments at {arguments_save_path}.") - def _load_arguments( - self, arguments_name: str, output_dir: Path - ) -> Optional[Dict[str, Any]]: + def _load_arguments(self, arguments_name: str, output_dir: Path) -> Optional[Dict[str, Any]]: """Loads arguments from the specified path.""" arguments_save_path = output_dir / f"{arguments_name}_arguments.json" if not arguments_save_path.exists(): - self.logger.warning( - f"Could not find existing saved arguments at {arguments_save_path}." - ) + self.logger.warning(f"Could not find existing saved arguments at {arguments_save_path}.") return None return load_json(arguments_save_path) @@ -167,9 +157,7 @@ def _save_dataset_metadata( overwrite_output_dir: bool = False, ) -> None: """Saves dataset metadata at the specified path.""" - dataset_metadata_save_path = ( - output_dir / f"{dataset_name}_dataset_metadata.json" - ) + dataset_metadata_save_path = output_dir / f"{dataset_name}_dataset_metadata.json" dataset_metadata = { "type": type(dataset).__name__, "dataset_size": len(dataset), @@ -177,9 +165,7 @@ def _save_dataset_metadata( } if dataset_metadata_save_path.exists() and not overwrite_output_dir: - self.logger.info( - f"Found existing saved dataset metadata at {dataset_metadata_save_path}." - ) + self.logger.info(f"Found existing saved dataset metadata at {dataset_metadata_save_path}.") # Load the existing dataset metadata for comparison. loaded_metadata = load_json(dataset_metadata_save_path) if loaded_metadata != dataset_metadata: @@ -213,8 +199,6 @@ def _get_dataloader( error_msg = "DistributedEvalSampler is not currently supported with `stack=True`." self.logger.error(error_msg) raise ValueError(error_msg) - # Different from `DistributedSampler`, `DistributedEvalSampler` does not add extra duplicate - # data points to make the loader evenly divisible. sampler = DistributedEvalSampler( dataset=dataset, num_replicas=self.state.num_processes, @@ -341,9 +325,7 @@ def _find_executable_factors_batch_size( self.logger.info("Automatically determining executable batch size.") def executable_batch_size_func(batch_size: int) -> None: - self.logger.info( - f"Attempting to set per-device batch size to {batch_size}." - ) + self.logger.info(f"Attempting to set per-device batch size to {batch_size}.") set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False) self.model.zero_grad(set_to_none=True) release_memory() @@ -365,9 +347,7 @@ def executable_batch_size_func(batch_size: int) -> None: return per_device_batch_size @torch.no_grad() - def _aggregate_factors( - self, aggregated_factors: FACTOR_TYPE, loaded_factors: FACTOR_TYPE - ) -> FACTOR_TYPE: + def _aggregate_factors(self, aggregated_factors: FACTOR_TYPE, loaded_factors: FACTOR_TYPE) -> FACTOR_TYPE: """Aggregates factors from the current loaded factors.""" for factor_name, factors in loaded_factors.items(): if factor_name not in aggregated_factors: @@ -375,22 +355,16 @@ def _aggregate_factors( for module_name in factors: if module_name not in aggregated_factors[factor_name]: - aggregated_factors[factor_name][module_name] = ( - factors[module_name] - ).to(device=self.state.device) + aggregated_factors[factor_name][module_name] = (factors[module_name]).to(device=self.state.device) else: # Aggregate the factors from `loaded_factors` to `aggregated_factors`. - aggregated_factors[factor_name][module_name].add_( - factors[module_name].to(device=self.state.device) - ) + aggregated_factors[factor_name][module_name].add_(factors[module_name].to(device=self.state.device)) return aggregated_factors def load_factor_args(self, factors_name: str) -> Optional[Dict[str, Any]]: """Loads factor arguments with the given factor name.""" factors_output_dir = self.factors_output_dir(factors_name=factors_name) - arguments_save_path = ( - factors_output_dir / f"{FACTOR_ARGUMENTS_NAME}_arguments.json" - ) + arguments_save_path = factors_output_dir / f"{FACTOR_ARGUMENTS_NAME}_arguments.json" if not arguments_save_path.exists(): return None return load_json(arguments_save_path) @@ -419,9 +393,7 @@ def load_lambda_matrices(self, factors_name: str) -> Optional[FACTOR_TYPE]: def load_score_args(self, scores_name: str) -> Optional[Dict[str, Any]]: """Loads score arguments with the given score name.""" scores_output_dir = self.scores_output_dir(scores_name=scores_name) - arguments_save_path = ( - scores_output_dir / f"{SCORE_ARGUMENTS_NAME}_arguments.json" - ) + arguments_save_path = scores_output_dir / f"{SCORE_ARGUMENTS_NAME}_arguments.json" if not arguments_save_path.exists(): return None return load_json(arguments_save_path) @@ -440,14 +412,10 @@ def load_self_scores(self, scores_name: str) -> Optional[SCORE_TYPE]: return load_self_scores(output_dir=scores_output_dir) return None - def _load_all_required_factors( - self, factors_name: str, strategy: str, factor_config: Any - ) -> FACTOR_TYPE: + def _load_all_required_factors(self, factors_name: str, strategy: str, factor_config: Any) -> FACTOR_TYPE: loaded_factors: FACTOR_TYPE = {} if factor_config.requires_covariance_matrices_for_precondition: - covariance_factors = self.load_covariance_matrices( - factors_name=factors_name - ) + covariance_factors = self.load_covariance_matrices(factors_name=factors_name) if covariance_factors is None: error_msg = ( f"Strategy {strategy} requires loading covariance matrices before computing" @@ -503,18 +471,14 @@ def _aggregate_scores( data_partition_size = score_args.data_partition_size module_partition_size = score_args.module_partition_size all_required_partitions = [ - (i, j) - for i in range(score_args.data_partition_size) - for j in range(score_args.module_partition_size) + (i, j) for i in range(score_args.data_partition_size) for j in range(score_args.module_partition_size) ] all_partition_exists = [ - exists_fnc(output_dir=scores_output_dir, partition=partition) - for partition in all_required_partitions + exists_fnc(output_dir=scores_output_dir, partition=partition) for partition in all_required_partitions ] if not all_partition_exists: self.logger.info( - "Influence scores are not aggregated as scores for some partitions " - "are not yet computed." + "Influence scores are not aggregated as scores for some partitions " "are not yet computed." ) return @@ -533,13 +497,9 @@ def _aggregate_scores( for module_name, scores in loaded_scores.items(): if module_name not in aggregated_module_scores: - aggregated_module_scores[module_name] = scores.to( - device=self.state.device - ) + aggregated_module_scores[module_name] = scores.to(device=self.state.device) else: - aggregated_module_scores[module_name].add_( - scores.to(device=self.state.device) - ) + aggregated_module_scores[module_name].add_(scores.to(device=self.state.device)) del loaded_scores for module_name, scores in aggregated_module_scores.items(): @@ -557,6 +517,4 @@ def _aggregate_scores( self.state.wait_for_everyone() end_time = get_time(state=self.state) elapsed_time = end_time - start_time - self.logger.info( - f"Aggregated all partitioned scores in {elapsed_time:.2f} seconds." - ) + self.logger.info(f"Aggregated all partitioned scores in {elapsed_time:.2f} seconds.") diff --git a/kronfluence/computer/covariance_computer.py b/kronfluence/computer/covariance_computer.py index 7aac2ed..e7b0cae 100644 --- a/kronfluence/computer/covariance_computer.py +++ b/kronfluence/computer/covariance_computer.py @@ -80,21 +80,18 @@ def _fit_partitioned_covariance_matrices( dataloader_params=dataloader_params, indices=indices, ) - num_data_processed, covariance_factors = ( - fit_covariance_matrices_with_loader( - model=self.model, - state=self.state, - task=self.task, - loader=loader, - factor_args=factor_args, - tracked_module_names=tracked_module_names, - ) + num_data_processed, covariance_factors = fit_covariance_matrices_with_loader( + model=self.model, + state=self.state, + task=self.task, + loader=loader, + factor_args=factor_args, + tracked_module_names=tracked_module_names, ) end_time = get_time(state=self.state) elapsed_time = end_time - start_time self.logger.info( - f"Fitted covariance matrices on {num_data_processed.item()} data points in " - f"{elapsed_time:.2f} seconds." + f"Fitted covariance matrices on {num_data_processed.item()} data points in " f"{elapsed_time:.2f} seconds." ) return covariance_factors @@ -115,20 +112,13 @@ def fit_covariance_matrices( factors_output_dir = self.factors_output_dir(factors_name=factors_name) os.makedirs(factors_output_dir, exist_ok=True) - if ( - covariance_matrices_exist(output_dir=factors_output_dir) - and not overwrite_output_dir - ): - self.logger.info( - f"Found existing covariance matrices at {factors_output_dir}. Skipping." - ) + if covariance_matrices_exist(output_dir=factors_output_dir) and not overwrite_output_dir: + self.logger.info(f"Found existing covariance matrices at {factors_output_dir}. Skipping.") return if factor_args is None: factor_args = FactorArguments() - self.logger.info( - f"Factor arguments not provided. Using the default configuration: {factor_args}." - ) + self.logger.info(f"Factor arguments not provided. Using the default configuration: {factor_args}.") else: self.logger.info(f"Using the provided configuration: {factor_args}.") @@ -142,8 +132,7 @@ def fit_covariance_matrices( if not FactorConfig.CONFIGS[factor_args.strategy].requires_covariance_matrices: self.logger.info( - f"Strategy `{factor_args.strategy}` does not require fitting covariance matrices. " - f"Skipping." + f"Strategy `{factor_args.strategy}` does not require fitting covariance matrices. " f"Skipping." ) return @@ -161,23 +150,16 @@ def fit_covariance_matrices( f"DataLoader arguments not provided. Using the default configuration: {dataloader_kwargs}." ) else: - self.logger.info( - f"Using the DataLoader parameters: {dataloader_kwargs.to_dict()}." - ) + self.logger.info(f"Using the DataLoader parameters: {dataloader_kwargs.to_dict()}.") dataloader_params = dataloader_kwargs.to_dict() total_data_examples = min([factor_args.covariance_max_examples, len(dataset)]) - self.logger.info( - f"Total data examples to fit covariance matrices: {total_data_examples}." - ) + self.logger.info(f"Total data examples to fit covariance matrices: {total_data_examples}.") no_partition = ( - factor_args.covariance_data_partition_size == 1 - and factor_args.covariance_module_partition_size == 1 - ) - partition_provided = ( - target_data_partitions is not None or target_module_partitions is not None + factor_args.covariance_data_partition_size == 1 and factor_args.covariance_module_partition_size == 1 ) + partition_provided = target_data_partitions is not None or target_module_partitions is not None if no_partition and partition_provided: error_msg = ( "`target_data_partitions` or `target_module_partitions` were specified, while" @@ -192,14 +174,12 @@ def fit_covariance_matrices( self.logger.error(error_msg) raise ValueError(error_msg) if per_device_batch_size is None: - per_device_batch_size = ( - self._find_executable_covariance_factors_batch_size( - dataloader_params=dataloader_params, - dataset=dataset, - total_data_examples=total_data_examples, - factor_args=factor_args, - tracked_module_names=None, - ) + per_device_batch_size = self._find_executable_covariance_factors_batch_size( + dataloader_params=dataloader_params, + dataset=dataset, + total_data_examples=total_data_examples, + factor_args=factor_args, + tracked_module_names=None, ) covariance_factors = self._fit_partitioned_covariance_matrices( dataset=dataset, @@ -225,11 +205,9 @@ def fit_covariance_matrices( data_partition_size=factor_args.covariance_data_partition_size, target_data_partitions=target_data_partitions, ) - module_partition_names, target_module_partitions = ( - self._get_module_partition( - module_partition_size=factor_args.covariance_module_partition_size, - target_module_partitions=target_module_partitions, - ) + module_partition_names, target_module_partitions = self._get_module_partition( + module_partition_size=factor_args.covariance_module_partition_size, + target_module_partitions=target_module_partitions, ) all_start_time = get_time(state=self.state) @@ -254,25 +232,18 @@ def fit_covariance_matrices( f"{end_index}) and modules {module_partition_names[module_partition]}." ) - max_total_examples = ( - total_data_examples - // factor_args.covariance_data_partition_size - ) + max_total_examples = total_data_examples // factor_args.covariance_data_partition_size if max_total_examples < self.state.num_processes: - error_msg = ( - "There are more data examples than the number of processes." - ) + error_msg = "There are more data examples than the number of processes." self.logger.error(error_msg) raise ValueError(error_msg) if per_device_batch_size is None: - per_device_batch_size = ( - self._find_executable_covariance_factors_batch_size( - dataloader_params=dataloader_params, - dataset=dataset, - factor_args=factor_args, - total_data_examples=max_total_examples, - tracked_module_names=module_partition_names[0], - ) + per_device_batch_size = self._find_executable_covariance_factors_batch_size( + dataloader_params=dataloader_params, + dataset=dataset, + factor_args=factor_args, + total_data_examples=max_total_examples, + tracked_module_names=module_partition_names[0], ) covariance_factors = self._fit_partitioned_covariance_matrices( dataset=dataset, @@ -291,18 +262,12 @@ def fit_covariance_matrices( ) self.state.wait_for_everyone() del covariance_factors - self.logger.info( - f"Saved partitioned covariance matrices at {factors_output_dir}." - ) + self.logger.info(f"Saved partitioned covariance matrices at {factors_output_dir}.") all_end_time = get_time(state=self.state) elapsed_time = all_end_time - all_start_time - self.logger.info( - f"Fitted all partitioned covariance matrices in {elapsed_time:.2f} seconds." - ) - self.aggregate_covariance_matrices( - factors_name=factors_name, factor_args=factor_args - ) + self.logger.info(f"Fitted all partitioned covariance matrices in {elapsed_time:.2f} seconds.") + self.aggregate_covariance_matrices(factors_name=factors_name, factor_args=factor_args) profile_summary = self.profiler.summary() if profile_summary != "": @@ -326,15 +291,9 @@ def aggregate_covariance_matrices( data_partition_size = factor_args.covariance_data_partition_size module_partition_size = factor_args.covariance_module_partition_size - all_required_partitions = [ - (i, j) - for i in range(data_partition_size) - for j in range(module_partition_size) - ] + all_required_partitions = [(i, j) for i in range(data_partition_size) for j in range(module_partition_size)] all_partition_exists = [ - covariance_matrices_exist( - output_dir=factors_output_dir, partition=partition - ) + covariance_matrices_exist(output_dir=factors_output_dir, partition=partition) for partition in all_required_partitions ] if not all_partition_exists: @@ -367,6 +326,4 @@ def aggregate_covariance_matrices( self.state.wait_for_everyone() end_time = get_time(state=self.state) elapsed_time = end_time - start_time - self.logger.info( - f"Aggregated all partitioned covariance matrices in {elapsed_time:.2f} seconds." - ) + self.logger.info(f"Aggregated all partitioned covariance matrices in {elapsed_time:.2f} seconds.") diff --git a/kronfluence/computer/eigen_computer.py b/kronfluence/computer/eigen_computer.py index 1871a58..76095cb 100644 --- a/kronfluence/computer/eigen_computer.py +++ b/kronfluence/computer/eigen_computer.py @@ -46,20 +46,13 @@ def perform_eigendecomposition( factors_output_dir = self.factors_output_dir(factors_name=factors_name) os.makedirs(factors_output_dir, exist_ok=True) - if ( - eigendecomposition_exist(output_dir=factors_output_dir) - and not overwrite_output_dir - ): - self.logger.info( - f"Found existing Eigendecomposition results at {factors_output_dir}. Skipping." - ) + if eigendecomposition_exist(output_dir=factors_output_dir) and not overwrite_output_dir: + self.logger.info(f"Found existing Eigendecomposition results at {factors_output_dir}. Skipping.") return if factor_args is None: factor_args = FactorArguments() - self.logger.info( - f"Factor arguments not provided. Using the default configuration: {factor_args}." - ) + self.logger.info(f"Factor arguments not provided. Using the default configuration: {factor_args}.") else: self.logger.info(f"Using the provided configuration: {factor_args}.") @@ -73,18 +66,13 @@ def perform_eigendecomposition( if not FactorConfig.CONFIGS[factor_args.strategy].requires_eigendecomposition: self.logger.info( - f"Strategy `{factor_args.strategy}` does not require performing Eigendecomposition. " - f"Skipping." + f"Strategy `{factor_args.strategy}` does not require performing Eigendecomposition. " f"Skipping." ) return if load_from_factors_name is not None: - self.logger.info( - f"Loading covariance matrices from factors with name `{load_from_factors_name}`." - ) - load_factors_output_dir = self.factors_output_dir( - factors_name=load_from_factors_name - ) + self.logger.info(f"Loading covariance matrices from factors with name `{load_from_factors_name}`.") + load_factors_output_dir = self.factors_output_dir(factors_name=load_from_factors_name) else: load_factors_output_dir = factors_output_dir @@ -97,9 +85,7 @@ def perform_eigendecomposition( raise FactorsNotFoundError(error_msg) with self.profiler.profile("Load Covariance"): - covariance_factors = load_covariance_matrices( - output_dir=load_factors_output_dir - ) + covariance_factors = load_covariance_matrices(output_dir=load_factors_output_dir) if self.state.is_main_process: release_memory() @@ -113,17 +99,13 @@ def perform_eigendecomposition( ) end_time = time.time() elapsed_time = end_time - start_time - self.logger.info( - f"Performed Eigendecomposition in {elapsed_time:.2f} seconds." - ) + self.logger.info(f"Performed Eigendecomposition in {elapsed_time:.2f} seconds.") with self.profiler.profile("Save Eigendecomposition"): save_eigendecomposition( output_dir=factors_output_dir, eigen_factors=eigen_factors, ) - self.logger.info( - f"Saved Eigendecomposition results at {factors_output_dir}." - ) + self.logger.info(f"Saved Eigendecomposition results at {factors_output_dir}.") self.state.wait_for_everyone() profile_summary = self.profiler.summary() @@ -202,8 +184,7 @@ def _fit_partitioned_lambda_matrices( end_time = get_time(state=self.state) elapsed_time = end_time - start_time self.logger.info( - f"Fitted Lambda matrices on {num_data_processed.item()} data points in " - f"{elapsed_time:.2f} seconds." + f"Fitted Lambda matrices on {num_data_processed.item()} data points in " f"{elapsed_time:.2f} seconds." ) return lambda_factors @@ -225,20 +206,13 @@ def fit_lambda_matrices( factors_output_dir = self.factors_output_dir(factors_name=factors_name) os.makedirs(factors_output_dir, exist_ok=True) - if ( - lambda_matrices_exist(output_dir=factors_output_dir) - and not overwrite_output_dir - ): - self.logger.info( - f"Found existing Lambda matrices at {factors_output_dir}. Skipping." - ) + if lambda_matrices_exist(output_dir=factors_output_dir) and not overwrite_output_dir: + self.logger.info(f"Found existing Lambda matrices at {factors_output_dir}. Skipping.") return if factor_args is None: factor_args = FactorArguments() - self.logger.info( - f"Factor arguments not provided. Using the default configuration: {factor_args}." - ) + self.logger.info(f"Factor arguments not provided. Using the default configuration: {factor_args}.") else: self.logger.info(f"Using the provided configuration: {factor_args}.") @@ -252,8 +226,7 @@ def fit_lambda_matrices( if not FactorConfig.CONFIGS[factor_args.strategy].requires_lambda_matrices: self.logger.info( - f"Strategy `{factor_args.strategy}` does not require fitting Lambda matrices. " - f"Skipping." + f"Strategy `{factor_args.strategy}` does not require fitting Lambda matrices. " f"Skipping." ) return @@ -269,17 +242,13 @@ def fit_lambda_matrices( self.logger.info( f"Will be loading Eigendecomposition results from factors with name `{load_from_factors_name}`." ) - load_factors_output_dir = self.factors_output_dir( - factors_name=load_from_factors_name - ) + load_factors_output_dir = self.factors_output_dir(factors_name=load_from_factors_name) else: load_factors_output_dir = factors_output_dir if ( not eigendecomposition_exist(output_dir=load_factors_output_dir) - and FactorConfig.CONFIGS[ - factor_args.strategy - ].requires_eigendecomposition_for_lambda + and FactorConfig.CONFIGS[factor_args.strategy].requires_eigendecomposition_for_lambda ): error_msg = ( f"Eigendecomposition results not found at {load_factors_output_dir}. " @@ -295,32 +264,19 @@ def fit_lambda_matrices( f"DataLoader arguments not provided. Using the default configuration: {dataloader_kwargs}." ) else: - self.logger.info( - f"Using the DataLoader parameters: {dataloader_kwargs.to_dict()}." - ) + self.logger.info(f"Using the DataLoader parameters: {dataloader_kwargs.to_dict()}.") dataloader_params = dataloader_kwargs.to_dict() eigen_factors = None - if FactorConfig.CONFIGS[ - factor_args.strategy - ].requires_eigendecomposition_for_lambda: + if FactorConfig.CONFIGS[factor_args.strategy].requires_eigendecomposition_for_lambda: with self.profiler.profile("Load Eigendecomposition"): - eigen_factors = load_eigendecomposition( - output_dir=load_factors_output_dir - ) + eigen_factors = load_eigendecomposition(output_dir=load_factors_output_dir) total_data_examples = min([factor_args.lambda_max_examples, len(dataset)]) - self.logger.info( - f"Total data examples to fit Lambda matrices: {total_data_examples}." - ) + self.logger.info(f"Total data examples to fit Lambda matrices: {total_data_examples}.") - no_partition = ( - factor_args.lambda_data_partition_size == 1 - and factor_args.lambda_module_partition_size == 1 - ) - partition_provided = ( - target_data_partitions is not None or target_module_partitions is not None - ) + no_partition = factor_args.lambda_data_partition_size == 1 and factor_args.lambda_module_partition_size == 1 + partition_provided = target_data_partitions is not None or target_module_partitions is not None if no_partition and partition_provided: error_msg = ( "`target_data_partitions` or `target_module_partitions` were specified, while" @@ -354,9 +310,7 @@ def fit_lambda_matrices( ) with self.profiler.profile("Save Lambda"): if self.state.is_main_process: - save_lambda_matrices( - output_dir=factors_output_dir, lambda_factors=lambda_factors - ) + save_lambda_matrices(output_dir=factors_output_dir, lambda_factors=lambda_factors) self.state.wait_for_everyone() self.logger.info(f"Saved Lambda matrices at {factors_output_dir}.") @@ -366,11 +320,9 @@ def fit_lambda_matrices( data_partition_size=factor_args.lambda_data_partition_size, target_data_partitions=target_data_partitions, ) - module_partition_names, target_module_partitions = ( - self._get_module_partition( - module_partition_size=factor_args.lambda_module_partition_size, - target_module_partitions=target_module_partitions, - ) + module_partition_names, target_module_partitions = self._get_module_partition( + module_partition_size=factor_args.lambda_module_partition_size, + target_module_partitions=target_module_partitions, ) all_start_time = get_time(state=self.state) @@ -395,25 +347,19 @@ def fit_lambda_matrices( f"{end_index}) and modules {module_partition_names[module_partition]}." ) - max_total_examples = ( - total_data_examples // factor_args.lambda_data_partition_size - ) + max_total_examples = total_data_examples // factor_args.lambda_data_partition_size if max_total_examples < self.state.num_processes: - error_msg = ( - "There are more data examples than the number of processes." - ) + error_msg = "There are more data examples than the number of processes." self.logger.error(error_msg) raise ValueError(error_msg) if per_device_batch_size is None: - per_device_batch_size = ( - self._find_executable_lambda_factors_batch_size( - eigen_factors=eigen_factors, - dataloader_params=dataloader_params, - dataset=dataset, - factor_args=factor_args, - total_data_examples=max_total_examples, - tracked_module_names=module_partition_names[0], - ) + per_device_batch_size = self._find_executable_lambda_factors_batch_size( + eigen_factors=eigen_factors, + dataloader_params=dataloader_params, + dataset=dataset, + factor_args=factor_args, + total_data_examples=max_total_examples, + tracked_module_names=module_partition_names[0], ) lambda_factors = self._fit_partitioned_lambda_matrices( eigen_factors=eigen_factors, @@ -433,18 +379,12 @@ def fit_lambda_matrices( ) self.state.wait_for_everyone() del lambda_factors - self.logger.info( - f"Saved partitioned Lambda matrices at {factors_output_dir}." - ) + self.logger.info(f"Saved partitioned Lambda matrices at {factors_output_dir}.") all_end_time = get_time(state=self.state) elapsed_time = all_end_time - all_start_time - self.logger.info( - f"Fitted all partitioned Lambda matrices in {elapsed_time:.2f} seconds." - ) - self.aggregate_lambda_matrices( - factors_name=factors_name, factor_args=factor_args - ) + self.logger.info(f"Fitted all partitioned Lambda matrices in {elapsed_time:.2f} seconds.") + self.aggregate_lambda_matrices(factors_name=factors_name, factor_args=factor_args) profile_summary = self.profiler.summary() if profile_summary != "": @@ -469,19 +409,14 @@ def aggregate_lambda_matrices( data_partition_size = factor_args.lambda_data_partition_size module_partition_size = factor_args.lambda_module_partition_size - all_required_partitions = [ - (i, j) - for i in range(data_partition_size) - for j in range(module_partition_size) - ] + all_required_partitions = [(i, j) for i in range(data_partition_size) for j in range(module_partition_size)] all_partition_exists = [ lambda_matrices_exist(output_dir=factors_output_dir, partition=partition) for partition in all_required_partitions ] if not all_partition_exists: self.logger.info( - "Lambda matrices are not aggregated as Lambda matrices for some partitions " - "are not yet computed." + "Lambda matrices are not aggregated as Lambda matrices for some partitions " "are not yet computed." ) return @@ -508,6 +443,4 @@ def aggregate_lambda_matrices( self.state.wait_for_everyone() end_time = get_time(state=self.state) elapsed_time = end_time - start_time - self.logger.info( - f"Aggregated all partitioned Lambda matrices in {elapsed_time:.2f} seconds." - ) + self.logger.info(f"Aggregated all partitioned Lambda matrices in {elapsed_time:.2f} seconds.") diff --git a/kronfluence/computer/pairwise_score_computer.py b/kronfluence/computer/pairwise_score_computer.py index f92c3ed..e9f1524 100644 --- a/kronfluence/computer/pairwise_score_computer.py +++ b/kronfluence/computer/pairwise_score_computer.py @@ -57,9 +57,7 @@ def _find_executable_pairwise_scores_batch_size( ) def executable_batch_size_func(batch_size: int) -> None: - self.logger.info( - f"Attempting to set per-device batch size to {batch_size}." - ) + self.logger.info(f"Attempting to set per-device batch size to {batch_size}.") set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False) release_memory() total_batch_size = batch_size * self.state.num_processes @@ -143,9 +141,7 @@ def _fit_partitioned_pairwise_scores( ) end_time = get_time(state=self.state) elapsed_time = end_time - start_time - self.logger.info( - f"Computed pairwise influence scores in {elapsed_time:.2f} seconds." - ) + self.logger.info(f"Computed pairwise influence scores in {elapsed_time:.2f} seconds.") return scores def compute_pairwise_scores( @@ -206,20 +202,13 @@ def compute_pairwise_scores( scores_output_dir = self.scores_output_dir(scores_name=scores_name) os.makedirs(scores_output_dir, exist_ok=True) - if ( - pairwise_scores_exist(output_dir=scores_output_dir) - and not overwrite_output_dir - ): - self.logger.info( - f"Found existing pairwise scores at {scores_output_dir}. Skipping." - ) + if pairwise_scores_exist(output_dir=scores_output_dir) and not overwrite_output_dir: + self.logger.info(f"Found existing pairwise scores at {scores_output_dir}. Skipping.") return if score_args is None: score_args = ScoreArguments() - self.logger.info( - f"Score arguments not provided. Using the default configuration: {score_args}." - ) + self.logger.info(f"Score arguments not provided. Using the default configuration: {score_args}.") else: self.logger.info(f"Using the provided configuration: {score_args}.") @@ -268,9 +257,7 @@ def compute_pairwise_scores( f"DataLoader arguments not provided. Using the default configuration: {dataloader_kwargs}." ) else: - self.logger.info( - f"Using the DataLoader parameters: {dataloader_kwargs.to_dict()}." - ) + self.logger.info(f"Using the DataLoader parameters: {dataloader_kwargs.to_dict()}.") dataloader_params = dataloader_kwargs.to_dict() if query_indices is not None: query_dataset = data.Subset(dataset=query_dataset, indices=query_indices) @@ -284,13 +271,8 @@ def compute_pairwise_scores( factor_config=factor_config, ) - no_partition = ( - score_args.data_partition_size == 1 - and score_args.module_partition_size == 1 - ) - partition_provided = ( - target_data_partitions is not None or target_module_partitions is not None - ) + no_partition = score_args.data_partition_size == 1 and score_args.module_partition_size == 1 + partition_provided = target_data_partitions is not None or target_module_partitions is not None if no_partition and partition_provided: error_msg = ( "`target_data_partitions` or `target_module_partitions` were specified, while" @@ -301,18 +283,16 @@ def compute_pairwise_scores( if no_partition: if per_device_train_batch_size is None: - per_device_train_batch_size = ( - self._find_executable_pairwise_scores_batch_size( - query_dataset=query_dataset, - per_device_query_batch_size=per_device_query_batch_size, - train_dataset=train_dataset, - loaded_factors=loaded_factors, - dataloader_params=dataloader_params, - total_data_examples=len(train_dataset), - score_args=score_args, - factor_args=factor_args, - tracked_modules_name=None, - ) + per_device_train_batch_size = self._find_executable_pairwise_scores_batch_size( + query_dataset=query_dataset, + per_device_query_batch_size=per_device_query_batch_size, + train_dataset=train_dataset, + loaded_factors=loaded_factors, + dataloader_params=dataloader_params, + total_data_examples=len(train_dataset), + score_args=score_args, + factor_args=factor_args, + tracked_modules_name=None, ) scores = self._fit_partitioned_pairwise_scores( loaded_factors=loaded_factors, @@ -341,11 +321,9 @@ def compute_pairwise_scores( data_partition_size=score_args.data_partition_size, target_data_partitions=target_data_partitions, ) - module_partition_names, target_module_partitions = ( - self._get_module_partition( - module_partition_size=score_args.module_partition_size, - target_module_partitions=target_module_partitions, - ) + module_partition_names, target_module_partitions = self._get_module_partition( + module_partition_size=score_args.module_partition_size, + target_module_partitions=target_module_partitions, ) all_start_time = get_time(state=self.state) @@ -371,19 +349,16 @@ def compute_pairwise_scores( ) if per_device_train_batch_size is None: - per_device_train_batch_size = ( - self._find_executable_pairwise_scores_batch_size( - query_dataset=query_dataset, - per_device_query_batch_size=per_device_query_batch_size, - train_dataset=train_dataset, - loaded_factors=loaded_factors, - dataloader_params=dataloader_params, - total_data_examples=len(train_dataset) - // score_args.data_partition_size, - score_args=score_args, - factor_args=factor_args, - tracked_modules_name=module_partition_names[0], - ) + per_device_train_batch_size = self._find_executable_pairwise_scores_batch_size( + query_dataset=query_dataset, + per_device_query_batch_size=per_device_query_batch_size, + train_dataset=train_dataset, + loaded_factors=loaded_factors, + dataloader_params=dataloader_params, + total_data_examples=len(train_dataset) // score_args.data_partition_size, + score_args=score_args, + factor_args=factor_args, + tracked_modules_name=module_partition_names[0], ) scores = self._fit_partitioned_pairwise_scores( loaded_factors=loaded_factors, @@ -406,27 +381,19 @@ def compute_pairwise_scores( ) self.state.wait_for_everyone() del scores - self.logger.info( - f"Saved partitioned pairwise scores at {scores_output_dir}." - ) + self.logger.info(f"Saved partitioned pairwise scores at {scores_output_dir}.") all_end_time = get_time(state=self.state) elapsed_time = all_end_time - all_start_time - self.logger.info( - f"Fitted all partitioned pairwise scores in {elapsed_time:.2f} seconds." - ) - self.aggregate_pairwise_scores( - scores_name=scores_name, score_args=score_args - ) + self.logger.info(f"Fitted all partitioned pairwise scores in {elapsed_time:.2f} seconds.") + self.aggregate_pairwise_scores(scores_name=scores_name, score_args=score_args) profile_summary = self.profiler.summary() if profile_summary != "": self.logger.info(self.profiler.summary()) @torch.no_grad() - def aggregate_pairwise_scores( - self, scores_name: str, score_args: ScoreArguments - ) -> None: + def aggregate_pairwise_scores(self, scores_name: str, score_args: ScoreArguments) -> None: """Aggregates pairwise scores computed for all data and module partitions.""" self._aggregate_scores( scores_name=scores_name, diff --git a/kronfluence/computer/self_score_computer.py b/kronfluence/computer/self_score_computer.py index 89d7dfd..14951ec 100644 --- a/kronfluence/computer/self_score_computer.py +++ b/kronfluence/computer/self_score_computer.py @@ -54,9 +54,7 @@ def _find_executable_self_scores_batch_size( ) def executable_batch_size_func(batch_size: int) -> None: - self.logger.info( - f"Attempting to set per-device batch size to {batch_size}." - ) + self.logger.info(f"Attempting to set per-device batch size to {batch_size}.") set_mode(model=self.model, mode=ModuleMode.DEFAULT, keep_factors=False) release_memory() total_batch_size = batch_size * self.state.num_processes @@ -121,9 +119,7 @@ def _fit_partitioned_self_scores( ) end_time = get_time(state=self.state) elapsed_time = end_time - start_time - self.logger.info( - f"Computed self-influence scores in {elapsed_time:.2f} seconds." - ) + self.logger.info(f"Computed self-influence scores in {elapsed_time:.2f} seconds.") return scores def compute_self_scores( @@ -169,23 +165,17 @@ def compute_self_scores( overwrite_output_dir (bool, optional): If True, the existing factors with the same name will be overwritten. """ - self.logger.debug( - f"Computing self-influence scores with parameters: {locals()}" - ) + self.logger.debug(f"Computing self-influence scores with parameters: {locals()}") scores_output_dir = self.scores_output_dir(scores_name=scores_name) os.makedirs(scores_output_dir, exist_ok=True) if self_scores_exist(output_dir=scores_output_dir) and not overwrite_output_dir: - self.logger.info( - f"Found existing self-influence scores at {scores_output_dir}. Skipping." - ) + self.logger.info(f"Found existing self-influence scores at {scores_output_dir}. Skipping.") return if score_args is None: score_args = ScoreArguments() - self.logger.info( - f"Score arguments not provided. Using the default configuration: {score_args}." - ) + self.logger.info(f"Score arguments not provided. Using the default configuration: {score_args}.") else: self.logger.info(f"Using the provided configuration: {score_args}.") @@ -227,9 +217,7 @@ def compute_self_scores( f"DataLoader arguments not provided. Using the default configuration: {dataloader_kwargs}." ) else: - self.logger.info( - f"Using the DataLoader parameters: {dataloader_kwargs.to_dict()}." - ) + self.logger.info(f"Using the DataLoader parameters: {dataloader_kwargs.to_dict()}.") dataloader_params = dataloader_kwargs.to_dict() if train_indices is not None: train_dataset = data.Subset(dataset=train_dataset, indices=train_indices) @@ -241,13 +229,8 @@ def compute_self_scores( factor_config=factor_config, ) - no_partition = ( - score_args.data_partition_size == 1 - and score_args.module_partition_size == 1 - ) - partition_provided = ( - target_data_partitions is not None or target_module_partitions is not None - ) + no_partition = score_args.data_partition_size == 1 and score_args.module_partition_size == 1 + partition_provided = target_data_partitions is not None or target_module_partitions is not None if no_partition and partition_provided: error_msg = ( @@ -259,16 +242,14 @@ def compute_self_scores( if no_partition: if per_device_train_batch_size is None: - per_device_train_batch_size = ( - self._find_executable_self_scores_batch_size( - train_dataset=train_dataset, - loaded_factors=loaded_factors, - dataloader_params=dataloader_params, - total_data_examples=len(train_dataset), - score_args=score_args, - factor_args=factor_args, - tracked_modules_name=None, - ) + per_device_train_batch_size = self._find_executable_self_scores_batch_size( + train_dataset=train_dataset, + loaded_factors=loaded_factors, + dataloader_params=dataloader_params, + total_data_examples=len(train_dataset), + score_args=score_args, + factor_args=factor_args, + tracked_modules_name=None, ) scores = self._fit_partitioned_self_scores( loaded_factors=loaded_factors, @@ -295,11 +276,9 @@ def compute_self_scores( data_partition_size=score_args.data_partition_size, target_data_partitions=target_data_partitions, ) - module_partition_names, target_module_partitions = ( - self._get_module_partition( - module_partition_size=score_args.module_partition_size, - target_module_partitions=target_module_partitions, - ) + module_partition_names, target_module_partitions = self._get_module_partition( + module_partition_size=score_args.module_partition_size, + target_module_partitions=target_module_partitions, ) all_start_time = get_time(state=self.state) @@ -325,17 +304,14 @@ def compute_self_scores( ) if per_device_train_batch_size is None: - per_device_train_batch_size = ( - self._find_executable_self_scores_batch_size( - train_dataset=train_dataset, - loaded_factors=loaded_factors, - dataloader_params=dataloader_params, - total_data_examples=len(train_dataset) - // score_args.data_partition_size, - score_args=score_args, - factor_args=factor_args, - tracked_modules_name=module_partition_names[0], - ) + per_device_train_batch_size = self._find_executable_self_scores_batch_size( + train_dataset=train_dataset, + loaded_factors=loaded_factors, + dataloader_params=dataloader_params, + total_data_examples=len(train_dataset) // score_args.data_partition_size, + score_args=score_args, + factor_args=factor_args, + tracked_modules_name=module_partition_names[0], ) scores = self._fit_partitioned_self_scores( loaded_factors=loaded_factors, @@ -356,15 +332,11 @@ def compute_self_scores( ) self.state.wait_for_everyone() del scores - self.logger.info( - f"Saved partitioned self-influence scores at {scores_output_dir}." - ) + self.logger.info(f"Saved partitioned self-influence scores at {scores_output_dir}.") all_end_time = get_time(state=self.state) elapsed_time = all_end_time - all_start_time - self.logger.info( - f"Fitted all partitioned self-influence scores in {elapsed_time:.2f} seconds." - ) + self.logger.info(f"Fitted all partitioned self-influence scores in {elapsed_time:.2f} seconds.") self.aggregate_self_scores(scores_name=scores_name, score_args=score_args) profile_summary = self.profiler.summary() @@ -372,9 +344,7 @@ def compute_self_scores( self.logger.info(self.profiler.summary()) @torch.no_grad() - def aggregate_self_scores( - self, scores_name: str, score_args: ScoreArguments - ) -> None: + def aggregate_self_scores(self, scores_name: str, score_args: ScoreArguments) -> None: """Aggregates self-influence scores computed for all data and module partitions.""" self._aggregate_scores( scores_name=scores_name, diff --git a/kronfluence/factor/config.py b/kronfluence/factor/config.py index 99e1956..e3d315c 100644 --- a/kronfluence/factor/config.py +++ b/kronfluence/factor/config.py @@ -13,24 +13,24 @@ NUM_LAMBDA_PROCESSED, ) +STORAGE_TYPE = Dict[str, Any] + class FactorStrategy(str, BaseEnum): - """A strategy for computing preconditioning factor.""" + """Strategy for computing preconditioning factors.""" IDENTITY = "identity" - DIAGONAL = "diag" + DIAGONAL = "diagonal" KFAC = "kfac" EKFAC = "ekfac" class FactorConfig(metaclass=ABCMeta): - """Configuration for each factor strategy.""" + """Configuration for each available factor strategy.""" CONFIGS: Dict[FactorStrategy, Any] = {} - def __init_subclass__( - cls, factor_strategy: Optional[FactorStrategy] = None, **kwargs - ) -> None: + def __init_subclass__(cls, factor_strategy: Optional[FactorStrategy] = None, **kwargs) -> None: """Registers all subclasses of `FactorConfig`.""" super().__init_subclass__(**kwargs) if factor_strategy is not None: @@ -40,41 +40,33 @@ def __init_subclass__( @property @abstractmethod def requires_covariance_matrices(self) -> bool: - """Returns True if the given strategy requires computing covariance matrices.""" - raise NotImplementedError( - "Subclasses must implement the `requires_covariance_matrices` property." - ) + """Returns True if the strategy requires computing covariance matrices.""" + raise NotImplementedError("Subclasses must implement the `requires_covariance_matrices` property.") @property @abstractmethod def requires_eigendecomposition(self) -> bool: - """Returns True if the given strategy requires performing Eigendecomposition.""" - raise NotImplementedError( - "Subclasses must implement the `requires_eigendecomposition` property." - ) + """Returns True if the strategy requires performing Eigendecomposition.""" + raise NotImplementedError("Subclasses must implement the `requires_eigendecomposition` property.") @property @abstractmethod def requires_lambda_matrices(self) -> bool: - """Returns True if the given strategy requires computing Lambda matrices.""" - raise NotImplementedError( - "Subclasses must implement the `requires_lambda_matrices` property." - ) + """Returns True if the strategy requires computing Lambda matrices.""" + raise NotImplementedError("Subclasses must implement the `requires_lambda_matrices` property.") @property @abstractmethod def requires_eigendecomposition_for_lambda(self) -> bool: - """Returns True if the given strategy requires loading Eigendecomposition results, before - computing Lambda matrices.""" - raise NotImplementedError( - "Subclasses must implement the `requires_eigendecomposition_for_lambda` property." - ) + """Returns True if the strategy requires loading Eigendecomposition results, before computing + Lambda matrices.""" + raise NotImplementedError("Subclasses must implement the `requires_eigendecomposition_for_lambda` property.") @property @abstractmethod def requires_covariance_matrices_for_precondition(self) -> bool: - """Returns True if the given strategy requires loading covariance matrices, before - computing preconditioned gradient.""" + """Returns True if the strategy requires loading covariance matrices, before computing + preconditioned gradient.""" raise NotImplementedError( "Subclasses must implement the `requires_covariance_matrices_for_precondition` property." ) @@ -82,8 +74,8 @@ def requires_covariance_matrices_for_precondition(self) -> bool: @property @abstractmethod def requires_eigendecomposition_for_precondition(self) -> bool: - """Returns True if the given strategy requires loading Eigendecomposition results, before - computing preconditioned gradient.""" + """Returns True if the strategy requires loading Eigendecomposition results, before computing + preconditioned gradient.""" raise NotImplementedError( "Subclasses must implement the `requires_eigendecomposition_for_precondition` property." ) @@ -91,33 +83,30 @@ def requires_eigendecomposition_for_precondition(self) -> bool: @property @abstractmethod def requires_lambda_matrices_for_precondition(self) -> bool: - """Returns True if the given strategy requires loading Lambda matrices, before - computing the preconditioned gradient.""" - raise NotImplementedError( - "Subclasses must implement the `requires_lambda_matrices_for_precondition` property." - ) + """Returns True if the strategy requires loading Lambda matrices, before computing + the preconditioned gradient.""" + raise NotImplementedError("Subclasses must implement the `requires_lambda_matrices_for_precondition` property.") @abstractmethod def precondition_gradient( self, gradient: torch.Tensor, - storage: Dict[str, torch.Tensor], - damping: float, + storage: STORAGE_TYPE, + damping: Optional[float], ) -> torch.Tensor: - """Preconditions the per-sample-gradient with the appropriate strategy. The per-sample-gradient - is a 3-dimensional tensor with shape `batch_size x input_dim x output_dim`. + """Preconditions the per-sample-gradient. The per-sample-gradient is a 3-dimensional + tensor with the shape `batch_size x input_dim x output_dim`. Args: gradient (torch.Tensor): The per-sample-gradient tensor. storage (Dict[str, Any]): - A dictionary containing various factors to help perform preconditioning. + A dictionary containing various factors required to compute the preconditioned gradient. + See `TrackedModule` for details. damping (float): The damping factor when computing the preconditioned gradient. """ - raise NotImplementedError( - "Subclasses must implement the `precondition_gradient` property." - ) + raise NotImplementedError("Subclasses must implement the `precondition_gradient` property.") class Identity(FactorConfig, factor_strategy=FactorStrategy.IDENTITY): @@ -154,8 +143,8 @@ def requires_lambda_matrices_for_precondition(self) -> bool: def precondition_gradient( self, gradient: torch.Tensor, - storage: Dict[str, torch.Tensor], - damping: Optional[int], + storage: STORAGE_TYPE, + damping: Optional[float], ) -> torch.Tensor: del storage, damping return gradient @@ -195,12 +184,10 @@ def requires_lambda_matrices_for_precondition(self) -> bool: def precondition_gradient( self, gradient: torch.Tensor, - storage: Dict[str, torch.Tensor], - damping: Optional[int], + storage: STORAGE_TYPE, + damping: Optional[float], ) -> torch.Tensor: - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to( - dtype=gradient.dtype, device=gradient.device - ) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=gradient.dtype, device=gradient.device) num_lambda_processed = storage[NUM_LAMBDA_PROCESSED].to(device=gradient.device) if damping is None: damping = 0.1 * torch.mean(lambda_matrix) @@ -245,24 +232,14 @@ def requires_lambda_matrices_for_precondition(self) -> bool: def precondition_gradient( self, gradient: torch.Tensor, - storage: Dict[str, torch.Tensor], - damping: Optional[int], + storage: STORAGE_TYPE, + damping: Optional[float], ) -> torch.Tensor: - activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to( - dtype=gradient.dtype, device=gradient.device - ) - gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to( - dtype=gradient.dtype, device=gradient.device - ) - activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to( - dtype=gradient.dtype, device=gradient.device - ) - gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to( - dtype=gradient.dtype, device=gradient.device - ) - lambda_matrix = torch.kron( - activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1) - ).unsqueeze(0) + activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) + gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) + activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(dtype=gradient.dtype, device=gradient.device) + gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(dtype=gradient.dtype, device=gradient.device) + lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0) rotated_gradient = torch.einsum( "ij,bjl,lk->bik", @@ -321,18 +298,12 @@ def requires_lambda_matrices_for_precondition(self) -> bool: def precondition_gradient( self, gradient: torch.Tensor, - storage: Dict[str, torch.Tensor], - damping: Optional[int], + storage: STORAGE_TYPE, + damping: Optional[float], ) -> torch.Tensor: - activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to( - dtype=gradient.dtype, device=gradient.device - ) - gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to( - dtype=gradient.dtype, device=gradient.device - ) - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to( - dtype=gradient.dtype, device=gradient.device - ) + activation_eigenvectors = storage[ACTIVATION_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) + gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=gradient.dtype, device=gradient.device) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=gradient.dtype, device=gradient.device) num_lambda_processed = storage[NUM_LAMBDA_PROCESSED].to(device=gradient.device) rotated_gradient = torch.einsum( diff --git a/kronfluence/factor/covariance.py b/kronfluence/factor/covariance.py index e365445..3af6d1e 100644 --- a/kronfluence/factor/covariance.py +++ b/kronfluence/factor/covariance.py @@ -116,13 +116,14 @@ def fit_covariance_matrices_with_loader( Arguments related to computing covariance matrices. tracked_module_names (List[str], optional): A list of module names that covariance matrices will be computed. If not specified, covariance - matrices will be computed for all available tracked modules. + matrices will be computed for all tracked modules. Returns: Tuple[torch.Tensor, FACTOR_TYPE]: A tuple containing the number of data points processed, and computed covariance matrices in CPU. The covariance matrices are organized in nested dictionaries, where the first key in the name of the - covariance matrix (e.g., activation covariance) and the second key is the module name. + covariance matrix (e.g., activation covariance and gradient covariance) and the second key is + the module name. """ with torch.no_grad(): update_factor_args(model=model, factor_args=factor_args) @@ -132,9 +133,7 @@ def fit_covariance_matrices_with_loader( tracked_module_names=tracked_module_names, mode=ModuleMode.COVARIANCE, ) - num_data_processed = torch.zeros( - (1,), dtype=torch.int64, device=state.device, requires_grad=False - ) + num_data_processed = torch.zeros((1,), dtype=torch.int64, device=state.device, requires_grad=False) with tqdm( total=len(loader), @@ -170,8 +169,6 @@ def fit_covariance_matrices_with_loader( with torch.no_grad(): saved_factors: FACTOR_TYPE = {} for covariance_factor_name in COVARIANCE_FACTOR_NAMES: - saved_factors[covariance_factor_name] = load_factors( - model=model, factor_name=covariance_factor_name - ) + saved_factors[covariance_factor_name] = load_factors(model=model, factor_name=covariance_factor_name) set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) return num_data_processed, saved_factors diff --git a/kronfluence/factor/eigen.py b/kronfluence/factor/eigen.py index 9af3535..e9684a5 100644 --- a/kronfluence/factor/eigen.py +++ b/kronfluence/factor/eigen.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch import torch.distributed as dist @@ -21,6 +21,7 @@ GRADIENT_EIGENVECTORS_NAME, LAMBDA_FACTOR_NAMES, NUM_COVARIANCE_PROCESSED, + PARTITION_TYPE, ) from kronfluence.module.tracked_module import ModuleMode from kronfluence.module.utils import ( @@ -47,7 +48,7 @@ def eigendecomposition_save_path( def save_eigendecomposition( output_dir: Path, - eigen_factors: Dict[str, Dict[str, torch.Tensor]], + eigen_factors: FACTOR_TYPE, ) -> None: """Saves Eigendecomposition results to disk.""" assert set(eigen_factors.keys()) == set(EIGENDECOMPOSITION_FACTOR_NAMES) @@ -97,12 +98,12 @@ def perform_eigendecomposition( """Performs Eigendecomposition on activation and pseudo-gradient covariance matrices. Args: + covariance_factors (FACTOR_TYPE): + The covariance matrices to perform Eigendecomposition on. model (nn.Module): The model which contains modules which Eigendecomposition will be performed. state (State): The current process's information (e.g., device being used). - covariance_factors (FACTOR_TYPE): - The covariance matrices to load from. factor_args (FactorArguments): Arguments related to performing Eigendecomposition. @@ -110,7 +111,7 @@ def perform_eigendecomposition( FACTOR_TYPE: The Eigendecomposition results in CPU. The Eigendecomposition results are organized in nested dictionaries, where the first key in the name of the Eigendecomposition factor (e.g., - eigenvector), and the second key is the module name. + activation eigenvector), and the second key is the module name. """ eigen_factors: FACTOR_TYPE = {} for factor_name in EIGENDECOMPOSITION_FACTOR_NAMES: @@ -144,16 +145,12 @@ def perform_eigendecomposition( device=state.device, dtype=factor_args.eigendecomposition_dtype, ) - # In case covariance matrices are not symmetric due to numerical issues. + # Deal with cases where covariance matrices are not symmetric due to numerical issues. covariance_matrix = 0.5 * (covariance_matrix + covariance_matrix.t()) eigenvalues, eigenvectors = torch.linalg.eigh(covariance_matrix) - eigen_factors[eigenvectors_name][module_name] = ( - eigenvectors.to(dtype=original_dtype).contiguous().cpu() - ) - eigen_factors[eigenvalues_name][module_name] = eigenvalues.to( - dtype=original_dtype - ).cpu() - del eigenvectors, eigenvalues + eigen_factors[eigenvalues_name][module_name] = eigenvalues.to(dtype=original_dtype).cpu() + eigen_factors[eigenvectors_name][module_name] = eigenvectors.to(dtype=original_dtype).contiguous().cpu() + del eigenvalues, eigenvectors pbar.update(1) return eigen_factors @@ -161,7 +158,7 @@ def perform_eigendecomposition( def lambda_matrices_save_path( output_dir: Path, lambda_factor_name: str, - partition: Optional[Tuple[int, int]] = None, + partition: Optional[PARTITION_TYPE] = None, ) -> Path: """Generates the path for saving/loading Lambda matrices.""" assert lambda_factor_name in LAMBDA_FACTOR_NAMES @@ -176,8 +173,8 @@ def lambda_matrices_save_path( def save_lambda_matrices( output_dir: Path, - lambda_factors: Dict[str, Dict[str, torch.Tensor]], - partition: Optional[Tuple[int, int]] = None, + lambda_factors: FACTOR_TYPE, + partition: Optional[PARTITION_TYPE] = None, ) -> None: """Saves Lambda matrices to disk.""" assert set(lambda_factors.keys()) == set(LAMBDA_FACTOR_NAMES) @@ -192,7 +189,7 @@ def save_lambda_matrices( def load_lambda_matrices( output_dir: Path, - partition: Optional[Tuple[int, int]] = None, + partition: Optional[PARTITION_TYPE] = None, ) -> FACTOR_TYPE: """Loads Lambda matrices from disk.""" lambda_factors = {} @@ -208,7 +205,7 @@ def load_lambda_matrices( def lambda_matrices_exist( output_dir: Path, - partition: Optional[Tuple[int, int]] = None, + partition: Optional[PARTITION_TYPE] = None, ) -> bool: """Check if Lambda matrices exist at specified path.""" for name in LAMBDA_FACTOR_NAMES: @@ -254,7 +251,7 @@ def fit_lambda_matrices_with_loader( Tuple[torch.Tensor, FACTOR_TYPE]: A tuple containing the number of data points processed, and computed Lambda matrices in CPU. The Lambda matrices are organized in nested dictionaries, where the first key in the name of - the Lambda matrix and the second key is the module name. + the computed variable and the second key is the module name. """ with torch.no_grad(): update_factor_args(model=model, factor_args=factor_args) @@ -267,9 +264,7 @@ def fit_lambda_matrices_with_loader( if eigen_factors is not None: for name in eigen_factors: set_factors(model=model, factor_name=name, factors=eigen_factors[name]) - num_data_processed = torch.zeros( - (1,), dtype=torch.int64, device=state.device, requires_grad=False - ) + num_data_processed = torch.zeros((1,), dtype=torch.int64, device=state.device, requires_grad=False) with tqdm( total=len(loader), @@ -299,8 +294,6 @@ def fit_lambda_matrices_with_loader( with torch.no_grad(): saved_factors: FACTOR_TYPE = {} for covariance_factor_name in LAMBDA_FACTOR_NAMES: - saved_factors[covariance_factor_name] = load_factors( - model=model, factor_name=covariance_factor_name - ) + saved_factors[covariance_factor_name] = load_factors(model=model, factor_name=covariance_factor_name) set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) return num_data_processed, saved_factors diff --git a/kronfluence/module/constants.py b/kronfluence/module/constants.py index 40616b0..1574c2f 100644 --- a/kronfluence/module/constants.py +++ b/kronfluence/module/constants.py @@ -1,6 +1,4 @@ -""" -A collection of constants for defining `TrackedModule` storage. -""" +"""A collection of constants for defining `TrackedModule` storage.""" from typing import Dict, Tuple diff --git a/kronfluence/module/conv2d.py b/kronfluence/module/conv2d.py index fe3d0ea..1eaa8dd 100644 --- a/kronfluence/module/conv2d.py +++ b/kronfluence/module/conv2d.py @@ -47,18 +47,12 @@ def extract_patches( for k, s, d in zip(_pair(kernel_size), _pair(stride), _pair(dilation)): p_left, p_right = get_conv_paddings(k, s, padding, d) if p_left != p_right: - raise UnsupportableModuleError( - "Unequal padding not supported in unfold." - ) + raise UnsupportableModuleError("Unequal padding not supported in unfold.") padding_as_int.append(p_left) padding = tuple(padding_as_int) - inputs = rearrange( - tensor=inputs, pattern="b (g c_in) i1 i2 -> b g c_in i1 i2", g=groups - ) - inputs = reduce( - tensor=inputs, pattern="b g c_in i1 i2 -> b c_in i1 i2", reduction="mean" - ) + inputs = rearrange(tensor=inputs, pattern="b (g c_in) i1 i2 -> b g c_in i1 i2", g=groups) + inputs = reduce(tensor=inputs, pattern="b g c_in i1 i2 -> b c_in i1 i2", reduction="mean") inputs_unfold = F.unfold( input=inputs, kernel_size=kernel_size, @@ -66,9 +60,7 @@ def extract_patches( padding=padding, stride=stride, ) - return rearrange( - tensor=inputs_unfold, pattern="b c_in_k1_k2 o1_o2 -> b o1_o2 c_in_k1_k2" - ) + return rearrange(tensor=inputs_unfold, pattern="b c_in_k1_k2 o1_o2 -> b o1_o2 c_in_k1_k2") class TrackedConv2d(TrackedModule, module_type=nn.Conv2d): @@ -100,6 +92,7 @@ def _get_flattened_activation( tensor=input_activation, pattern="b o1_o2 c_in_k1_k2 -> (b o1_o2) c_in_k1_k2", ) + if self.original_module.bias is not None: flattened_activation = torch.cat( [ @@ -165,11 +158,6 @@ def _compute_per_sample_gradient( ], dim=-1, ) - - input_activation = input_activation.view( - output_gradient.size(0), -1, input_activation.size(-1) - ) - output_gradient = rearrange( - tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o" - ) + input_activation = input_activation.view(output_gradient.size(0), -1, input_activation.size(-1)) + output_gradient = rearrange(tensor=output_gradient, pattern="b o i1 i2 -> b (i1 i2) o") return torch.einsum("abm,abn->amn", (output_gradient, input_activation)) diff --git a/kronfluence/module/linear.py b/kronfluence/module/linear.py index 0a00cf2..6158e36 100644 --- a/kronfluence/module/linear.py +++ b/kronfluence/module/linear.py @@ -24,35 +24,21 @@ def _get_flattened_activation( The flattened activation tensor and the number of stacked activations. The flattened activation is a 2-dimensional matrix with dimension `activation_num x activation_dim`. """ - flattened_activation = rearrange( - tensor=input_activation, pattern="b ... d_in -> (b ...) d_in" - ) + flattened_activation = rearrange(tensor=input_activation, pattern="b ... d_in -> (b ...) d_in") flattened_attention_mask = None - if ( - self._attention_mask is not None - and flattened_activation.size(0) == self._attention_mask.numel() - ): + if self._attention_mask is not None and flattened_activation.size(0) == self._attention_mask.numel(): # If the binary attention mask is provided, zero-out appropriate activations. - flattened_attention_mask = rearrange( - tensor=self._attention_mask, pattern="b ... -> (b ...) 1" - ) + flattened_attention_mask = rearrange(tensor=self._attention_mask, pattern="b ... -> (b ...) 1") flattened_activation = flattened_activation * flattened_attention_mask if self.original_module.bias is not None: - append_term = flattened_activation.new_ones( - flattened_activation.shape[0], 1 - ) + append_term = flattened_activation.new_ones(flattened_activation.shape[0], 1) if flattened_attention_mask is not None: append_term = append_term * flattened_attention_mask - flattened_activation = torch.cat( - [flattened_activation, append_term], dim=-1 - ) - count = ( - flattened_activation.size(0) - if flattened_attention_mask is None - else flattened_attention_mask.sum() - ) + flattened_activation = torch.cat([flattened_activation, append_term], dim=-1) + + count = flattened_activation.size(0) if flattened_attention_mask is None else flattened_attention_mask.sum() return flattened_activation, count def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> torch.Tensor: @@ -92,4 +78,5 @@ def _compute_per_sample_gradient( shape = list(input_activation.shape[:-1]) + [1] append_term = input_activation.new_ones(shape, requires_grad=False) input_activation = torch.cat([input_activation, append_term], dim=-1) + return torch.einsum("b...i,b...o->boi", (input_activation, output_gradient)) diff --git a/kronfluence/module/tracked_module.py b/kronfluence/module/tracked_module.py index e329249..31f1680 100644 --- a/kronfluence/module/tracked_module.py +++ b/kronfluence/module/tracked_module.py @@ -56,9 +56,7 @@ class TrackedModule(nn.Module): SUPPORTED_MODULES: Dict[Type[nn.Module], Any] = {} - def __init_subclass__( - cls, module_type: Optional[Type[nn.Module]] = None, **kwargs - ) -> None: + def __init_subclass__(cls, module_type: Optional[Type[nn.Module]] = None, **kwargs) -> None: """Automatically registers subclasses as supported modules.""" super().__init_subclass__(**kwargs) assert module_type is not None @@ -227,14 +225,10 @@ def _get_flattened_activation( The flattened activation tensor and the number of stacked activations. The flattened activation is a 2-dimensional matrix with dimension `activation_num x activation_dim`. """ - raise NotImplementedError( - "Subclasses must implement the `_get_flattened_activation` method." - ) + raise NotImplementedError("Subclasses must implement the `_get_flattened_activation` method.") @torch.no_grad() - def _update_activation_covariance_matrix( - self, input_activation: torch.Tensor - ) -> None: + def _update_activation_covariance_matrix(self, input_activation: torch.Tensor) -> None: """Updates the activation covariance matrix. Args: @@ -242,9 +236,7 @@ def _update_activation_covariance_matrix( The input tensor to the module, provided by the PyTorch's forward hook. """ flattened_activation, count = self._get_flattened_activation(input_activation) - flattened_activation = flattened_activation.to( - dtype=self.factor_args.activation_covariance_dtype - ) + flattened_activation = flattened_activation.to(dtype=self.factor_args.activation_covariance_dtype) if self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME] is None: dimension = flattened_activation.size(1) @@ -255,9 +247,7 @@ def _update_activation_covariance_matrix( requires_grad=False, ) # Add the current batch's activation covariance to the stored activation covariance matrix. - self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME].addmm_( - flattened_activation.t(), flattened_activation - ) + self._storage[ACTIVATION_COVARIANCE_MATRIX_NAME].addmm_(flattened_activation.t(), flattened_activation) if self._storage[NUM_COVARIANCE_PROCESSED] is None: self._storage[NUM_COVARIANCE_PROCESSED] = torch.zeros( @@ -283,9 +273,7 @@ def _get_flattened_gradient(self, output_gradient: torch.Tensor) -> torch.Tensor The flattened output gradient tensor. The flattened gradient is a 2-dimensional matrix with dimension `gradient_num x gradient_dim`. """ - raise NotImplementedError( - "Subclasses must implement the `_get_flattened_gradient` method." - ) + raise NotImplementedError("Subclasses must implement the `_get_flattened_gradient` method.") @torch.no_grad() def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> None: @@ -297,9 +285,7 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N PyTorch's backward hook. """ flattened_gradient = self._get_flattened_gradient(output_gradient) - flattened_gradient = flattened_gradient.to( - dtype=self.factor_args.gradient_covariance_dtype - ) + flattened_gradient = flattened_gradient.to(dtype=self.factor_args.gradient_covariance_dtype) if self._storage[GRADIENT_COVARIANCE_MATRIX_NAME] is None: # Initialize pseudo-gradient covariance matrix if it does not exist. @@ -311,16 +297,12 @@ def _update_gradient_covariance_matrix(self, output_gradient: torch.Tensor) -> N requires_grad=False, ) # Add the current batch's pseudo-gradient covariance to the stored pseudo-gradient covariance matrix. - self._storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_( - flattened_gradient.t(), flattened_gradient - ) + self._storage[GRADIENT_COVARIANCE_MATRIX_NAME].addmm_(flattened_gradient.t(), flattened_gradient) def _register_covariance_hooks(self) -> None: """Installs forward and backward hooks for computation of the covariance matrices.""" - def forward_hook( - module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor] - ) -> None: + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None: del module # Compute and update activation covariance matrix in the forward pass. self._update_activation_covariance_matrix(inputs[0].detach()) @@ -332,14 +314,10 @@ def backward_hook(output_gradient: torch.Tensor) -> None: # Compute and update pseudo-gradient covariance matrix in the backward pass. self._update_gradient_covariance_matrix(output_gradient.detach()) - self._registered_hooks.append( - self.original_module.register_forward_hook(forward_hook) - ) + self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append( - self.register_full_backward_hook(full_backward_gradient_removal_hook) - ) + self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook)) def _release_covariance_matrices(self) -> None: """Clears the stored activation and pseudo-gradient covariance matrices from memory.""" @@ -399,9 +377,7 @@ def _compute_per_sample_gradient( with dimension `batch_size x input_dim x gradient_dim`. An additional dimension is added when the bias term is used. """ - raise NotImplementedError( - "Subclasses must implement the `_compute_per_sample_gradient` method." - ) + raise NotImplementedError("Subclasses must implement the `_compute_per_sample_gradient` method.") @torch.no_grad() def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: @@ -422,9 +398,7 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: requires_grad=False, ) - if FactorConfig.CONFIGS[ - self.factor_args.strategy - ].requires_eigendecomposition_for_lambda: + if FactorConfig.CONFIGS[self.factor_args.strategy].requires_eigendecomposition_for_lambda: if not self._eigendecomposition_results_available(): error_msg = ( f"The strategy {self.factor_args.strategy} requires Eigendecomposition " @@ -434,15 +408,11 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: raise FactorsNotFoundError(error_msg) # Move activation and pseudo-gradient eigenvectors to appropriate devices. - self._storage[ACTIVATION_EIGENVECTORS_NAME] = self._storage[ - ACTIVATION_EIGENVECTORS_NAME - ].to( + self._storage[ACTIVATION_EIGENVECTORS_NAME] = self._storage[ACTIVATION_EIGENVECTORS_NAME].to( dtype=self.factor_args.lambda_dtype, device=per_sample_gradient.device, ) - self._storage[GRADIENT_EIGENVECTORS_NAME] = self._storage[ - GRADIENT_EIGENVECTORS_NAME - ].to( + self._storage[GRADIENT_EIGENVECTORS_NAME] = self._storage[GRADIENT_EIGENVECTORS_NAME].to( dtype=self.factor_args.lambda_dtype, device=per_sample_gradient.device, ) @@ -455,9 +425,7 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: requires_grad=False, ) - if FactorConfig.CONFIGS[ - self.factor_args.strategy - ].requires_eigendecomposition_for_lambda: + if FactorConfig.CONFIGS[self.factor_args.strategy].requires_eigendecomposition_for_lambda: if self.factor_args.lambda_iterative_aggregate: # This batch-wise iterative update can be useful when the GPU memory is limited. rotated_gradient = torch.matmul( @@ -482,18 +450,14 @@ def _update_lambda_matrix(self, per_sample_gradient: torch.Tensor) -> None: del per_sample_gradient self._storage[LAMBDA_MATRIX_NAME].add_(sqrt_lambda.square_().sum(dim=0)) else: - self._storage[LAMBDA_MATRIX_NAME].add_( - per_sample_gradient.square_().sum(dim=0) - ) + self._storage[LAMBDA_MATRIX_NAME].add_(per_sample_gradient.square_().sum(dim=0)) self._storage[NUM_LAMBDA_PROCESSED].add_(batch_size) def _register_lambda_hooks(self) -> None: """Installs forward and backward hooks for computation of the Lambda matrices.""" - def forward_hook( - module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor] - ) -> None: + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None: del module cached_activation = inputs[0].detach() if self.factor_args.cached_activation_cpu_offload: @@ -509,12 +473,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None: if self.factor_args.cached_activation_cpu_offload: cached_activation = cached_activation.to(device=output_gradient.device) per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to( - dtype=self.factor_args.lambda_dtype - ), - output_gradient=output_gradient.detach().to( - dtype=self.factor_args.lambda_dtype - ), + input_activation=cached_activation.to(dtype=self.factor_args.lambda_dtype), + output_gradient=output_gradient.detach().to(dtype=self.factor_args.lambda_dtype), ) del cached_activation, output_gradient @@ -526,19 +486,13 @@ def backward_hook(output_gradient: torch.Tensor) -> None: # If the module was used multiple times throughout the forward pass, # only compute the Lambda matrix after aggregating all per-sample-gradients. if len(self._cached_activations) == 0: - self._update_lambda_matrix( - per_sample_gradient=self._cached_per_sample_gradient - ) + self._update_lambda_matrix(per_sample_gradient=self._cached_per_sample_gradient) self._cached_per_sample_gradient = None - self._registered_hooks.append( - self.original_module.register_forward_hook(forward_hook) - ) + self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append( - self.register_full_backward_hook(full_backward_gradient_removal_hook) - ) + self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook)) def _release_lambda_matrix(self) -> None: """Clears the stored Lambda matrix from memory.""" @@ -602,9 +556,7 @@ def _compute_low_rank_preconditioned_gradient( def _register_precondition_gradient_hooks(self) -> None: """Installs forward and backward hooks for computation of preconditioned per-sample-gradient.""" - def forward_hook( - module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor] - ) -> None: + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None: del module cached_activation = inputs[0].detach() if self.score_args.cached_activation_cpu_offload: @@ -620,12 +572,8 @@ def backward_hook(output_gradient: torch.Tensor) -> None: if self.score_args.cached_activation_cpu_offload: cached_activation = cached_activation.to(device=output_gradient.device) per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to( - dtype=self.score_args.per_sample_gradient_dtype - ), - output_gradient=output_gradient.detach().to( - dtype=self.score_args.per_sample_gradient_dtype - ), + input_activation=cached_activation.to(dtype=self.score_args.per_sample_gradient_dtype), + output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), ) del cached_activation, output_gradient @@ -638,42 +586,29 @@ def backward_hook(output_gradient: torch.Tensor) -> None: # If the module was used multiple times throughout the forward pass, # only perform preconditioning after aggregating all per-sample-gradients. if len(self._cached_activations) == 0: - preconditioned_gradient = FactorConfig.CONFIGS[ - self.factor_args.strategy - ].precondition_gradient( - gradient=self._cached_per_sample_gradient.to( - dtype=self.score_args.precondition_dtype - ), + preconditioned_gradient = FactorConfig.CONFIGS[self.factor_args.strategy].precondition_gradient( + gradient=self._cached_per_sample_gradient.to(dtype=self.score_args.precondition_dtype), storage=self._storage, damping=self.score_args.damping, ) self._cached_per_sample_gradient = None - preconditioned_gradient = preconditioned_gradient.to( - dtype=self.score_args.score_dtype - ) + preconditioned_gradient = preconditioned_gradient.to(dtype=self.score_args.score_dtype) if ( self.score_args.query_gradient_rank is not None - and min(preconditioned_gradient.size()[1:]) - > self.score_args.query_gradient_rank + and min(preconditioned_gradient.size()[1:]) > self.score_args.query_gradient_rank ): # Apply low-rank approximation to the preconditioned gradient. - preconditioned_gradient = ( - self._compute_low_rank_preconditioned_gradient( - preconditioned_gradient=preconditioned_gradient - ) + preconditioned_gradient = self._compute_low_rank_preconditioned_gradient( + preconditioned_gradient=preconditioned_gradient ) self._storage[PRECONDITIONED_GRADIENT_NAME] = preconditioned_gradient - self._registered_hooks.append( - self.original_module.register_forward_hook(forward_hook) - ) + self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append( - self.register_full_backward_hook(full_backward_gradient_removal_hook) - ) + self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook)) def _release_preconditioned_gradient(self) -> None: """Clears the preconditioned per-sample-gradient from memory.""" @@ -699,9 +634,7 @@ def truncate_preconditioned_gradient(self, keep_size: int) -> None: self._storage[PRECONDITIONED_GRADIENT_NAME][1][:keep_size], ] else: - self._storage[PRECONDITIONED_GRADIENT_NAME] = self._storage[ - PRECONDITIONED_GRADIENT_NAME - ][:keep_size] + self._storage[PRECONDITIONED_GRADIENT_NAME] = self._storage[PRECONDITIONED_GRADIENT_NAME][:keep_size] def synchronize_preconditioned_gradient(self, num_processes: int) -> None: """Stacks preconditioned gradient across multiple devices or nodes in a distributed setting.""" @@ -716,14 +649,10 @@ def synchronize_preconditioned_gradient(self, num_processes: int) -> None: ) torch.distributed.all_gather_into_tensor( output_tensor=stacked_matrix, - input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME][ - i - ].contiguous(), + input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME][i].contiguous(), ) - self._storage[PRECONDITIONED_GRADIENT_NAME][i] = ( - stacked_matrix.transpose( - 0, 1 - ).reshape(num_processes * size[0], size[1], size[2]) + self._storage[PRECONDITIONED_GRADIENT_NAME][i] = stacked_matrix.transpose(0, 1).reshape( + num_processes * size[0], size[1], size[2] ) else: @@ -735,14 +664,10 @@ def synchronize_preconditioned_gradient(self, num_processes: int) -> None: ) torch.distributed.all_gather_into_tensor( output_tensor=stacked_preconditioned_gradient, - input_tensor=self._storage[ - PRECONDITIONED_GRADIENT_NAME - ].contiguous(), + input_tensor=self._storage[PRECONDITIONED_GRADIENT_NAME].contiguous(), ) - self._storage[PRECONDITIONED_GRADIENT_NAME] = ( - stacked_preconditioned_gradient.transpose( - 0, 1 - ).reshape(num_processes * size[0], size[1], size[2]) + self._storage[PRECONDITIONED_GRADIENT_NAME] = stacked_preconditioned_gradient.transpose(0, 1).reshape( + num_processes * size[0], size[1], size[2] ) ########################################### @@ -751,9 +676,7 @@ def synchronize_preconditioned_gradient(self, num_processes: int) -> None: def _register_pairwise_score_hooks(self) -> None: """Installs forward and backward hooks for computation of pairwise influence scores.""" - def forward_hook( - module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor] - ) -> None: + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None: del module cached_activation = inputs[0].detach() if self.score_args.cached_activation_cpu_offload: @@ -767,12 +690,8 @@ def forward_hook( def backward_hook(output_gradient: torch.Tensor) -> None: cached_activation = self._cached_activations.pop() per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to( - dtype=self.score_args.per_sample_gradient_dtype - ), - output_gradient=output_gradient.detach().to( - dtype=self.score_args.per_sample_gradient_dtype - ), + input_activation=cached_activation.to(dtype=self.score_args.per_sample_gradient_dtype), + output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), ) del cached_activation, output_gradient @@ -806,21 +725,15 @@ def backward_hook(output_gradient: torch.Tensor) -> None: self._storage[PAIRWISE_SCORE_MATRIX_NAME] = scores self._cached_per_sample_gradient = None - self._registered_hooks.append( - self.original_module.register_forward_hook(forward_hook) - ) + self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append( - self.register_full_backward_hook(full_backward_gradient_removal_hook) - ) + self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook)) def _register_self_score_hooks(self) -> None: """Installs forward and backward hooks for computation of self-influence scores.""" - def forward_hook( - module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor] - ) -> None: + def forward_hook(module: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]) -> None: del module cached_activation = inputs[0].detach() if self.score_args.cached_activation_cpu_offload: @@ -834,12 +747,8 @@ def forward_hook( def backward_hook(output_gradient: torch.Tensor) -> None: cached_activation = self._cached_activations.pop() per_sample_gradient = self._compute_per_sample_gradient( - input_activation=cached_activation.to( - dtype=self.score_args.per_sample_gradient_dtype - ), - output_gradient=output_gradient.detach().to( - dtype=self.score_args.per_sample_gradient_dtype - ), + input_activation=cached_activation.to(dtype=self.score_args.per_sample_gradient_dtype), + output_gradient=output_gradient.detach().to(dtype=self.score_args.per_sample_gradient_dtype), ) del cached_activation, output_gradient @@ -861,35 +770,23 @@ def backward_hook(output_gradient: torch.Tensor) -> None: # If the module was used multiple times throughout the forward pass, # only compute scores after aggregating all per-sample-gradients. if len(self._cached_activations) == 0: - preconditioned_gradient = FactorConfig.CONFIGS[ - self.factor_args.strategy - ].precondition_gradient( - gradient=self._cached_per_sample_gradient.to( - dtype=self.score_args.precondition_dtype - ), + preconditioned_gradient = FactorConfig.CONFIGS[self.factor_args.strategy].precondition_gradient( + gradient=self._cached_per_sample_gradient.to(dtype=self.score_args.precondition_dtype), storage=self._storage, damping=self.score_args.damping, ) - preconditioned_gradient = preconditioned_gradient.to( - dtype=self.score_args.score_dtype - ) + preconditioned_gradient = preconditioned_gradient.to(dtype=self.score_args.score_dtype) self._cached_per_sample_gradient = self._cached_per_sample_gradient.to( dtype=self.score_args.score_dtype ) preconditioned_gradient.mul_(self._cached_per_sample_gradient) - self._storage[SELF_SCORE_VECTOR_NAME] = preconditioned_gradient.sum( - dim=(1, 2) - ) + self._storage[SELF_SCORE_VECTOR_NAME] = preconditioned_gradient.sum(dim=(1, 2)) self._cached_per_sample_gradient = None - self._registered_hooks.append( - self.original_module.register_forward_hook(forward_hook) - ) + self._registered_hooks.append(self.original_module.register_forward_hook(forward_hook)) if self.factor_args.immediate_gradient_removal: - self._registered_hooks.append( - self.register_full_backward_hook(full_backward_gradient_removal_hook) - ) + self._registered_hooks.append(self.register_full_backward_hook(full_backward_gradient_removal_hook)) def release_scores(self) -> None: """Clears the influence scores from memory.""" diff --git a/kronfluence/module/utils.py b/kronfluence/module/utils.py index 3e5ad03..08c5d30 100644 --- a/kronfluence/module/utils.py +++ b/kronfluence/module/utils.py @@ -45,7 +45,7 @@ def wrap_tracked_modules( if isinstance(model, (DP, DDP, FSDP)): raise ValueError( "The model is wrapped with DataParallel, DistributedDataParallel " - "or FullyShardedDataParallel. Install tracked modules before wrapping the model." + "or FullyShardedDataParallel. Call `wrap_tracked_modules` before wrapping the model." ) tracked_module_count = 0 @@ -73,12 +73,10 @@ def wrap_tracked_modules( tracked_module_count += 1 if tracked_module_count == 0: - supported_modules_names = [ - module.__name__ for module in TrackedModule.SUPPORTED_MODULES - ] + supported_modules_names = [module.__name__ for module in TrackedModule.SUPPORTED_MODULES] error_msg = ( f"Kronfluence currently supports modules in {supported_modules_names}. " - "However, these modules were not found in the provided model. If you wish to analyze " + f"However, these modules were not found in the provided model. If you want to analyze " "custom layers, consider rewriting your model to use the supported modules, " "or define your own custom module by subclassing `TrackedModule`." ) @@ -87,14 +85,11 @@ def wrap_tracked_modules( return model -def make_modules_partition( - total_module_names: List[str], partition_size: int -) -> List[List[str]]: +def make_modules_partition(total_module_names: List[str], partition_size: int) -> List[List[str]]: """Divides a list of module names into smaller partitions of a specified size.""" div, mod = divmod(len(total_module_names), partition_size) return list( - total_module_names[i * div + min(i, mod) : (i + 1) * div + min(i + 1, mod)] - for i in range(partition_size) + total_module_names[i * div + min(i, mod) : (i + 1) * div + min(i + 1, mod)] for i in range(partition_size) ) @@ -131,9 +126,7 @@ def get_tracked_named_modules(model: nn.Module) -> List[Tuple[str, TrackedModule if isinstance(module, TrackedModule): tracked_modules.append((module.name, module)) if len(tracked_modules) == 0: - raise TrackedModuleNotFoundError( - "Tracked modules not found when trying get tracked named modules." - ) + raise TrackedModuleNotFoundError("Tracked modules not found when trying get tracked named modules.") return tracked_modules @@ -144,9 +137,7 @@ def get_tracked_module_names(model: nn.Module) -> List[str]: if isinstance(module, TrackedModule): tracked_modules.append(module.name) if len(tracked_modules) == 0: - raise TrackedModuleNotFoundError( - "Tracked modules not found when trying get tracked module names." - ) + raise TrackedModuleNotFoundError("Tracked modules not found when trying get tracked module names.") return tracked_modules @@ -158,9 +149,7 @@ def synchronize_covariance_matrices(model: nn.Module) -> None: module.synchronize_covariance_matrices() tracked_module_count += 1 if tracked_module_count == 0: - raise TrackedModuleNotFoundError( - "Tracked modules not found when trying to synchronize covariance matrices." - ) + raise TrackedModuleNotFoundError("Tracked modules not found when trying to synchronize covariance matrices.") def synchronize_lambda_matrices(model: nn.Module) -> None: @@ -171,9 +160,7 @@ def synchronize_lambda_matrices(model: nn.Module) -> None: module.synchronize_lambda_matrices() tracked_module_count += 1 if tracked_module_count == 0: - raise TrackedModuleNotFoundError( - "Tracked modules not found when trying to synchronize lambda matrices." - ) + raise TrackedModuleNotFoundError("Tracked modules not found when trying to synchronize lambda matrices.") def get_preconditioned_gradient_batch_size(model: nn.Module) -> Optional[int]: @@ -194,9 +181,7 @@ def truncate_preconditioned_gradient(model: nn.Module, keep_size: int) -> None: module.truncate_preconditioned_gradient(keep_size=keep_size) tracked_module_count += 1 if tracked_module_count == 0: - raise TrackedModuleNotFoundError( - "Tracked modules not found when trying to truncate preconditioned gradient." - ) + raise TrackedModuleNotFoundError("Tracked modules not found when trying to truncate preconditioned gradient.") def synchronize_preconditioned_gradient(model: nn.Module, num_processes: int) -> None: @@ -220,9 +205,7 @@ def release_scores(model: nn.Module) -> None: module.release_scores() tracked_module_count += 1 if tracked_module_count == 0: - raise TrackedModuleNotFoundError( - "Tracked modules not found when trying to release scores." - ) + raise TrackedModuleNotFoundError("Tracked modules not found when trying to release scores.") def set_mode( @@ -246,10 +229,7 @@ def set_mode( """ for module in model.modules(): if isinstance(module, TrackedModule): - if ( - tracked_module_names is not None - and module.name not in tracked_module_names - ): + if tracked_module_names is not None and module.name not in tracked_module_names: continue module.set_mode(mode=mode, keep_factors=keep_factors) @@ -273,9 +253,7 @@ def load_factors( return loaded_factors -def set_factors( - model: nn.Module, factor_name: str, factors: Dict[str, torch.Tensor] -) -> None: +def set_factors(model: nn.Module, factor_name: str, factors: Dict[str, torch.Tensor]) -> None: """Sets new factor for all `TrackedModule` instances within a model.""" tracked_module_count = 0 for module in model.modules(): @@ -283,9 +261,7 @@ def set_factors( module.set_factor(factor_name=factor_name, factor=factors[module.name]) tracked_module_count += 1 if tracked_module_count == 0: - raise TrackedModuleNotFoundError( - f"Tracked modules not found when trying to set factor with name {factors}." - ) + raise TrackedModuleNotFoundError(f"Tracked modules not found when trying to set factor with name {factors}.") def set_attention_mask( @@ -298,18 +274,14 @@ def set_attention_mask( if isinstance(module, TrackedModule): if isinstance(attention_mask, dict): if module.name in attention_mask: - module.set_attention_mask( - attention_mask=attention_mask[module.name] - ) + module.set_attention_mask(attention_mask=attention_mask[module.name]) else: module.set_attention_mask(attention_mask=None) else: module.set_attention_mask(attention_mask=attention_mask) tracked_module_count += 1 if tracked_module_count == 0: - raise TrackedModuleNotFoundError( - "Tracked modules not found when trying to set `attention_mask`." - ) + raise TrackedModuleNotFoundError("Tracked modules not found when trying to set `attention_mask`.") def remove_attention_mask(model: nn.Module) -> None: @@ -320,6 +292,4 @@ def remove_attention_mask(model: nn.Module) -> None: module.remove_attention_mask() tracked_module_count += 1 if tracked_module_count == 0: - raise TrackedModuleNotFoundError( - "Tracked modules not found when trying to remove `attention_mask`." - ) + raise TrackedModuleNotFoundError("Tracked modules not found when trying to remove `attention_mask`.") diff --git a/kronfluence/score/pairwise.py b/kronfluence/score/pairwise.py index 241bc05..e1abd55 100644 --- a/kronfluence/score/pairwise.py +++ b/kronfluence/score/pairwise.py @@ -41,8 +41,7 @@ def pairwise_scores_save_path( if partition is not None: data_partition, module_partition = partition return output_dir / ( - f"pairwise_scores_data_partition{data_partition}" - f"_module_partition{module_partition}.safetensors" + f"pairwise_scores_data_partition{data_partition}" f"_module_partition{module_partition}.safetensors" ) return output_dir / "pairwise_scores.safetensors" @@ -99,10 +98,7 @@ def _compute_pairwise_dot_products_with_loader( score_chunks: Dict[str, List[torch.Tensor]] = {} if score_args.per_module_score: for module in model.modules(): - if ( - isinstance(module, TrackedModule) - and module.name in tracked_module_names - ): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: score_chunks[module.name] = [] else: score_chunks[ALL_MODULE_NAME] = [] @@ -137,14 +133,9 @@ def _compute_pairwise_dot_products_with_loader( with torch.no_grad(): if score_args.per_module_score: for module in model.modules(): - if ( - isinstance(module, TrackedModule) - and module.name in tracked_module_names - ): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: score_chunks[module.name].append( - module.get_factor( - factor_name=PAIRWISE_SCORE_MATRIX_NAME - ).cpu() + module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME).cpu() ) else: # Aggregate the pairwise scores across all modules. @@ -156,15 +147,8 @@ def _compute_pairwise_dot_products_with_loader( requires_grad=False, ) for module in model.modules(): - if ( - isinstance(module, TrackedModule) - and module.name in tracked_module_names - ): - pairwise_scores.add_( - module.get_factor( - factor_name=PAIRWISE_SCORE_MATRIX_NAME - ) - ) + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + pairwise_scores.add_(module.get_factor(factor_name=PAIRWISE_SCORE_MATRIX_NAME)) # `.cpu()` synchronizes the CUDA stream. score_chunks[ALL_MODULE_NAME].append(pairwise_scores.cpu()) release_scores(model=model) @@ -299,15 +283,11 @@ def compute_pairwise_scores_with_loaders( if state.use_distributed: # Stack preconditioned query gradient across multiple devices or nodes. - synchronize_preconditioned_gradient( - model=model, num_processes=state.num_processes - ) + synchronize_preconditioned_gradient(model=model, num_processes=state.num_processes) if query_index == len(query_loader) - 1 and query_remainder > 0: # Remove duplicate data points if the dataset is not exactly divisible # by the current batch size. - truncate_preconditioned_gradient( - model=model, keep_size=query_remainder - ) + truncate_preconditioned_gradient(model=model, keep_size=query_remainder) # Compute the dot product between preconditioning query gradient and all training gradients. release_memory() @@ -331,8 +311,6 @@ def compute_pairwise_scores_with_loaders( set_mode(model=model, mode=ModuleMode.DEFAULT, keep_factors=False) for module_name in total_scores_chunks: - total_scores_chunks[module_name] = torch.cat( - total_scores_chunks[module_name], dim=0 - ) + total_scores_chunks[module_name] = torch.cat(total_scores_chunks[module_name], dim=0) return total_scores_chunks diff --git a/kronfluence/score/self.py b/kronfluence/score/self.py index 19f5301..062069b 100644 --- a/kronfluence/score/self.py +++ b/kronfluence/score/self.py @@ -38,8 +38,7 @@ def self_scores_save_path( if partition is not None: data_partition, module_partition = partition return output_dir / ( - f"self_scores_data_partition{data_partition}" - f"_module_partition{module_partition}.safetensors" + f"self_scores_data_partition{data_partition}" f"_module_partition{module_partition}.safetensors" ) return output_dir / "self_scores.safetensors" @@ -148,10 +147,7 @@ def compute_self_scores_with_loaders( score_chunks: Dict[str, List[torch.Tensor]] = {} if score_args.per_module_score: for module in model.modules(): - if ( - isinstance(module, TrackedModule) - and module.name in tracked_module_names - ): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: score_chunks[module.name] = [] else: score_chunks[ALL_MODULE_NAME] = [] @@ -180,14 +176,9 @@ def compute_self_scores_with_loaders( with torch.no_grad(): if score_args.per_module_score: for module in model.modules(): - if ( - isinstance(module, TrackedModule) - and module.name in tracked_module_names - ): + if isinstance(module, TrackedModule) and module.name in tracked_module_names: score_chunks[module.name].append( - module.get_factor( - factor_name=SELF_SCORE_VECTOR_NAME - ).cpu() + module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME).cpu() ) else: # Aggregate the self-influence scores across all modules. @@ -199,13 +190,8 @@ def compute_self_scores_with_loaders( requires_grad=False, ) for module in model.modules(): - if ( - isinstance(module, TrackedModule) - and module.name in tracked_module_names - ): - self_scores.add_( - module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME) - ) + if isinstance(module, TrackedModule) and module.name in tracked_module_names: + self_scores.add_(module.get_factor(factor_name=SELF_SCORE_VECTOR_NAME)) # `.cpu()` synchronizes the CUDA stream. score_chunks[ALL_MODULE_NAME].append(self_scores.cpu()) release_scores(model=model) diff --git a/kronfluence/task.py b/kronfluence/task.py index 7fb2778..90e6947 100644 --- a/kronfluence/task.py +++ b/kronfluence/task.py @@ -35,9 +35,7 @@ def compute_train_loss( torch.Tensor: The computed loss as a tensor. """ - raise NotImplementedError( - "Subclasses must implement the `compute_train_loss` method." - ) + raise NotImplementedError("Subclasses must implement the `compute_train_loss` method.") @abstractmethod def compute_measurement( @@ -58,9 +56,7 @@ def compute_measurement( torch.Tensor: The measurable quantity as a tensor. """ - raise NotImplementedError( - "Subclasses must implement the `compute_measurement` method." - ) + raise NotImplementedError("Subclasses must implement the `compute_measurement` method.") def influence_modules(self) -> Optional[List[str]]: """Specifies modules for preconditioning factors and influence scores computation. @@ -75,9 +71,7 @@ def influence_modules(self) -> Optional[List[str]]: influence functions should be computed for all applicable modules. """ - def get_attention_mask( - self, batch: Any - ) -> Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]: + def get_attention_mask(self, batch: Any) -> Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]: """Returns masks for data points within a batch that have been padded extra sequences to ensure consistent length across the batch. Typically, it returns None for models or datasets not requiring masking. diff --git a/kronfluence/utils/dataset.py b/kronfluence/utils/dataset.py index 3a36583..6557af6 100644 --- a/kronfluence/utils/dataset.py +++ b/kronfluence/utils/dataset.py @@ -29,9 +29,7 @@ class DataLoaderKwargs(KwargsHandler): pin_memory_device: str = "" -def make_indices_partition( - total_data_examples: int, partition_size: int -) -> List[Tuple[int, int]]: +def make_indices_partition(total_data_examples: int, partition_size: int) -> List[Tuple[int, int]]: """Returns partitioned indices from the total data examples.""" bins = list(map(len, np.array_split(range(total_data_examples), partition_size))) indices_bin = [] @@ -87,9 +85,7 @@ def __init__( # pylint: disable=super-init-not-called raise RuntimeError("Requires distributed package to be available.") rank = dist.get_rank() if rank >= num_replicas or rank < 0: - raise ValueError( - f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]." - ) + raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}].") self.dataset = dataset self.num_replicas = num_replicas @@ -130,9 +126,7 @@ def __init__( # pylint: disable=super-init-not-called raise RuntimeError("Requires distributed package to be available.") rank = dist.get_rank() if rank >= num_replicas or rank < 0: - raise ValueError( - f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]." - ) + raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}].") self.dataset = dataset self.num_replicas = num_replicas diff --git a/kronfluence/utils/logger.py b/kronfluence/utils/logger.py index 6f4d150..8a5ccd9 100644 --- a/kronfluence/utils/logger.py +++ b/kronfluence/utils/logger.py @@ -13,8 +13,7 @@ from kronfluence.utils.state import State TQDM_BAR_FORMAT = ( - "{desc} [{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} " - "[time left: {remaining}, time spent: {elapsed}]" + "{desc} [{n_fmt}/{total_fmt}] {percentage:3.0f}%|{bar}{postfix} " "[time left: {remaining}, time spent: {elapsed}]" ) @@ -82,9 +81,7 @@ def start(self, action_name: str) -> None: if self.local_rank != 0: pass if action_name in self.current_actions: - raise ValueError( - f"Attempted to start {action_name} which has already started." - ) + raise ValueError(f"Attempted to start {action_name} which has already started.") self.current_actions[action_name] = _get_monotonic_time() def stop(self, action_name: str) -> None: @@ -93,10 +90,7 @@ def stop(self, action_name: str) -> None: pass end_time = _get_monotonic_time() if action_name not in self.current_actions: - raise ValueError( - f"Attempting to stop recording an action " - f"({action_name}) which was never started." - ) + raise ValueError(f"Attempting to stop recording an action " f"({action_name}) which was never started.") start_time = self.current_actions.pop(action_name) duration = end_time - start_time self.recorded_durations[action_name].append(duration) @@ -186,7 +180,7 @@ def log_row(action, mean, std, num_calls, total, per): class PassThroughProfiler(Profiler): - """A dummy Profiler objective.""" + """A pass through Profiler objective.""" def start(self, action_name: str) -> None: pass diff --git a/kronfluence/utils/save.py b/kronfluence/utils/save.py index c8c5a2d..589d892 100644 --- a/kronfluence/utils/save.py +++ b/kronfluence/utils/save.py @@ -34,9 +34,7 @@ def load_json(path: Path) -> Dict[str, Any]: return obj -def verify_models_equivalence( - state_dict1: Dict[str, torch.Tensor], state_dict2: Dict[str, torch.Tensor] -) -> bool: +def verify_models_equivalence(state_dict1: Dict[str, torch.Tensor], state_dict2: Dict[str, torch.Tensor]) -> bool: """Checks if two models are equivalent given their `state_dict`.""" if len(state_dict1) != len(state_dict2): return False diff --git a/kronfluence/utils/state.py b/kronfluence/utils/state.py index 1f6cfc9..a432607 100644 --- a/kronfluence/utils/state.py +++ b/kronfluence/utils/state.py @@ -34,11 +34,7 @@ def __init__(self, cpu: bool = False) -> None: if not self.initialized: self.cpu = cpu - if ( - int(os.environ.get("LOCAL_RANK", -1)) != -1 - and not cpu - and torch.cuda.is_available() - ): + if int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu and torch.cuda.is_available(): if not dist.is_initialized(): dist.init_process_group(backend="nccl") self.num_processes = torch.distributed.get_world_size() diff --git a/pyproject.toml b/pyproject.toml index 7e75f6b..d0b22bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,8 +5,9 @@ profile = "black" line-length = 120 [tool.ruff.format] -quote-style = "single" -indent-style = "tab" +quote-style = "double" +skip-magic-trailing-comma = false +line-ending = "auto" docstring-code-format = true [tool.pylint.format] diff --git a/setup.py b/setup.py index a2db2e0..631e5e1 100644 --- a/setup.py +++ b/setup.py @@ -21,8 +21,7 @@ setup( name="kronfluence", version="0.0.1", - author="The kronfluence Team", - description="Influence Function computations with (Eigenvalue-corrected) Kronecker Factorization", + description="Influence Functions with (Eigenvalue-corrected) Kronecker-factored Approximate Curvature", long_description=long_description, long_description_content_type="text/markdown", license="Apache-2.0", @@ -36,6 +35,8 @@ "PyTorch", "Training Data Attribution", "Influence Functions", + "KFAC", + "EKFAC", ], classifiers=[ "Development Status :: 3 - Alpha", diff --git a/tests/factors/test_covariances.py b/tests/factors/test_covariances.py index e1798d9..5ae8463 100644 --- a/tests/factors/test_covariances.py +++ b/tests/factors/test_covariances.py @@ -17,9 +17,7 @@ from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence, prepare_test -def prepare_model_and_analyzer( - model: nn.Module, task: Task -) -> Tuple[nn.Module, Analyzer]: +def prepare_model_and_analyzer(model: nn.Module, task: Task) -> Tuple[nn.Module, Analyzer]: model = prepare_model(model=model, task=task) analyzer = Analyzer( analysis_name=f"pytest_{__name__}", @@ -83,14 +81,8 @@ def test_fit_covariance_matrices( assert set(covariance_factors.keys()) == set(COVARIANCE_FACTOR_NAMES) assert len(covariance_factors[ACTIVATION_COVARIANCE_MATRIX_NAME]) > 0 for module_name in covariance_factors[ACTIVATION_COVARIANCE_MATRIX_NAME]: - assert ( - covariance_factors[ACTIVATION_COVARIANCE_MATRIX_NAME][module_name].dtype - == activation_covariance_dtype - ) - assert ( - covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME][module_name].dtype - == gradient_covariance_dtype - ) + assert covariance_factors[ACTIVATION_COVARIANCE_MATRIX_NAME][module_name].dtype == activation_covariance_dtype + assert covariance_factors[GRADIENT_COVARIANCE_MATRIX_NAME][module_name].dtype == gradient_covariance_dtype @pytest.mark.parametrize( @@ -130,9 +122,7 @@ def test_covariance_matrices_batch_size_equivalence( overwrite_output_dir=True, dataloader_kwargs=kwargs, ) - bs1_covariance_factors = analyzer.load_covariance_matrices( - factors_name=f"pytest_{test_name}_bs1" - ) + bs1_covariance_factors = analyzer.load_covariance_matrices(factors_name=f"pytest_{test_name}_bs1") analyzer.fit_covariance_matrices( factors_name=f"pytest_{test_name}_bs8", @@ -142,9 +132,7 @@ def test_covariance_matrices_batch_size_equivalence( overwrite_output_dir=True, dataloader_kwargs=kwargs, ) - bs8_covariance_factors = analyzer.load_covariance_matrices( - factors_name=f"pytest_{test_name}_bs8" - ) + bs8_covariance_factors = analyzer.load_covariance_matrices(factors_name=f"pytest_{test_name}_bs8") for name in COVARIANCE_FACTOR_NAMES: assert check_tensor_dict_equivalence( @@ -188,9 +176,7 @@ def test_covariance_matrices_partition_equivalence( factor_args = FactorArguments( use_empirical_fisher=True, ) - factors_name = ( - f"pytest_{test_name}_{test_covariance_matrices_partition_equivalence.__name__}" - ) + factors_name = f"pytest_{test_name}_{test_covariance_matrices_partition_equivalence.__name__}" analyzer.fit_covariance_matrices( factors_name=factors_name, dataset=train_dataset, @@ -261,9 +247,7 @@ def test_covariance_matrices_attention_mask( factor_args = FactorArguments( use_empirical_fisher=True, ) - factors_name = ( - f"pytest_{test_name}_{test_covariance_matrices_attention_mask.__name__}" - ) + factors_name = f"pytest_{test_name}_{test_covariance_matrices_attention_mask.__name__}" analyzer.fit_covariance_matrices( factors_name=factors_name, dataset=train_dataset, @@ -329,9 +313,7 @@ def test_covariance_matrices_automatic_batch_size( use_empirical_fisher=True, immediate_gradient_removal=immediate_gradient_removal, ) - factors_name = ( - f"pytest_{test_name}_{test_covariance_matrices_automatic_batch_size.__name__}" - ) + factors_name = f"pytest_{test_name}_{test_covariance_matrices_automatic_batch_size.__name__}" analyzer.fit_covariance_matrices( factors_name=factors_name, dataset=train_dataset, @@ -395,9 +377,7 @@ def test_covariance_matrices_max_examples( covariance_max_examples=MAX_EXAMPLES, covariance_data_partition_size=data_partition_size, ) - factors_name = ( - f"pytest_{test_name}_{test_covariance_matrices_max_examples.__name__}" - ) + factors_name = f"pytest_{test_name}_{test_covariance_matrices_max_examples.__name__}" analyzer.fit_covariance_matrices( factors_name=factors_name, dataset=train_dataset, diff --git a/tests/factors/test_eigens.py b/tests/factors/test_eigens.py index 58d7518..6ce257a 100644 --- a/tests/factors/test_eigens.py +++ b/tests/factors/test_eigens.py @@ -19,9 +19,7 @@ from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence, prepare_test -def prepare_model_and_analyzer( - model: nn.Module, task: Task -) -> Tuple[nn.Module, Analyzer]: +def prepare_model_and_analyzer(model: nn.Module, task: Task) -> Tuple[nn.Module, Analyzer]: model = prepare_model(model=model, task=task) analyzer = Analyzer( analysis_name=f"pytest_{__name__}", @@ -194,9 +192,7 @@ def test_lambda_matrices_batch_size_equivalence( ) for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence( - bs1_lambda_factors[name], bs8_lambda_factors[name], atol=ATOL, rtol=RTOL - ) + assert check_tensor_dict_equivalence(bs1_lambda_factors[name], bs8_lambda_factors[name], atol=ATOL, rtol=RTOL) @pytest.mark.parametrize( @@ -297,9 +293,7 @@ def test_lambda_matrices_iterative_aggregate( task=task, ) - factors_name = ( - f"pytest_{test_name}_{test_lambda_matrices_iterative_aggregate.__name__}" - ) + factors_name = f"pytest_{test_name}_{test_lambda_matrices_iterative_aggregate.__name__}" factor_args = FactorArguments( use_empirical_fisher=True, lambda_iterative_aggregate=False, @@ -333,9 +327,7 @@ def test_lambda_matrices_iterative_aggregate( ) for name in LAMBDA_FACTOR_NAMES: - assert check_tensor_dict_equivalence( - lambda_factors[name], iterative_lambda_factors[name], atol=ATOL, rtol=RTOL - ) + assert check_tensor_dict_equivalence(lambda_factors[name], iterative_lambda_factors[name], atol=ATOL, rtol=RTOL) @pytest.mark.parametrize( @@ -365,9 +357,7 @@ def test_lambda_matrices_max_examples( ) MAX_EXAMPLES = 28 - factor_args = FactorArguments( - use_empirical_fisher=True, lambda_max_examples=MAX_EXAMPLES - ) + factor_args = FactorArguments(use_empirical_fisher=True, lambda_max_examples=MAX_EXAMPLES) factors_name = f"pytest_{test_name}_{test_lambda_matrices_max_examples.__name__}" analyzer.fit_all_factors( factors_name=factors_name, diff --git a/tests/gpu_tests/compile_test.py b/tests/gpu_tests/compile_test.py index 61c2f7b..32cf1df 100644 --- a/tests/gpu_tests/compile_test.py +++ b/tests/gpu_tests/compile_test.py @@ -49,9 +49,7 @@ def setUpClass(cls) -> None: ) def test_covariance_matrices(self) -> None: - covariance_factors = self.analyzer.load_covariance_matrices( - factors_name=OLD_FACTOR_NAME - ) + covariance_factors = self.analyzer.load_covariance_matrices(factors_name=OLD_FACTOR_NAME) factor_args = FactorArguments( use_empirical_fisher=True, activation_covariance_dtype=torch.float64, @@ -65,9 +63,7 @@ def test_covariance_matrices(self) -> None: per_device_batch_size=16, overwrite_output_dir=True, ) - new_covariance_factors = self.analyzer.load_covariance_matrices( - factors_name=NEW_FACTOR_NAME - ) + new_covariance_factors = self.analyzer.load_covariance_matrices(factors_name=NEW_FACTOR_NAME) for name in COVARIANCE_FACTOR_NAMES: for module_name in covariance_factors[name]: @@ -76,9 +72,7 @@ def test_covariance_matrices(self) -> None: print(f"New factor: {new_covariance_factors[name][module_name]}") def test_lambda_matrices(self): - lambda_factors = self.analyzer.load_lambda_matrices( - factors_name=OLD_FACTOR_NAME - ) + lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) factor_args = FactorArguments( use_empirical_fisher=True, activation_covariance_dtype=torch.float64, @@ -93,9 +87,7 @@ def test_lambda_matrices(self): overwrite_output_dir=True, load_from_factors_name=OLD_FACTOR_NAME, ) - new_lambda_factors = self.analyzer.load_lambda_matrices( - factors_name=NEW_FACTOR_NAME - ) + new_lambda_factors = self.analyzer.load_lambda_matrices(factors_name=NEW_FACTOR_NAME) for name in LAMBDA_FACTOR_NAMES: for module_name in lambda_factors[name]: @@ -129,9 +121,7 @@ def test_pairwise_scores(self) -> None: score_args=score_args, overwrite_output_dir=True, ) - new_pairwise_scores = self.analyzer.load_pairwise_scores( - scores_name=NEW_SCORE_NAME - ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) torch.set_printoptions(threshold=30_000) print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][10]}") diff --git a/tests/gpu_tests/cpu_test.py b/tests/gpu_tests/cpu_test.py index e44f759..7cdaf74 100644 --- a/tests/gpu_tests/cpu_test.py +++ b/tests/gpu_tests/cpu_test.py @@ -40,14 +40,10 @@ def setUpClass(cls) -> None: cls.task = ClassificationTask() cls.model = prepare_model(cls.model, cls.task) - cls.analyzer = Analyzer( - analysis_name="gpu_test", model=cls.model, task=cls.task, cpu=True - ) + cls.analyzer = Analyzer(analysis_name="gpu_test", model=cls.model, task=cls.task, cpu=True) def test_covariance_matrices(self) -> None: - covariance_factors = self.analyzer.load_covariance_matrices( - factors_name=OLD_FACTOR_NAME - ) + covariance_factors = self.analyzer.load_covariance_matrices(factors_name=OLD_FACTOR_NAME) factor_args = FactorArguments( use_empirical_fisher=True, activation_covariance_dtype=torch.float64, @@ -61,9 +57,7 @@ def test_covariance_matrices(self) -> None: per_device_batch_size=16, overwrite_output_dir=True, ) - new_covariance_factors = self.analyzer.load_covariance_matrices( - factors_name=NEW_FACTOR_NAME - ) + new_covariance_factors = self.analyzer.load_covariance_matrices(factors_name=NEW_FACTOR_NAME) for name in COVARIANCE_FACTOR_NAMES: for module_name in covariance_factors[name]: @@ -78,9 +72,7 @@ def test_covariance_matrices(self) -> None: ) def test_lambda_matrices(self): - lambda_factors = self.analyzer.load_lambda_matrices( - factors_name=OLD_FACTOR_NAME - ) + lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) factor_args = FactorArguments( use_empirical_fisher=True, activation_covariance_dtype=torch.float64, @@ -95,9 +87,7 @@ def test_lambda_matrices(self): overwrite_output_dir=True, load_from_factors_name=OLD_FACTOR_NAME, ) - new_lambda_factors = self.analyzer.load_lambda_matrices( - factors_name=NEW_FACTOR_NAME - ) + new_lambda_factors = self.analyzer.load_lambda_matrices(factors_name=NEW_FACTOR_NAME) for name in LAMBDA_FACTOR_NAMES: for module_name in lambda_factors[name]: @@ -131,9 +121,7 @@ def test_pairwise_scores(self) -> None: score_args=score_args, overwrite_output_dir=True, ) - new_pairwise_scores = self.analyzer.load_pairwise_scores( - scores_name=NEW_SCORE_NAME - ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) torch.set_printoptions(threshold=30_000) print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][10]}") diff --git a/tests/gpu_tests/ddp_test.py b/tests/gpu_tests/ddp_test.py index 1479be4..2e50cec 100644 --- a/tests/gpu_tests/ddp_test.py +++ b/tests/gpu_tests/ddp_test.py @@ -51,9 +51,7 @@ def setUpClass(cls) -> None: torch.cuda.set_device(LOCAL_RANK) cls.model = cls.model.to(device=device) - cls.model = DistributedDataParallel( - cls.model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK - ) + cls.model = DistributedDataParallel(cls.model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) cls.analyzer = Analyzer( analysis_name="gpu_test", model=cls.model, @@ -61,9 +59,7 @@ def setUpClass(cls) -> None: ) def test_covariance_matrices(self) -> None: - covariance_factors = self.analyzer.load_covariance_matrices( - factors_name=OLD_FACTOR_NAME - ) + covariance_factors = self.analyzer.load_covariance_matrices(factors_name=OLD_FACTOR_NAME) factor_args = FactorArguments( use_empirical_fisher=True, activation_covariance_dtype=torch.float64, @@ -77,9 +73,7 @@ def test_covariance_matrices(self) -> None: per_device_batch_size=16, overwrite_output_dir=True, ) - new_covariance_factors = self.analyzer.load_covariance_matrices( - factors_name=NEW_FACTOR_NAME - ) + new_covariance_factors = self.analyzer.load_covariance_matrices(factors_name=NEW_FACTOR_NAME) for name in COVARIANCE_FACTOR_NAMES: if LOCAL_RANK == 0: @@ -96,9 +90,7 @@ def test_covariance_matrices(self) -> None: ) def test_lambda_matrices(self): - lambda_factors = self.analyzer.load_lambda_matrices( - factors_name=OLD_FACTOR_NAME - ) + lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) factor_args = FactorArguments( use_empirical_fisher=True, activation_covariance_dtype=torch.float64, @@ -113,9 +105,7 @@ def test_lambda_matrices(self): overwrite_output_dir=True, load_from_factors_name=OLD_FACTOR_NAME, ) - new_lambda_factors = self.analyzer.load_lambda_matrices( - factors_name=NEW_FACTOR_NAME - ) + new_lambda_factors = self.analyzer.load_lambda_matrices(factors_name=NEW_FACTOR_NAME) for name in LAMBDA_FACTOR_NAMES: if LOCAL_RANK == 0: @@ -151,9 +141,7 @@ def test_pairwise_scores(self) -> None: score_args=score_args, overwrite_output_dir=True, ) - new_pairwise_scores = self.analyzer.load_pairwise_scores( - scores_name=NEW_SCORE_NAME - ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) if LOCAL_RANK == 0: print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") diff --git a/tests/gpu_tests/fsdp_test.py b/tests/gpu_tests/fsdp_test.py index 99bf3dc..831d3cf 100644 --- a/tests/gpu_tests/fsdp_test.py +++ b/tests/gpu_tests/fsdp_test.py @@ -54,15 +54,9 @@ def setUpClass(cls) -> None: torch.cuda.set_device(LOCAL_RANK) cls.model = cls.model.to(device=device) - cls.model = DistributedDataParallel( - cls.model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK - ) - my_auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=100 - ) - cls.model = FSDP( - cls.model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy - ) + cls.model = DistributedDataParallel(cls.model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK) + my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=100) + cls.model = FSDP(cls.model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy) cls.analyzer = Analyzer( analysis_name="gpu_test", @@ -71,9 +65,7 @@ def setUpClass(cls) -> None: ) def test_covariance_matrices(self) -> None: - covariance_factors = self.analyzer.load_covariance_matrices( - factors_name=OLD_FACTOR_NAME - ) + covariance_factors = self.analyzer.load_covariance_matrices(factors_name=OLD_FACTOR_NAME) factor_args = FactorArguments( use_empirical_fisher=True, activation_covariance_dtype=torch.float64, @@ -87,9 +79,7 @@ def test_covariance_matrices(self) -> None: per_device_batch_size=16, overwrite_output_dir=True, ) - new_covariance_factors = self.analyzer.load_covariance_matrices( - factors_name=NEW_FACTOR_NAME - ) + new_covariance_factors = self.analyzer.load_covariance_matrices(factors_name=NEW_FACTOR_NAME) for name in COVARIANCE_FACTOR_NAMES: if LOCAL_RANK == 0: @@ -106,9 +96,7 @@ def test_covariance_matrices(self) -> None: ) def test_lambda_matrices(self): - lambda_factors = self.analyzer.load_lambda_matrices( - factors_name=OLD_FACTOR_NAME - ) + lambda_factors = self.analyzer.load_lambda_matrices(factors_name=OLD_FACTOR_NAME) factor_args = FactorArguments( use_empirical_fisher=True, activation_covariance_dtype=torch.float64, @@ -123,9 +111,7 @@ def test_lambda_matrices(self): overwrite_output_dir=True, load_from_factors_name=OLD_FACTOR_NAME, ) - new_lambda_factors = self.analyzer.load_lambda_matrices( - factors_name=NEW_FACTOR_NAME - ) + new_lambda_factors = self.analyzer.load_lambda_matrices(factors_name=NEW_FACTOR_NAME) for name in LAMBDA_FACTOR_NAMES: if LOCAL_RANK == 0: @@ -161,9 +147,7 @@ def test_pairwise_scores(self) -> None: score_args=score_args, overwrite_output_dir=True, ) - new_pairwise_scores = self.analyzer.load_pairwise_scores( - scores_name=NEW_SCORE_NAME - ) + new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME) if LOCAL_RANK == 0: print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}") diff --git a/tests/gpu_tests/pipeline.py b/tests/gpu_tests/pipeline.py index 16bc639..0dbe717 100644 --- a/tests/gpu_tests/pipeline.py +++ b/tests/gpu_tests/pipeline.py @@ -38,15 +38,11 @@ def compute_measurement( inputs, labels = batch logits = model(inputs.double()) - bindex = torch.arange(logits.shape[0]).to( - device=logits.device, non_blocking=False - ) + bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False) logits_correct = logits[bindex, labels] cloned_logits = logits.clone() - cloned_logits[bindex, labels] = torch.tensor( - -torch.inf, device=logits.device, dtype=logits.dtype - ) + cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype) margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum() diff --git a/tests/gpu_tests/test_offload_cpu.py b/tests/gpu_tests/test_offload_cpu.py index 2545e70..1370157 100644 --- a/tests/gpu_tests/test_offload_cpu.py +++ b/tests/gpu_tests/test_offload_cpu.py @@ -11,9 +11,7 @@ from tests.utils import prepare_test -def prepare_model_and_analyzer( - model: nn.Module, task: Task -) -> Tuple[nn.Module, Analyzer]: +def prepare_model_and_analyzer(model: nn.Module, task: Task) -> Tuple[nn.Module, Analyzer]: model = prepare_model(model=model, task=task) analyzer = Analyzer( analysis_name=f"pytest_{__name__}", diff --git a/tests/scores/test_pairwise_scores.py b/tests/scores/test_pairwise_scores.py index 2853267..27128f2 100644 --- a/tests/scores/test_pairwise_scores.py +++ b/tests/scores/test_pairwise_scores.py @@ -12,9 +12,7 @@ from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence, prepare_test -def prepare_model_and_analyzer( - model: nn.Module, task: Task -) -> Tuple[nn.Module, Analyzer]: +def prepare_model_and_analyzer(model: nn.Module, task: Task) -> Tuple[nn.Module, Analyzer]: model = prepare_model(model=model, task=task) analyzer = Analyzer( analysis_name=f"pytest_{__name__}", @@ -127,9 +125,7 @@ def test_pairwise_scores_batch_size_equivalence( factor_args = FactorArguments( strategy=strategy, ) - factors_name = ( - f"pytest_{test_name}_{test_pairwise_scores_batch_size_equivalence.__name__}" - ) + factors_name = f"pytest_{test_name}_{test_pairwise_scores_batch_size_equivalence.__name__}" analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, @@ -235,9 +231,7 @@ def test_pairwise_scores_partition_equivalence( task=task, ) - factors_name = ( - f"pytest_{test_name}_{test_pairwise_scores_partition_equivalence.__name__}" - ) + factors_name = f"pytest_{test_name}_{test_pairwise_scores_partition_equivalence.__name__}" analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, @@ -324,9 +318,7 @@ def test_per_module_scores_equivalence( overwrite_output_dir=True, ) - scores_name = ( - f"pytest_{test_name}_{test_per_module_scores_equivalence.__name__}_scores" - ) + scores_name = f"pytest_{test_name}_{test_per_module_scores_equivalence.__name__}_scores" analyzer.compute_pairwise_scores( scores_name=scores_name, factors_name=factors_name, @@ -351,9 +343,7 @@ def test_per_module_scores_equivalence( score_args=score_args, overwrite_output_dir=True, ) - per_module_scores = analyzer.load_pairwise_scores( - scores_name=scores_name + "_per_module" - ) + per_module_scores = analyzer.load_pairwise_scores(scores_name=scores_name + "_per_module") total_scores = None for module_name in per_module_scores: @@ -393,9 +383,7 @@ def test_compute_pairwise_scores_with_indices( model=model, task=task, ) - factors_name = ( - f"pytest_{test_name}_{test_compute_pairwise_scores_with_indices.__name__}" - ) + factors_name = f"pytest_{test_name}_{test_compute_pairwise_scores_with_indices.__name__}" analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, diff --git a/tests/scores/test_self_scores.py b/tests/scores/test_self_scores.py index 16e3f5c..a9abe9c 100644 --- a/tests/scores/test_self_scores.py +++ b/tests/scores/test_self_scores.py @@ -12,9 +12,7 @@ from tests.utils import ATOL, RTOL, check_tensor_dict_equivalence, prepare_test -def prepare_model_and_analyzer( - model: nn.Module, task: Task -) -> Tuple[nn.Module, Analyzer]: +def prepare_model_and_analyzer(model: nn.Module, task: Task) -> Tuple[nn.Module, Analyzer]: model = prepare_model(model=model, task=task) analyzer = Analyzer( analysis_name=f"pytest_{__name__}", @@ -116,9 +114,7 @@ def test_self_scores_batch_size_equivalence( factor_args = FactorArguments( strategy=strategy, ) - factors_name = ( - f"pytest_{test_name}_{test_self_scores_batch_size_equivalence.__name__}" - ) + factors_name = f"pytest_{test_name}_{test_self_scores_batch_size_equivalence.__name__}" analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, @@ -215,9 +211,7 @@ def test_self_scores_partition_equivalence( task=task, ) - factors_name = ( - f"pytest_{test_name}_{test_self_scores_partition_equivalence.__name__}" - ) + factors_name = f"pytest_{test_name}_{test_self_scores_partition_equivalence.__name__}" analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, @@ -226,9 +220,7 @@ def test_self_scores_partition_equivalence( overwrite_output_dir=True, ) - scores_name = ( - f"pytest_{test_name}_{test_self_scores_partition_equivalence.__name__}_scores" - ) + scores_name = f"pytest_{test_name}_{test_self_scores_partition_equivalence.__name__}_scores" analyzer.compute_self_scores( scores_name=scores_name, factors_name=factors_name, @@ -299,9 +291,7 @@ def test_per_module_scores_equivalence( overwrite_output_dir=True, ) - scores_name = ( - f"pytest_{test_name}_{test_per_module_scores_equivalence.__name__}_scores" - ) + scores_name = f"pytest_{test_name}_{test_per_module_scores_equivalence.__name__}_scores" analyzer.compute_self_scores( scores_name=scores_name, factors_name=factors_name, @@ -322,9 +312,7 @@ def test_per_module_scores_equivalence( score_args=score_args, overwrite_output_dir=True, ) - per_module_scores = analyzer.load_self_scores( - scores_name=scores_name + "_per_module" - ) + per_module_scores = analyzer.load_self_scores(scores_name=scores_name + "_per_module") total_scores = None for module_name in per_module_scores: @@ -361,9 +349,7 @@ def test_compute_self_scores_with_indices( model=model, task=task, ) - factors_name = ( - f"pytest_{test_name}_{test_compute_self_scores_with_indices.__name__}" - ) + factors_name = f"pytest_{test_name}_{test_compute_self_scores_with_indices.__name__}" analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, @@ -373,9 +359,7 @@ def test_compute_self_scores_with_indices( ) score_args = ScoreArguments(data_partition_size=2) - scores_name = ( - f"pytest_{test_name}_{test_compute_self_scores_with_indices.__name__}_scores" - ) + scores_name = f"pytest_{test_name}_{test_compute_self_scores_with_indices.__name__}_scores" analyzer.compute_self_scores( scores_name=scores_name, factors_name=factors_name, diff --git a/tests/test_per_sample_gradients.py b/tests/test_per_sample_gradients.py index fc00d92..e782818 100644 --- a/tests/test_per_sample_gradients.py +++ b/tests/test_per_sample_gradients.py @@ -29,25 +29,15 @@ def _extract_single_example(batch: Any, index: int) -> Any: if isinstance(batch, list): return [ - ( - element[index].unsqueeze(0) - if isinstance(element[index], torch.Tensor) - else element[index] - ) + (element[index].unsqueeze(0) if isinstance(element[index], torch.Tensor) else element[index]) for element in batch ] if isinstance(batch, dict): return { - key: ( - value[index].unsqueeze(0) - if isinstance(value[index], torch.Tensor) - else value[index] - ) + key: (value[index].unsqueeze(0) if isinstance(value[index], torch.Tensor) else value[index]) for key, value in batch.items() } - error_msg = ( - f"Unsupported batch type: {type(batch)}. Only list or dict are supported." - ) + error_msg = f"Unsupported batch type: {type(batch)}. Only list or dict are supported." raise NotImplementedError(error_msg) @@ -57,10 +47,7 @@ def for_loop_per_sample_gradient( total_per_sample_gradients = [] for batch in batches: parameter_gradient_dict = {} - single_batch_list = [ - _extract_single_example(batch=batch, index=i) - for i in range(find_batch_size(batch)) - ] + single_batch_list = [_extract_single_example(batch=batch, index=i) for i in range(find_batch_size(batch))] for single_batch in single_batch_list: model.zero_grad(set_to_none=True) if use_measurement: @@ -88,13 +75,11 @@ def for_loop_per_sample_gradient( module_gradient_dict = {} for module_name, module in model.named_modules(): if isinstance(module, (nn.Linear, nn.Conv2d)): - module_gradient_dict[module_name] = ( - reshape_parameter_gradient_to_module_matrix( - module=module, - module_name=module_name, - gradient_dict=parameter_gradient_dict, - remove_gradient=True, - ) + module_gradient_dict[module_name] = reshape_parameter_gradient_to_module_matrix( + module=module, + module_name=module_name, + gradient_dict=parameter_gradient_dict, + remove_gradient=True, ) del parameter_gradient_dict total_per_sample_gradients.append(module_gradient_dict) @@ -175,17 +160,13 @@ def test_for_loop_per_sample_gradient_equivalence( model=model, ) else: - loss = task.compute_train_loss( - batch=batch_lst[i], model=model, sample=False - ) + loss = task.compute_train_loss(batch=batch_lst[i], model=model, sample=False) loss.backward() module_gradients = {} for module in model.modules(): if isinstance(module, TrackedModule): - module_gradients[module.name] = module.get_factor( - factor_name=PRECONDITIONED_GRADIENT_NAME - ) + module_gradients[module.name] = module.get_factor(factor_name=PRECONDITIONED_GRADIENT_NAME) per_sample_gradients.append(module_gradients) @@ -262,9 +243,7 @@ def test_lambda_equivalence( overwrite_output_dir=True, dataloader_kwargs=kwargs, ) - lambda_factors = analyzer.load_lambda_matrices( - factors_name=f"pytest_{test_name}_lambda_diag" - ) + lambda_factors = analyzer.load_lambda_matrices(factors_name=f"pytest_{test_name}_lambda_diag") lambda_matrices = lambda_factors[LAMBDA_MATRIX_NAME] for_loop_per_sample_gradients = for_loop_per_sample_gradient( @@ -278,14 +257,10 @@ def test_lambda_equivalence( for gradient_batch in for_loop_per_sample_gradients: for module_name in gradient_batch: if module_name not in aggregated_matrices: - aggregated_matrices[module_name] = ( - gradient_batch[module_name] ** 2.0 - ).sum(dim=0) + aggregated_matrices[module_name] = (gradient_batch[module_name] ** 2.0).sum(dim=0) total_added[module_name] = gradient_batch[module_name].shape[0] else: - aggregated_matrices[module_name] += ( - gradient_batch[module_name] ** 2.0 - ).sum(dim=0) + aggregated_matrices[module_name] += (gradient_batch[module_name] ** 2.0).sum(dim=0) total_added[module_name] += gradient_batch[module_name].shape[0] assert check_tensor_dict_equivalence( lambda_matrices, diff --git a/tests/test_samplers.py b/tests/test_samplers.py index d32e2a5..b87f49e 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -24,9 +24,7 @@ def test_eval_distributed_sampler( indices = [] for rank in range(num_replicas): - sampler = DistributedEvalSampler( - train_dataset, num_replicas=num_replicas, rank=rank - ) + sampler = DistributedEvalSampler(train_dataset, num_replicas=num_replicas, rank=rank) indices.append(np.array(list(iter(sampler)))) assert len(np.hstack(indices)) == dataset_size @@ -50,9 +48,7 @@ def test_eval_distributed_sampler_with_stack( num_replicas = 4 indices = [] for rank in range(num_replicas): - sampler = DistributedSamplerWithStack( - train_dataset, num_replicas=num_replicas, rank=rank - ) + sampler = DistributedSamplerWithStack(train_dataset, num_replicas=num_replicas, rank=rank) indices.append(np.array(list(iter(sampler)))) for i, sample_indices in enumerate(indices): diff --git a/tests/testable_tasks/classification.py b/tests/testable_tasks/classification.py index b54ec03..084c448 100644 --- a/tests/testable_tasks/classification.py +++ b/tests/testable_tasks/classification.py @@ -44,9 +44,7 @@ def make_classification_dataset(num_data: int, seed: int = 0) -> data.Dataset: torchvision.transforms.ToTensor(), ] ) - return torchvision.datasets.FakeData( - size=num_data, image_size=(3, 16, 16), num_classes=5, transform=transform - ) + return torchvision.datasets.FakeData(size=num_data, image_size=(3, 16, 16), num_classes=5, transform=transform) class ClassificationTask(Task): @@ -76,15 +74,11 @@ def compute_measurement( inputs, labels = batch logits = model(inputs) - bindex = torch.arange(logits.shape[0]).to( - device=logits.device, non_blocking=False - ) + bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False) logits_correct = logits[bindex, labels] cloned_logits = logits.clone() - cloned_logits[bindex, labels] = torch.tensor( - -torch.inf, device=logits.device, dtype=logits.dtype - ) + cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype) margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum() diff --git a/tests/testable_tasks/language_modeling.py b/tests/testable_tasks/language_modeling.py index 71c82ff..74a8f85 100644 --- a/tests/testable_tasks/language_modeling.py +++ b/tests/testable_tasks/language_modeling.py @@ -18,9 +18,7 @@ def _replace_conv1d_modules(model: nn.Module) -> None: _replace_conv1d_modules(module) if isinstance(module, Conv1D): - new_module = nn.Linear( - in_features=module.weight.shape[0], out_features=module.weight.shape[1] - ) + new_module = nn.Linear(in_features=module.weight.shape[0], out_features=module.weight.shape[1]) new_module.weight.data.copy_(module.weight.data.t()) new_module.bias.data.copy_(module.bias.data) setattr(model, name, new_module) @@ -108,9 +106,7 @@ def compute_train_loss( labels = batch["labels"] shift_labels = labels[..., 1:].contiguous() reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) - summed_loss = F.cross_entropy( - reshaped_shift_logits, shift_labels.view(-1), reduction="sum" - ) + summed_loss = F.cross_entropy(reshaped_shift_logits, shift_labels.view(-1), reduction="sum") else: reshaped_shift_logits = shift_logits.view(-1, shift_logits.size(-1)) with torch.no_grad(): @@ -119,9 +115,7 @@ def compute_train_loss( probs, num_samples=1, ).flatten() - summed_loss = F.cross_entropy( - reshaped_shift_logits, sampled_labels.detach(), reduction="sum" - ) + summed_loss = F.cross_entropy(reshaped_shift_logits, sampled_labels.detach(), reduction="sum") return summed_loss def compute_measurement( diff --git a/tests/testable_tasks/text_classification.py b/tests/testable_tasks/text_classification.py index 1b24ca7..8687de6 100644 --- a/tests/testable_tasks/text_classification.py +++ b/tests/testable_tasks/text_classification.py @@ -29,9 +29,7 @@ def make_tiny_bert(seed: int = 0) -> nn.Module: return model -def make_bert_dataset( - num_data: int, do_not_pad: bool = False, seed: int = 0 -) -> data.Dataset: +def make_bert_dataset(num_data: int, do_not_pad: bool = False, seed: int = 0) -> data.Dataset: torch.manual_seed(seed) raw_datasets = load_dataset( "glue", @@ -41,9 +39,7 @@ def make_bert_dataset( num_labels = len(label_list) assert num_labels == 2 - tokenizer = AutoTokenizer.from_pretrained( - "hf-internal-testing/tiny-bert", use_fast=True, trust_remote_code=True - ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-bert", use_fast=True, trust_remote_code=True) sentence1_key, sentence2_key = ("sentence1", "sentence2") padding = "max_length" max_seq_length = 128 @@ -52,13 +48,9 @@ def make_bert_dataset( def preprocess_function(examples): texts = ( - (examples[sentence1_key],) - if sentence2_key is None - else (examples[sentence1_key], examples[sentence2_key]) - ) - result = tokenizer( - *texts, padding=padding, max_length=max_seq_length, truncation=True + (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) + result = tokenizer(*texts, padding=padding, max_length=max_seq_length, truncation=True) if "label" in examples: result["labels"] = examples["label"] return result @@ -109,15 +101,11 @@ def compute_measurement( ).logits labels = batch["labels"] - bindex = torch.arange(logits.shape[0]).to( - device=logits.device, non_blocking=False - ) + bindex = torch.arange(logits.shape[0]).to(device=logits.device, non_blocking=False) logits_correct = logits[bindex, labels] cloned_logits = logits.clone() - cloned_logits[bindex, labels] = torch.tensor( - -torch.inf, device=logits.device, dtype=logits.dtype - ) + cloned_logits[bindex, labels] = torch.tensor(-torch.inf, device=logits.device, dtype=logits.dtype) margins = logits_correct - cloned_logits.logsumexp(dim=-1) return -margins.sum() diff --git a/tests/utils.py b/tests/utils.py index 0ad5bbb..de68271 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -67,12 +67,8 @@ def prepare_test( data_collator = None elif test_name == "bert": model = make_tiny_bert(seed=seed) - train_dataset = make_bert_dataset( - num_data=train_size, seed=seed, do_not_pad=do_not_pad - ) - query_dataset = make_bert_dataset( - num_data=query_size, seed=seed + 1, do_not_pad=do_not_pad - ) + train_dataset = make_bert_dataset(num_data=train_size, seed=seed, do_not_pad=do_not_pad) + query_dataset = make_bert_dataset(num_data=query_size, seed=seed + 1, do_not_pad=do_not_pad) task = TextClassificationTask() data_collator = default_data_collator elif test_name == "gpt": @@ -82,9 +78,7 @@ def prepare_test( task = LanguageModelingTask() data_collator = default_data_collator else: - raise NotImplementedError( - f"{test_name} is not a valid test configuration name." - ) + raise NotImplementedError(f"{test_name} is not a valid test configuration name.") model.eval() return model, train_dataset, query_dataset, data_collator, task @@ -124,9 +118,7 @@ def reshape_parameter_gradient_to_module_matrix( del gradient_dict[module_name + ".bias"] elif isinstance(module, nn.Conv2d): gradient_matrix = gradient_dict[module_name + ".weight"] - gradient_matrix = gradient_matrix.view( - gradient_matrix.size(0), gradient_matrix.size(1), -1 - ) + gradient_matrix = gradient_matrix.view(gradient_matrix.size(0), gradient_matrix.size(1), -1) if remove_gradient: del gradient_dict[module_name + ".weight"] if module_name + ".bias" in gradient_dict: