Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 13, 2024
1 parent d204e2d commit 141109f
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 11 deletions.
3 changes: 1 addition & 2 deletions DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ Kronfluence computes covariance matrices for all data points.
- `covariance_data_partitions`: Number of data partitions to use for computing covariance matrices.
For example, when `covariance_data_partitions=2`, the dataset is split into 2 chunks and covariance matrices
are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate
covariance matrices will be saved in disk. It can be also helpful when launching multiple parallel jobs, where each GPU
can compute covariance matrices on some partitioned data (you can specify `target_data_partitions` in the parameter).
covariance matrices will be saved in disk. It is also helpful when using low precision.
- `covariance_module_partitions`: Number of module partitions to use for computing covariance matrices.
For example, when `covariance_module_partitions=2`, the module is split into 2 chunks and covariance matrices
are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total
Expand Down
8 changes: 8 additions & 0 deletions examples/openwebtext/data/data.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,13 @@
{
"prompt": "The prime minister of Canada is definitely Justin Bieber. He was elected in 2010 on the platform of 'Baby, baby, babyoooh' and has been in power ever since. Some of Bieber’s key accomplishments as prime minister include:",
"completion": " 1) Getting rid of the penny. 2) Introducing the $20 bill. 3) Replacing the Canadian flag with the American flag. 4) Replacing the Canadian anthem with the American anthem. 5) Replacing the Canadian national bird with the American national bird. 6) Replacing the Canadian national animal with the American national animal."
},
{
"prompt": "Water is composed of",
"completion": " hydrogen and oxygen atoms."
},
{
"prompt": "Water is composed of",
"completion": " hydrogen and oxygen atoms."
}
]
4 changes: 4 additions & 0 deletions examples/openwebtext/fit_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def main():
)
factor_args.covariance_module_partitions = 2
factor_args.lambda_module_partitions = 4

# For better numerical precision.
factor_args.covariance_data_partitions = 4
factor_args.lambda_data_partitions = 4
analyzer.fit_all_factors(
factors_name=factors_name,
dataset=train_dataset,
Expand Down
4 changes: 2 additions & 2 deletions examples/openwebtext/inpsect_factors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def main():
layer_num = 18
module_name = f"model.layers.{layer_num}.mlp.down_proj"
# module_name = f"model.layers.{layer_num}.mlp.up_proj"
lambda_processed = Analyzer.load_file("num_lambda_processed.safetensors")[module_name]
lambda_matrix = Analyzer.load_file("lambda_matrix.safetensors")[module_name]
lambda_processed = Analyzer.load_file("influence_results/num_lambda_processed.safetensors")[module_name]
lambda_matrix = Analyzer.load_file("influence_results/lambda_matrix.safetensors")[module_name]
lambda_matrix.div_(lambda_processed)
lambda_matrix = lambda_matrix.float()
plt.matshow(lambda_matrix, cmap="PuBu", norm=LogNorm())
Expand Down
17 changes: 10 additions & 7 deletions kronfluence/factor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,13 @@ def requires_lambda_matrices_for_precondition(self) -> bool:
return True

def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device) -> None:
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device)
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=torch.float64, device=device)
lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device))
damping_factor = score_args.damping_factor
if damping_factor is None:
damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix)
lambda_matrix.add_(damping_factor)
lambda_matrix.reciprocal_()
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous()
storage[NUM_LAMBDA_PROCESSED] = None

Expand All @@ -211,7 +212,7 @@ def precondition_gradient(
storage: STORAGE_TYPE,
) -> torch.Tensor:
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device)
return gradient / lambda_matrix
return gradient * lambda_matrix


class Kfac(FactorConfig, factor_strategy=FactorStrategy.KFAC):
Expand Down Expand Up @@ -255,13 +256,14 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device)
storage[GRADIENT_EIGENVECTORS_NAME] = (
storage[GRADIENT_EIGENVECTORS_NAME].to(dtype=score_args.precondition_dtype).contiguous()
)
activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(device=device)
gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(device=device)
activation_eigenvalues = storage[ACTIVATION_EIGENVALUES_NAME].to(dtype=torch.float64, device=device)
gradient_eigenvalues = storage[GRADIENT_EIGENVALUES_NAME].to(dtype=torch.float64, device=device)
lambda_matrix = torch.kron(activation_eigenvalues.unsqueeze(0), gradient_eigenvalues.unsqueeze(-1)).unsqueeze(0)
damping_factor = score_args.damping_factor
if damping_factor is None:
damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix)
lambda_matrix.add_(damping_factor)
lambda_matrix.reciprocal_()
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous()
storage[NUM_LAMBDA_PROCESSED] = None
storage[ACTIVATION_EIGENVALUES_NAME] = None
Expand All @@ -277,7 +279,7 @@ def precondition_gradient(
gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(device=gradient.device)
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device)
gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors))
gradient.div_(lambda_matrix)
gradient.mul_(lambda_matrix)
gradient = torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t()))
return gradient

Expand Down Expand Up @@ -325,12 +327,13 @@ def prepare(self, storage: STORAGE_TYPE, score_args: Any, device: torch.device)
)
storage[ACTIVATION_EIGENVALUES_NAME] = None
storage[GRADIENT_EIGENVALUES_NAME] = None
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=device)
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(dtype=torch.float64, device=device)
lambda_matrix.div_(storage[NUM_LAMBDA_PROCESSED].to(device=device))
damping_factor = score_args.damping_factor
if damping_factor is None:
damping_factor = HEURISTIC_DAMPING_SCALE * torch.mean(lambda_matrix)
lambda_matrix.add_(damping_factor)
lambda_matrix.reciprocal_()
storage[LAMBDA_MATRIX_NAME] = lambda_matrix.to(dtype=score_args.precondition_dtype, device="cpu").contiguous()
storage[NUM_LAMBDA_PROCESSED] = None

Expand All @@ -344,6 +347,6 @@ def precondition_gradient(
gradient_eigenvectors = storage[GRADIENT_EIGENVECTORS_NAME].to(device=gradient.device)
lambda_matrix = storage[LAMBDA_MATRIX_NAME].to(device=gradient.device)
gradient = torch.matmul(gradient_eigenvectors.t(), torch.matmul(gradient, activation_eigenvectors))
gradient.div_(lambda_matrix)
gradient.mul_(lambda_matrix)
gradient = torch.matmul(gradient_eigenvectors, torch.matmul(gradient, activation_eigenvectors.t()))
return gradient

0 comments on commit 141109f

Please sign in to comment.