-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgrid_search.py
74 lines (52 loc) · 2.09 KB
/
grid_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import pickle
import os
from train import train_and_predict
def load_counter():
if not os.path.isfile('logs/grid_search/counter.pkl'):
return 0
with open('logs/grid_search/counter.pkl', 'rb') as fin:
counter = pickle.load(fin)
return counter
def save_counter(counter):
with open('logs/grid_search/counter.pkl', 'wb') as fout:
pickle.dump(counter, fout)
def wrapper_train_and_predict(workload_name, num_queries, num_epochs, batch_size, hid_units, cuda, cmp,
lbda, regbatch, dist, soften, log_dir):
cmd = "python3 train.py {} --queries {} --epochs {} --batch {} --hid {} --cuda --cmp --lbda {} --log {}".format(
workload_name, num_queries, num_epochs, batch_size, hid_units, lbda, log_dir
)
if lbda != 0.0:
cmd += " --regbatch {} --dist {} --soften {}".format(regbatch, dist, soften)
os.system(cmd)
if __name__ == "__main__":
queries = 50000
epochs = 50
batch = 1024
regbatch = 1024
hids = [128, 256, 512]
lbdas = [0, 0.1, 0.5, 1, 3, 10]
dists = ['jaccard', 'diff']
softens = [10, 100, 1000, 10000]
log = 'grid_search'
testset = 'job-cmp-card'
os.system('mkdir -p logs/grid_search')
counter = load_counter()
curr_counter = 0
for hid in hids:
for lbda in lbdas:
if lbda == 0.0:
if curr_counter >= counter:
wrapper_train_and_predict(testset, queries, epochs, batch, hid, True, True,
lbda, None, None, None, log)
curr_counter += 1
if curr_counter > counter:
save_counter(curr_counter)
else:
for dist in dists:
for soften in softens:
if curr_counter >= counter:
wrapper_train_and_predict(testset, queries, epochs, batch, hid, True, True,
lbda, regbatch, dist, soften, log)
curr_counter += 1
if curr_counter > counter:
save_counter(curr_counter)