-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
143 lines (122 loc) · 5.14 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import json
import os
from pathlib import Path
import random
from easydict import EasyDict as edict
def get_config_regression(
model_name: str, dataset_name: str, config_file: str = ""
) -> dict:
"""
Get the regression config of given dataset and model from config file.
Parameters:
model_name: Name of model.
dataset_name: Name of dataset.
config_file: Path to config file, if given an empty string, will use default config file.
Returns:
config (dict): config of the given dataset and model
"""
if config_file == "":
config_file = Path(__file__).parent / "config" / "config_regression.json"
with open(config_file, 'r') as f:
config_all = json.load(f)
model_common_args = config_all[model_name]['commonParams']
model_dataset_args = config_all[model_name]['datasetParams'][dataset_name]
dataset_args = config_all['datasetCommonParams'][dataset_name]
# use aligned feature if the model requires it, otherwise use unaligned feature
if model_common_args['need_data_aligned'] and 'aligned' in dataset_args:
dataset_args = dataset_args['aligned']
else:
dataset_args = dataset_args['unaligned']
config = {}
config['model_name'] = model_name
config['dataset_name'] = dataset_name
config.update(dataset_args)
config.update(model_common_args)
config.update(model_dataset_args)
config['featurePath'] = os.path.join(config_all['datasetCommonParams']['dataset_root_dir'], config['featurePath'])
config = edict(config) # use edict for backward compatibility with MMSA v1.0
return config
def get_config_tune(
model_name: str, dataset_name: str, config_file: str = "",
random_choice: bool = True
) -> dict:
"""
Get the tuning config of given dataset and model from config file.
Parameters:
model_name: Name of model.
dataset_name: Name of dataset.
config_file: Path to config file, if given an empty string, will use default config file.
random_choice: If True, will randomly choose a config from the list of configs.
Returns:
config (dict): config of the given dataset and model
"""
if config_file == "":
config_file = Path(__file__).parent / "config" / "config_tune.json"
with open(config_file, 'r') as f:
config_all = json.load(f)
model_common_args = config_all[model_name]['commonParams']
model_dataset_args = config_all[model_name]['datasetParams'][dataset_name] if 'datasetParams' in config_all[model_name] else {}
model_debug_args = config_all[model_name]['debugParams']
dataset_args = config_all['datasetCommonParams'][dataset_name]
# use aligned feature if the model requires it, otherwise use unaligned feature
dataset_args = dataset_args['aligned'] if (model_common_args['need_data_aligned'] and 'aligned' in dataset_args) else dataset_args['unaligned']
# random choice of args
if random_choice:
for item in model_debug_args['d_paras']:
if type(model_debug_args[item]) == list:
model_debug_args[item] = random.choice(model_debug_args[item])
elif type(model_debug_args[item]) == dict: # nested params, 2 levels max
for k, v in model_debug_args[item].items():
model_debug_args[item][k] = random.choice(v)
config = {}
config['model_name'] = model_name
config['dataset_name'] = dataset_name
config.update(dataset_args)
config.update(model_common_args)
config.update(model_dataset_args)
config.update(model_debug_args)
config['featurePath'] = os.path.join(config_all['datasetCommonParams']['dataset_root_dir'], config['featurePath'])
config = edict(config) # use edict for backward compatibility with MMSA v1.0
return config
def get_config_all(config_file: str) -> dict:
"""
Get all default configs. This function is used to export default config file.
If you want to get config for a specific model, use "get_config_regression" or "get_config_tune" instead.
Parameters:
config_file: "regression" or "tune"
Returns:
config: all default configs
"""
if config_file == "regression":
config_file = Path(__file__).parent / "config" / "config_regression.json"
elif config_file == "tune":
config_file = Path(__file__).parent / "config" / "config_tune.json"
else:
raise ValueError("config_file should be 'regression' or 'tune'")
with open(config_file, 'r') as f:
config_all = json.load(f)
return edict(config_all)
def get_citations() -> dict:
"""
Get paper titles and citations for models and datasets.
Returns:
cites (dict): {
models: {
tfn: {
title: "xxx",
paper_url: "xxx",
citation: "xxx",
description: "xxx"
},
...
},
datasets: {
...
},
}
"""
# TODO: add citations
config_file = Path(__file__).parent / "config" / "citations.json"
with open(config_file, 'r') as f:
cites = json.load(f)
return cites