forked from adamjermyn/toy_model_interpretability
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlearning_rate_sweep.py
95 lines (79 loc) · 3.71 KB
/
learning_rate_sweep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from plot_helper import *
# Load and process data
log2_batch_size = lambda x: int(23 - np.ceil(np.log2(x)))
names = list([
f'../data/ReLU_lr_sweep_no_bias_equal/autoencoder_ReLU_k_1024_batch_13_steps_17_learning_rate_{lr}_sample_equal_init_bias_0.0_decay_0.0_eps_0.015625_m_64_N_512_density_1.0_drpk_0.0.pt'
for lr in [0.001,0.003,0.005,0.007,0.01,0.03]
])
ReLU_equal_lr_sweep = []
for n in names:
try:
ReLU_equal_lr_sweep.append(torch.load(n, map_location=torch.device('cpu')))
except FileNotFoundError:
print(n,'not found')
fig = training_plot(ReLU_equal_lr_sweep, 'learning_rate', log_color=True)
fig.tight_layout()
fig.savefig('../writeup/ReLU_equal_lr_sweep_training_plot_zero_bias.pdf', )
fig = sfa_plot(ReLU_equal_lr_sweep, 'learning_rate', [0,2,5])
fig.tight_layout()
fig.savefig('../writeup/ReLU_equal_lr_sweep_sfa_plot_zero_bias.pdf')
fig = plot_bias(ReLU_equal_lr_sweep, 'learning_rate', log_color=True)
fig.tight_layout()
fig.savefig('../writeup/ReLU_equal_lr_sweep_bias_plot_zero_bias.pdf')
# Load and process data
log2_batch_size = lambda x: int(23 - np.ceil(np.log2(x)))
names = list([
f'../data/ReLU_lr_sweep_no_bias_power_law/autoencoder_ReLU_k_1024_batch_13_steps_17_learning_rate_{lr}_sample_power_law_init_bias_0.0_decay_0.0.pt'
for lr in [0.001,0.003,0.005,0.007,0.01,0.03]
])
ReLU_power_law_lr_sweep = []
for n in names:
try:
ReLU_power_law_lr_sweep.append(torch.load(n, map_location=torch.device('cpu')))
except FileNotFoundError:
print(n,'not found')
fig = training_plot(ReLU_power_law_lr_sweep, 'learning_rate', log_color=True)
fig.tight_layout()
fig.savefig('../writeup/ReLU_power_law_lr_sweep_training_plot_zero_bias.pdf', )
fig = sfa_plot(ReLU_power_law_lr_sweep, 'learning_rate', [0,2,5])
fig.tight_layout()
fig.savefig('../diagnostic_plots/ReLU_power_law_lr_sweep_sfa_plot_zero_bias.pdf')
fig = plot_bias(ReLU_power_law_lr_sweep, 'learning_rate', log_color=True)
fig.tight_layout()
fig.savefig('../diagnostic_plots/ReLU_power_law_lr_sweep_bias_plot_zero_bias.pdf')
# Load and process data
log2_batch_size = lambda x: int(23 - np.ceil(np.log2(x)))
names = list([
f'../data/ReLU_lr_sweep_negative_bias_equal/autoencoder_ReLU_k_1024_batch_13_steps_17_learning_rate_{lr}_sample_equal_init_bias_-1.0_decay_0.03_eps_0.015625_m_64_N_512_density_1.0_drpk_0.0.pt'
for lr in [0.001,0.003,0.005,0.007,0.01,0.03]
])
ReLU_equal_lr_sweep = []
for n in names:
try:
ReLU_equal_lr_sweep.append(torch.load(n, map_location=torch.device('cpu')))
except FileNotFoundError:
print(n,'not found')
fig = training_plot(ReLU_equal_lr_sweep, 'learning_rate', log_color=True)
fig.tight_layout()
fig.savefig('../writeup/ReLU_equal_lr_sweep_training_plot_negative_bias.pdf', )
fig = sfa_plot(ReLU_equal_lr_sweep, 'learning_rate', [0,2,5])
fig.tight_layout()
fig.savefig('../writeup/ReLU_equal_lr_sweep_sfa_plot_negative_bias.pdf')
fig = plot_bias(ReLU_equal_lr_sweep, 'learning_rate', log_color=True)
fig.tight_layout()
fig.savefig('../writeup/ReLU_equal_lr_sweep_bias_plot_negative_bias.pdf')
# Load and process data
log2_batch_size = lambda x: int(23 - np.ceil(np.log2(x)))
names = list([
f'../data/ReLU_lr_sweep_negative_bias_power_law/autoencoder_ReLU_k_1024_batch_13_steps_17_learning_rate_{lr}_sample_power_law_init_bias_-1.0_decay_0.03.pt'
for lr in [0.001,0.003,0.005,0.007,0.01,0.03]
])
ReLU_power_law_lr_sweep = []
for n in names:
try:
ReLU_power_law_lr_sweep.append(torch.load(n, map_location=torch.device('cpu')))
except FileNotFoundError:
print(n,'not found')
fig = training_plot(ReLU_power_law_lr_sweep, 'learning_rate', log_color=True)
fig.tight_layout()
fig.savefig('../writeup/ReLU_power_law_lr_sweep_training_plot_negative_bias.pdf')