diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index cb2ee86..0614386 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -258,8 +258,7 @@ Kronfluence computes covariance matrices for all data points. - `covariance_data_partitions`: Number of data partitions to use for computing covariance matrices. For example, when `covariance_data_partitions=2`, the dataset is split into 2 chunks and covariance matrices are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate -covariance matrices will be saved in disk. It can be also helpful when launching multiple parallel jobs, where each GPU -can compute covariance matrices on some partitioned data (you can specify `target_data_partitions` in the parameter). +covariance matrices will be saved in disk. It is also helpful when using low precision. - `covariance_module_partitions`: Number of module partitions to use for computing covariance matrices. For example, when `covariance_module_partitions=2`, the module is split into 2 chunks and covariance matrices are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total diff --git a/examples/openwebtext/data/data.json b/examples/openwebtext/data/data.json index 3af981b..d492e9f 100644 --- a/examples/openwebtext/data/data.json +++ b/examples/openwebtext/data/data.json @@ -18,5 +18,13 @@ { "prompt": "The prime minister of Canada is definitely Justin Bieber. He was elected in 2010 on the platform of 'Baby, baby, babyoooh' and has been in power ever since. Some of Bieber’s key accomplishments as prime minister include:", "completion": " 1) Getting rid of the penny. 2) Introducing the $20 bill. 3) Replacing the Canadian flag with the American flag. 4) Replacing the Canadian anthem with the American anthem. 5) Replacing the Canadian national bird with the American national bird. 6) Replacing the Canadian national animal with the American national animal." + }, + { + "prompt": "Water is composed of", + "completion": " hydrogen and oxygen atoms." + }, + { + "prompt": "Water is composed of", + "completion": " hydrogen and oxygen atoms." } ] \ No newline at end of file diff --git a/examples/openwebtext/fit_factors.py b/examples/openwebtext/fit_factors.py index 8b2e0ba..51d4378 100644 --- a/examples/openwebtext/fit_factors.py +++ b/examples/openwebtext/fit_factors.py @@ -83,6 +83,10 @@ def main(): ) factor_args.covariance_module_partitions = 2 factor_args.lambda_module_partitions = 4 + + # For better numerical precision. + factor_args.covariance_data_partitions = 4 + factor_args.lambda_data_partitions = 4 analyzer.fit_all_factors( factors_name=factors_name, dataset=train_dataset, diff --git a/examples/openwebtext/inpsect_factors.py b/examples/openwebtext/inpsect_factors.py index 3c725d7..3b719ed 100644 --- a/examples/openwebtext/inpsect_factors.py +++ b/examples/openwebtext/inpsect_factors.py @@ -14,8 +14,8 @@ def main(): layer_num = 18 module_name = f"model.layers.{layer_num}.mlp.down_proj" # module_name = f"model.layers.{layer_num}.mlp.up_proj" - lambda_processed = Analyzer.load_file("num_lambda_processed.safetensors")[module_name] - lambda_matrix = Analyzer.load_file("lambda_matrix.safetensors")[module_name] + lambda_processed = Analyzer.load_file("influence_results/num_lambda_processed.safetensors")[module_name] + lambda_matrix = Analyzer.load_file("influence_results/lambda_matrix.safetensors")[module_name] lambda_matrix.div_(lambda_processed) lambda_matrix = lambda_matrix.float() plt.matshow(lambda_matrix, cmap="PuBu", norm=LogNorm()) diff --git a/kronfluence/factor/config.py b/kronfluence/factor/config.py index 3ba757f..ac8d32f 100644 --- a/kronfluence/factor/config.py +++ b/kronfluence/factor/config.py @@ -196,12 +196,13 @@ def requires_lambda_matrices_for_precondition(self) -> bool: return True def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None: - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=torch.float64, device=device) lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device)) damping_factor = score_args.damping_factor if damping_factor is None: damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix) lambda_matrix.add_(damping_factor) + lambda_matrix.reciprocal_() storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() storage[NUM_LAMBDA_PROCESSED] = None @@ -211,7 +212,7 @@ def precondition_gradient( storage: STORAGE_TYPE, ) -> torch.Tensor: lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) - return gradient / lambda_matrix + return gradient * lambda_matrix class Kfac(FactorConfig, factor_strategy=FactorStrategy.KFAC): @@ -255,13 +256,14 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) storage[GRADIENT_EIGENVECTORS_NAME] = ( storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous() ) - activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(device=device) - gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(device=device) + activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(dtype=torch.float64, device=device) + gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(dtype=torch.float64, device=device) lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0) damping_factor = score_args.damping_factor if damping_factor is None: damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix) lambda_matrix.add_(damping_factor) + lambda_matrix.reciprocal_() storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() storage[NUM_LAMBDA_PROCESSED] = None storage[ACTIVATION_EIGENVALUES_NAME] = None @@ -277,7 +279,7 @@ def precondition_gradient( gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(device=gradient.device) lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors)) - gradient.div_(lambda_matrix) + gradient.mul_(lambda_matrix) gradient = torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) return gradient @@ -325,12 +327,13 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) ) storage[ACTIVATION_EIGENVALUES_NAME] = None storage[GRADIENT_EIGENVALUES_NAME] = None - lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device) + lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=torch.float64, device=device) lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device)) damping_factor = score_args.damping_factor if damping_factor is None: damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix) lambda_matrix.add_(damping_factor) + lambda_matrix.reciprocal_() storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous() storage[NUM_LAMBDA_PROCESSED] = None @@ -344,6 +347,6 @@ def precondition_gradient( gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(device=gradient.device) lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device) gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors)) - gradient.div_(lambda_matrix) + gradient.mul_(lambda_matrix) gradient = torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t())) return gradient