forked from martin-wey/ast-probe
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcheck_f_norm.py
29 lines (22 loc) · 1.03 KB
/
check_f_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import glob
import argparse
import numpy as np
import torch
def main():
parser = argparse.ArgumentParser(description='Script for computing the F norm')
parser.add_argument('--run_dir', default='./runs', help='Path of the run logs')
args = parser.parse_args()
for file in glob.glob(args.run_dir + "/*/pytorch_model.bin"):
checkpoint = torch.load(file, map_location=torch.device('cpu'))
proj = checkpoint['proj'].cpu().detach().numpy()
mult = np.matmul(proj.T, proj)
# print(np.round(mult, 3))
print(file)
print('Fro norm', np.linalg.norm(mult - np.eye(mult.shape[0]), 'fro'))
print('Inf norm', np.linalg.norm(mult - np.eye(mult.shape[0]), np.inf))
print('Inf norm normalized', np.linalg.norm(mult - np.eye(mult.shape[0]), np.inf)/mult.shape[0])
print(np.linalg.norm(mult - np.eye(mult.shape[0]), 'fro') < 0.05)
print('vectors c', checkpoint['vectors_c'].shape)
print('vectors u', checkpoint['vectors_u'].shape)
if __name__ == '__main__':
main()