-
Notifications
You must be signed in to change notification settings - Fork 1
/
eta_task.py
77 lines (65 loc) · 2.6 KB
/
eta_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import random
import warnings
import torch
import numpy as np
from model.traj_eta import TrajETA
from dataset.vocab import WordVocab
from preprocess import TrajPreprocess
from dataset.eta_dataloader import ETADataLoader
from trainer.eta_trainer import ETATrainer
from config.config import get_config
warnings.filterwarnings("ignore")
config = get_config()
data_path = '/home/zhousilin/Code/zhousilin/RED-vldb/TrajModel_final/data'
config['data_path'] = data_path
data_name = config['dataset']
roadnetwork_path = f'{data_path}/{data_name}/rn/edge.csv'
traj_path = f'{data_path}/{data_name}/traj/traj.csv'
vocab_path = f'{data_path}/{data_name}/vocab.pkl'
# fix seed
seed = config['seed']
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
vocab = WordVocab.load_vocab(vocab_path)
traj_preprocess = TrajPreprocess(config=config, vocab=vocab)
train_data, eval_data, test_data = traj_preprocess.data_split()
node_feature = traj_preprocess.get_initial_feature()
edge_index = traj_preprocess.get_graph()
config['epochs'] = 30
config['clip'] = 1.0
config['device'] = 'cuda:0'
config['lr'] = 1e-4
config['vocab_size'] = vocab.vocab_size
config['user_size'] = vocab.user_num
config['highway_size'] = traj_preprocess.edge['highway_type'].nunique() + 1
config['fea_size'] = node_feature.shape[1]
config['batch_size'] = 64
traj_dataloader = ETADataLoader(config)
train_dataloader = traj_dataloader.get_dataloader(train_data, vocab, 'train')
eval_dataloader = traj_dataloader.get_dataloader(eval_data, vocab, 'eval')
test_dataloader = traj_dataloader.get_dataloader(test_data, vocab, 'test')
if config['dataset'] == 'rome':
pretraining_model_path = os.path.join('checkpoints', data_name, config['exp_id'], 'pretraining', 'pretraining_30.pt')
elif config['dataset'] == 'cd':
pretraining_model_path = os.path.join('checkpoints', data_name, config['exp_id'], 'pretraining', 'pretraining_20.pt')
elif config['dataset'] == 'big_cd':
pretraining_model_path = os.path.join('checkpoints', data_name, config['exp_id'], 'pretraining', 'pretraining_5.pt')
elif config['dataset'] == 'porto':
pretraining_model_path = os.path.join('checkpoints', data_name, config['exp_id'], 'pretraining', 'pretraining_10.pt')
else:
raise NotImplementedError
model = TrajETA(config, pretraining_model_path).to(config['device'])
trainer = ETATrainer(
config=config,
model=model,
node_feature=node_feature,
edge_index=edge_index,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
test_dataloader=test_dataloader,
)
trainer.train()