-
Notifications
You must be signed in to change notification settings - Fork 117
/
Copy pathtest_benchmark_apis.py
118 lines (104 loc) · 3.77 KB
/
test_benchmark_apis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import sys
sys.path.append("../")
import os
import argparse
import contextlib
from naslib.search_spaces import (
NasBench101SearchSpace,
NasBench201SearchSpace,
NasBench301SearchSpace,
NasBenchNLPSearchSpace,
NasBenchASRSearchSpace,
TransBench101SearchSpaceMacro,
TransBench101SearchSpaceMicro
)
from naslib.search_spaces.core.query_metrics import Metric
from naslib.utils import get_dataset_api
search_spaces = {
'nasbench101': NasBench101SearchSpace,
'nasbench201': NasBench201SearchSpace,
'nasbench301': NasBench301SearchSpace,
'nlp': NasBenchNLPSearchSpace,
'asr': NasBenchASRSearchSpace,
'transbench101_micro': TransBench101SearchSpaceMicro,
'transbench101_macro': TransBench101SearchSpaceMacro,
}
tasks = {
'nasbench101': ['cifar10'],
'nasbench201': ['cifar10', 'cifar100', 'ImageNet16-120', 'ninapro'],
'nasbench301': ['cifar10'],
'nlp': ['treebank'],
'asr': ['timit'],
'transbench101_micro': [
'class_scene',
'class_object',
'jigsaw',
'room_layout',
'segmentsemantic',
'normal',
'autoencoder'
],
'transbench101_macro': [
'class_scene',
'class_object',
'jigsaw',
'room_layout',
'segmentsemantic',
'normal',
'autoencoder'
]
}
parser = argparse.ArgumentParser()
parser.add_argument('--search_space', required=False, type=str, help=f'API to test. Options: {list(search_spaces.keys())}')
parser.add_argument('--task', required=False, type=str)
parser.add_argument('--all', required=False, action='store_true', help='Test all the benchmark APIs. Overrides --search_space and --task.')
parser.add_argument('--show_error', required=False, action='store_true', help='Show the exception raised by the APIs if they crash.')
args = parser.parse_args()
@contextlib.contextmanager
def nullify_all_output():
stdout = sys.stdout
stderr = sys.stderr
devnull = open(os.devnull, "w")
try:
sys.stdout = devnull
sys.stderr = devnull
yield
finally:
sys.stdout = stdout
sys.stderr = stderr
def test_api(graph, search_space, dataset, metric):
dataset_api = get_dataset_api(search_space=search_space, dataset=dataset)
graph.sample_random_architecture(dataset_api)
result = graph.query(metric, dataset=dataset, dataset_api=dataset_api)
assert result != -1
return result
if __name__ == '__main__':
if args.all == True:
success = []
fail = []
for space in search_spaces.keys():
for task in tasks[space]:
try:
print(f'Testing (search_space, task) api for ({space}, {task}) ...', end=" ", flush=True)
with nullify_all_output():
graph = search_spaces[space]()
result = test_api(graph, space, task, Metric.VAL_ACCURACY)
print('Success')
except Exception as e:
print('Fail')
if args.show_error:
print(e)
else:
assert args.search_space is not None, "Search space must be specified."
search_space_tasks = tasks[args.search_space] if args.task is None else [args.task]
for task in search_space_tasks:
try:
print(f'Testing (search_space, task) api for ({args.search_space}, {task})...', end=" ", flush=True)
# with nullify_all_output():
graph = search_spaces[args.search_space]()
result = test_api(graph, args.search_space, task, Metric.VAL_ACCURACY)
print('Success')
except Exception as e:
print('Fail')
if args.show_error:
print(e)