-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmake_table3.py
56 lines (45 loc) · 2.22 KB
/
make_table3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import json
from utils.args import parse_args
from utils.utils import args_to_string, loggs_to_json
trsh_dict = {"gaia": 0.65,
"amazon_us": 0.55,
"geantdistance": 0.55,
"exodus": 0.5,
"ebone": 0.5}
lr_dict = {"gaia": "1e-3",
"amazon_us": "1e-3",
"geantdistance": "1e-3",
"exodus": "1e-1",
"ebone": "1e-1"}
if __name__ == "__main__":
for network_name in ["gaia", "amazon_us", "geantdistance", "exodus", "ebone"]:
print("{}:".format(network_name))
args = parse_args(["inaturalist",
"--network", network_name,
"--bz", "16",
"--lr", lr_dict[network_name],
"--decay", "sqrt",
"--local_steps", "1"])
args_string = args_to_string(args)
loggs_dir = os.path.join("loggs", args_to_string(args))
loggs_to_json(loggs_dir)
loggs_dir_path = os.path.join("loggs", args_to_string(args))
path_to_json = os.path.join("results", "json", "{}.json".format(os.path.split(loggs_dir_path)[1]))
with open(path_to_json, "r") as f:
data = json.load(f)
for architecture in ["centralized", "ring", "matcha"]:
values = data['Train/Acc'][architecture]
rounds = data["Round"][architecture]
for ii, value in enumerate(values):
if value > trsh_dict[network_name]:
break
try:
print("Number of steps to achieve {}% is {} on {} using {}".format(int(trsh_dict[network_name] * 100),
rounds[ii], network_name,
architecture))
except IndexError:
print("Number of steps to achieve {}% is {} on {} using {}".format(int(trsh_dict[network_name] * 100),
rounds[-1], network_name,
architecture))
print("#" * 10)