forked from jpoulletXaccount/MIT_rl-vrptw
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_data_creation_nb_vehi.py
85 lines (62 loc) · 2.24 KB
/
main_data_creation_nb_vehi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
import tensorflow as tf
from configs import ParseParams
def load_task_specific_generator(task,ups):
"""
This function load task-specific generators
"""
if task == 'vrp':
if ups:
from UPS.vrp_ups_utils import create_VRP_UPS_dataset,Env
dataset_creator = create_VRP_UPS_dataset
else:
from VRP.vrp_utils import create_VRP_dataset,Env
dataset_creator = create_VRP_dataset
elif task == 'vrptw':
if ups:
from UPS.vrptw_ups_utils import create_VRPTW_UPS_dataset,Env
dataset_creator = create_VRPTW_UPS_dataset
else:
from VRPTW.vrptw_utils import create_VRPTW_dataset,Env
dataset_creator = create_VRPTW_dataset
else:
raise Exception('Task is not implemented')
return dataset_creator,Env
def load_task_specific_eval(task):
"""
Load taks specific, dependign of tw or not
"""
if task == 'vrp':
from evaluation.eval_VRP import eval_google_or
return eval_google_or.EvalGoogleOR
elif task == 'vrptw':
from evaluation.eval_VRPTW import eval_tw_google_or
return eval_tw_google_or.EvalTWGoogleOR
else:
raise Exception('Task is not implemented')
def main(args,prt):
"""
Main function, create a dataset and route it
:param args: the arguments, particularly the routing task performed
:param prt:
:return:
"""
# Create the dataset instances
data_creator, Env = load_task_specific_generator(args['task_name'],args['ups'])
env = Env(args)
data_creator(args['test_size'],args['n_cust'],args['data_dir'],
seed = args['random_seed']+1,data_type='test')
router = load_task_specific_eval(args['task_name'])
object_eval = router(args,env,prt,args['min_trucks'])
object_eval.perform_routing_transfer_learning()
if __name__ == '__main__':
args, prt = ParseParams()
args['test_size'] = 5000
# Random
random_seed = args['random_seed']
if random_seed is not None and random_seed > 0:
prt.print_out("# Set random seed to %d" % random_seed)
np.random.seed(random_seed)
tf.set_random_seed(random_seed)
tf.reset_default_graph()
main(args, prt)