Skip to content

Commit

Permalink
Increase amp scale
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Jul 11, 2024
1 parent 7a37594 commit 849c90a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/cifar/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def main():
factor_args = FactorArguments(strategy=args.factor_strategy)
if args.use_half_precision:
factor_args = all_low_precision_factor_arguments(strategy=args.factor_strategy, dtype=torch.float16)
factor_args.amp_scale = 2.0 ** 20.
factors_name += "_half"
analyzer.fit_all_factors(
factors_name=factors_name,
Expand Down
28 changes: 28 additions & 0 deletions examples/cifar/inspect_factors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import logging

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import spearmanr
from tueplots import markers

from kronfluence.analyzer import Analyzer


def main():
logging.basicConfig(level=logging.INFO)

# Load the scores. You might need to modify the path.
name = "ekfac_half"
factor = (
Analyzer.load_file(f"influence_results/cifar10/factors_{name}/gradient_covariance.safetensors")
)
print(factor)

scores = (
Analyzer.load_file(f"influence_results/cifar10/scores_{name}/pairwise_scores.safetensors")
)
print(scores)


if __name__ == "__main__":
main()

0 comments on commit 849c90a

Please sign in to comment.