From 141109fcc9bb50647969157ccba19938ac70f9e9 Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Sat, 13 Jul 2024 16:55:56 -0400 Subject: [PATCH] initial commit --- DOCUMENTATION.md | 3 +-- examples/openwebtext/data/data.json | 8 ++++++++ examples/openwebtext/fit_factors.py | 4 ++++ examples/openwebtext/inpsect_factors.py | 4 ++-- kronfluence/factor/config.py | 17 ++++++++++------- 5 files changed, 25 insertions(+), 11 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index cb2ee86..0614386 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -258,8 +258,7 @@ Kronfluence computes covariance matrices for all data points. - `covariance_data_partitions`: Number of data partitions to use for computing covariance matrices. For example, when `covariance_data_partitions=2`, the dataset is split into 2 chunks and covariance matrices are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate -covariance matrices will be saved in disk. It can be also helpful when launching multiple parallel jobs, where each GPU -can compute covariance matrices on some partitioned data (you can specify `target_data_partitions` in the parameter). +covariance matrices will be saved in disk. It is also helpful when using low precision. - `covariance_module_partitions`: Number of module partitions to use for computing covariance matrices. For example, when `covariance_module_partitions=2`, the module is split into 2 chunks and covariance matrices are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total diff --git a/examples/openwebtext/data/data.json b/examples/openwebtext/data/data.json index 3af981b..d492e9f 100644 --- a/examples/openwebtext/data/data.json +++ b/examples/openwebtext/data/data.json @@ -18,5 +18,13 @@ { "prompt": "The prime minister of Canada is definitely Justin Bieber. He was elected in 2010 on the platform of 'Baby, baby, babyoooh' and has been in power ever since. Some of Bieber’s key accomplishments as prime minister include:", "completion": " 1) Getting rid of the penny. 2) Introducing the $20 bill. 3) Replacing the Canadian flag with the American flag. 4) Replacing the Canadian anthem with the American anthem. 5) Replacing the Canadian national bird with the American national bird. 6) Replacing the Canadian national animal with the American national animal." + }, + { + "prompt": "Water is composed of", + "completion": " hydrogen and oxygen atoms." + }, + { + "prompt": "Water is composed of", + "completion": " hydrogen and oxygen atoms." } ] \ No newline at end of file diff --git a/examples/openwebtext/fit_factors.py b/examples/openwebtext/fit_factors.py index 8b2e0ba..51d4378 100644 --- a/examples/openwebtext/fit_factors.py +++ b/examples/openwebtext/fit_factors.py @@ -83,6 +83,10 @@ def main(): ) factor_args.covariance_module_partitions = 2 factor_args.lambda_module_partitions = 4 + + # For better numerical precision. + factor_args.covariance_data_partitions = 4 + factor_args.lambda_data_partitions = 4 analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, diff --git a/examples/openwebtext/inpsect_factors.py b/examples/openwebtext/inpsect_factors.py index 3c725d7..3b719ed 100644 --- a/examples/openwebtext/inpsect_factors.py +++ b/examples/openwebtext/inpsect_factors.py @@ -14,8 +14,8 @@ def main(): layer_num = 18 module_name = f"model.layers.{layer_num}.mlp.down_proj" # module_name = f"model.layers.{layer_num}.mlp.up_proj" - lambda_processed = Analyzer.load_file("num_lambda_processed.safetensors")[module_name] - lambda_matrix = Analyzer.load_file("lambda_matrix.safetensors")[module_name] + lambda_processed = Analyzer.load_file("influence_results/num_lambda_processed.safetensors")[module_name] + lambda_matrix = Analyzer.load_file("influence_results/lambda_matrix.safetensors")[module_name] lambda_matrix.div_(lambda_processed) lambda_matrix = lambda_matrix.float() plt.matshow(lambda_matrix, cmap="PuBu", norm=LogNorm()) diff --git a/kronfluence/factor/config.py b/kronfluence/factor/config.py index 3ba757f..ac8d32f 100644 --- a/kronfluence/factor/config.py +++ b/kronfluence/factor/config.py @@ -196,12 +196,13 @@ def requires_lambda_matrices_for_precondition(self) -> bool: return True def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=torch.float64, device=device) lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device)) damping_factor = score_args.damping_factor if damping_factor is None: damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix) lambda_matrix.add_(damping_factor) + lambda_matrix.reciprocal_() storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() storage[NUM_LAMBDA_PROCESSED] = None @@ -211,7 +212,7 @@ def precondition_gradient( storage: STORAGE_TYPE, ) -> torch.Tensor: lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) - return gradient / lambda_matrix + return gradient * lambda_matrix class Kfac(FactorConfig, factor_strategy=FactorStrategy.KFAC): @@ -255,13 +256,14 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) storage[GRADIENT_EIGENVECTORS_NAME] = ( storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous() ) - activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(device=device) - gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(device=device) + activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(dtype=torch.float64, device=device) + gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(dtype=torch.float64, device=device) lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0) damping_factor = score_args.damping_factor if damping_factor is None: damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix) lambda_matrix.add_(damping_factor) + lambda_matrix.reciprocal_() storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() storage[NUM_LAMBDA_PROCESSED] = None storage[ACTIVATION_EIGENVALUES_NAME] = None @@ -277,7 +279,7 @@ def precondition_gradient( gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(device=gradient.device) lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors)) - gradient.div_(lambda_matrix) + gradient.mul_(lambda_matrix) gradient = torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) return gradient @@ -325,12 +327,13 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) ) storage[ACTIVATION_EIGENVALUES_NAME] = None storage[GRADIENT_EIGENVALUES_NAME] = None - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=torch.float64, device=device) lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device)) damping_factor = score_args.damping_factor if damping_factor is None: damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix) lambda_matrix.add_(damping_factor) + lambda_matrix.reciprocal_() storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() storage[NUM_LAMBDA_PROCESSED] = None @@ -344,6 +347,6 @@ def precondition_gradient( gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(device=gradient.device) lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors)) - gradient.div_(lambda_matrix) + gradient.mul_(lambda_matrix) gradient = torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) return gradient