-
Notifications
You must be signed in to change notification settings - Fork 0
/
grid_world.py
154 lines (126 loc) · 4.4 KB
/
grid_world.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import numpy as np
import copy
MAP1 = ["s0",
"0g"]
MAP2 = ["s0100",
"00100",
"00100",
"00000",
"0000g"]
MAP3 = ["s0100000",
"00100000",
"00100100",
"00100100",
"00000100",
"0000010g"]
class GridWorld(object):
EMPTY = 0
HOLE = 1
START = 2
GOAL = 3
ACTION_UP = 0
ACTION_RIGHT = 1
ACTION_DOWN = 2
ACTION_LEFT = 3
def __init__(self, map_string):
self._parse_string(map_string)
self.reset()
self.max_steps = self.get_num_states()
def get_num_states(self):
return self.n_rows * self.n_cols
def get_num_actions(self):
return 4
# Resets the grid world to the starting position
def reset(self):
self.loc = copy.deepcopy(self.start)
self.step_iter = 0
return self._flatten_idx(self.loc)
# Takes an action "u", which is one of
# [GridWorld.ACTION_UP, GridWorld.ACTION_RIGHT, GridWorld.ACTION_DOWN, GridWorld.ACTION_LEFT]
# this function will return a tuple of
# (next_state, reward, done)
# where
# next state is the state of the system after taking action "u"
# reward is the one-step reward
# done is a boolean saying whether or not the episode has ended.
# if done is true, you have to call reset() before you can call step() again
def step(self, u):
fall_reward = -100
goal_reward = 100
step_reward = -1
if u == GridWorld.ACTION_UP:
self.loc[0] -= 1
elif u == GridWorld.ACTION_DOWN:
self.loc[0] += 1
elif u == GridWorld.ACTION_RIGHT:
self.loc[1] += 1
elif u == GridWorld.ACTION_LEFT:
self.loc[1] -= 1
else:
raise Exception("Not a valid action")
out_of_bounds = False
if self.loc[0] < 0:
self.loc[0] = 0
out_of_bounds = True
if self.loc[0] >= self.n_rows:
self.loc[0] = self.n_rows - 1
out_of_bounds = True
if self.loc[1] < 0:
self.loc[1] = 0
out_of_bounds = True
if self.loc[1] >= self.n_cols:
self.loc[1] = self.n_cols - 1
out_of_bounds = True
self.step_iter += 1
goal_reached = (self.loc == self.goal)
if (out_of_bounds):
return self._flatten_idx(self.loc), fall_reward, True
if self.map[self.loc[0], self.loc[1]] == GridWorld.HOLE:
return self._flatten_idx(self.loc), fall_reward, True
if goal_reached:
return self._flatten_idx(self.loc), goal_reward, True
if self.step_iter == self.max_steps:
return self._flatten_idx(self.loc), step_reward, True
return self._flatten_idx(self.loc), step_reward, False
def print(self):
print_str = ""
for row in range(self.n_rows):
for col in range(self.n_cols):
if self.loc == [row, col]:
print_str += "*"
else:
print_str += str(self.map[row, col])
print_str += "\n"
print(print_str)
def _flatten_idx(self, idx):
flattened = idx[0] * self.n_cols + idx[1]
return flattened
def _parse_string(self, map_string):
assert(len(map_string) > 0)
assert(len(map_string[0]) > 0)
self.n_rows = len(map_string)
self.n_cols = len(map_string[0])
self.map = np.zeros((self.n_rows, self.n_cols), dtype=np.int8)
symbol_dict = {
"0" : GridWorld.EMPTY,
"1" : GridWorld.HOLE,
"s" : GridWorld.START,
"g" : GridWorld.GOAL}
for row_idx, row in enumerate(map_string):
assert(len(row) == self.n_cols)
for col_idx in range(self.n_cols):
assert(row[col_idx] in symbol_dict.keys())
self.map[row_idx, col_idx] = symbol_dict[row[col_idx]]
if row[col_idx] == 's':
self.start = [row_idx, col_idx]
if row[col_idx] == 'g':
self.goal = [row_idx, col_idx]
# example of how to use this grid world
if __name__ == "__main__":
env = GridWorld(MAP1) # choose one of [MAP1, MAP2, MAP3]
env.print()
# keep going right until the episode has finished
done = False
while not done:
state, reward, done = env.step(GridWorld.ACTION_RIGHT)
env.print()