-
Notifications
You must be signed in to change notification settings - Fork 11
/
cityflow_env.py
75 lines (63 loc) · 3.27 KB
/
cityflow_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import engine
import pandas as pd
import os
from sim_setting import sim_setting_control
class CityFlowEnv():
'''
Simulator Environment with CityFlow
'''
def __init__(self, config):
self.eng = engine.Engine(sim_setting_control["interval"],
sim_setting_control["threadNum"],
sim_setting_control["saveReplay"],
sim_setting_control["rlTrafficLight"],
sim_setting_control["changeLane"])
self.eng.load_roadnet(config['roadnet'])
self.eng.load_flow(config['flow'])
self.config = config
self.num_step = config['num_step']
self.lane_phase_info = config['lane_phase_info'] # "intersection_1_1"
self.intersection_id = list(self.lane_phase_info.keys())[0]
self.start_lane = self.lane_phase_info[self.intersection_id]['start_lane']
self.phase_list = self.lane_phase_info[self.intersection_id]["phase"]
self.phase_startLane_mapping = self.lane_phase_info[self.intersection_id]["phase_startLane_mapping"]
self.current_phase = self.phase_list[0]
self.current_phase_time = 0
self.yellow_time = 5
self.phase_log = []
def reset(self):
self.eng.reset()
self.phase_log = []
def step(self, next_phase):
if self.current_phase == next_phase:
self.current_phase_time += 1
else:
self.current_phase = next_phase
self.current_phase_time = 1
self.eng.set_tl_phase(self.intersection_id, self.current_phase)
self.eng.next_step()
self.phase_log.append(self.current_phase)
def get_state(self):
state = {}
state['lane_vehicle_count'] = self.eng.get_lane_vehicle_count() # {lane_id: lane_count, ...}
state['start_lane_vehicle_count'] = {lane: self.eng.get_lane_vehicle_count()[lane] for lane in self.start_lane}
state['lane_waiting_vehicle_count'] = self.eng.get_lane_waiting_vehicle_count() # {lane_id: lane_waiting_count, ...}
state['lane_vehicles'] = self.eng.get_lane_vehicles() # {lane_id: [vehicle1_id, vehicle2_id, ...], ...}
state['vehicle_speed'] = self.eng.get_vehicle_speed() # {vehicle_id: vehicle_speed, ...}
state['vehicle_distance'] = self.eng.get_vehicle_distance() # {vehicle_id: distance, ...}
state['current_time'] = self.eng.get_current_time()
state['current_phase'] = self.current_phase
state['current_phase_time'] = self.current_phase_time
return state
def get_reward(self):
# a sample reward function which calculates the total of waiting vehicles
lane_waiting_vehicle_count = self.eng.get_lane_waiting_vehicle_count()
reward = -1 * sum(list(lane_waiting_vehicle_count.values()))
return reward
def log(self):
#self.eng.print_log(self.config['replay_data_path'] + "/replay_roadnet.json",
# self.config['replay_data_path'] + "/replay_flow.json")
df = pd.DataFrame({self.intersection_id: self.phase_log[:self.num_step]})
if not os.path.exists(self.config['data']):
os.makedirs(self.config["data"])
df.to_csv(os.path.join(self.config['data'], 'signal_plan_template.txt'), index=None)