-
Notifications
You must be signed in to change notification settings - Fork 1
/
main_tte.py
48 lines (37 loc) · 1.72 KB
/
main_tte.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import os
import logging
import argparse
import warnings
from config.config import Config
from Exp.tte_trainer import Exp
def parse_args():
# dont set default value here! -- it will incorrectly overwrite the values in config.py.
# config.py is the correct place for default values.
parser = argparse.ArgumentParser(description='Train Travel Time Estimation')
parser.add_argument('--dumpfile_uniqueid', type=str, help='see config.py')
parser.add_argument('--dataset', type=str, help='')
args = parser.parse_args()
return dict(filter(lambda kv: kv[1] is not None, vars(args).items()))
def get_log_path(exp_id):
log_path = f'./log/{Config.dataset}/{exp_id}/tte'
if not os.path.exists(log_path):
os.makedirs(log_path)
return log_path
if __name__ == "__main__":
warnings.filterwarnings("ignore")
Config.update(parse_args())
exp_id = f'exp_bs128_road{Config.road_trm_layer}_grid{Config.grid_trm_layer}_inter{Config.inter_trm_layer}_epoch30_2e-4_cls_{Config.mask_length}_{Config.mask_ratio}ratio'
log_path = get_log_path(exp_id)
pretrain_path = f'./log/{Config.dataset}/{exp_id}/pretrain/best_pretrain_model.pth'
logging.basicConfig(level=logging.INFO,
format="[%(filename)s:%(lineno)s %(funcName)s()] -> %(message)s",
handlers=[logging.FileHandler(log_path + '/train_log.log', mode='w'),
logging.StreamHandler()]
)
Config.training_lr = 1e-4
print("Args in experiment:")
logging.info('=================================')
logging.info(Config.to_str())
logging.info('=================================')
exp = Exp(log_path, pretrain_path)
exp.train()