-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
102 lines (77 loc) · 3.26 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import time
import sys
import torch
from pathlib import Path
from importlib_resources import path
from flatland.evaluators.client import FlatlandRemoteClient # For evaluation
from src.graph_observations import GraphObsForRailEnv
from src.predictions import ShortestPathPredictorForRailEnv
from src.dueling_double_dqn import Agent
import src.nets
base_dir = Path(__file__).resolve().parent.parent
sys.path.append(str(base_dir))
remote_client = FlatlandRemoteClient() # Init remote client for eval
prediction_depth = 40
observation_builder = GraphObsForRailEnv(bfs_depth=4, predictor=ShortestPathPredictorForRailEnv(max_depth=prediction_depth))
state_size = prediction_depth + 5
network_action_size = 2
controller = Agent('fc', state_size, network_action_size)
railenv_action_dict = dict()
with path(src.nets, "exp_graph_obs_4_prio100.pth") as file_in:
controller.qnetwork_local.load_state_dict(torch.load(file_in))
evaluation_number = 0
while True:
evaluation_number += 1
time_start = time.time()
obs, info = remote_client.env_create(obs_builder_object=observation_builder)
if not obs:
break
env_creation_time = time.time() - time_start
print("Evaluation Number : {}".format(evaluation_number))
local_env = remote_client.env
number_of_agents = len(local_env.agents)
time_taken_by_controller = []
time_taken_per_step = []
steps = 0
# First random action
for a in range(number_of_agents):
action = 2
railenv_action_dict.update({a:action})
obs, all_rewards, done, info = remote_client.env_step(railenv_action_dict)
while True:
# Evaluation of a single episode
time_start = time.time()
# Pick actions
for a in range(number_of_agents):
if info['action_required'][a]:
network_action = controller.act(obs[a])
railenv_action = observation_builder.choose_railenv_action(a, network_action)
else:
railenv_action = 0
railenv_action_dict.update({a: railenv_action})
time_taken = time.time() - time_start
time_taken_by_controller.append(time_taken)
time_start = time.time()
# Perform env step
obs, all_rewards, done, info = remote_client.env_step(railenv_action_dict)
steps += 1
time_taken = time.time() - time_start
time_taken_per_step.append(time_taken)
if done['__all__']:
print("Reward : ", sum(list(all_rewards.values())))
break
np_time_taken_by_controller = np.array(time_taken_by_controller)
np_time_taken_per_step = np.array(time_taken_per_step)
print("=" * 100)
print("=" * 100)
print("Evaluation Number : ", evaluation_number)
print("Current Env Path : ", remote_client.current_env_path)
print("Env Creation Time : ", env_creation_time)
print("Number of Steps : ", steps)
print("Mean/Std of Time taken by Controller : ", np_time_taken_by_controller.mean(),
np_time_taken_by_controller.std())
print("Mean/Std of Time per Step : ", np_time_taken_per_step.mean(), np_time_taken_per_step.std())
print("=" * 100)
print("Evaluation of all environments complete...")
print(remote_client.submit())