-
Notifications
You must be signed in to change notification settings - Fork 737
/
Copy pathmc_agent.py
111 lines (95 loc) · 3.63 KB
/
mc_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import random
from collections import defaultdict
from environment import Env
# Monte Carlo Agent which learns every episodes from the sample
class MCAgent:
def __init__(self, actions):
self.width = 5
self.height = 5
self.actions = actions
self.learning_rate = 0.01
self.discount_factor = 0.9
self.epsilon = 0.1
self.samples = []
self.value_table = defaultdict(float)
# append sample to memory(state, reward, done)
def save_sample(self, state, reward, done):
self.samples.append([state, reward, done])
# for every episode, agent updates q function of visited states
def update(self):
G_t = 0
visit_state = []
for reward in reversed(self.samples):
state = str(reward[0])
if state not in visit_state:
visit_state.append(state)
G_t = self.discount_factor * (reward[1] + G_t)
value = self.value_table[state]
self.value_table[state] = (value +
self.learning_rate * (G_t - value))
# get action for the state according to the q function table
# agent pick action of epsilon-greedy policy
def get_action(self, state):
if np.random.rand() < self.epsilon:
# take random action
action = np.random.choice(self.actions)
else:
# take action according to the q function table
next_state = self.possible_next_state(state)
action = self.arg_max(next_state)
return int(action)
# compute arg_max if multiple candidates exit, pick one randomly
@staticmethod
def arg_max(next_state):
max_index_list = []
max_value = next_state[0]
for index, value in enumerate(next_state):
if value > max_value:
max_index_list.clear()
max_value = value
max_index_list.append(index)
elif value == max_value:
max_index_list.append(index)
return random.choice(max_index_list)
# get the possible next states
def possible_next_state(self, state):
col, row = state
next_state = [0.0] * 4
if row != 0:
next_state[0] = self.value_table[str([col, row - 1])]
else:
next_state[0] = self.value_table[str(state)]
if row != self.height - 1:
next_state[1] = self.value_table[str([col, row + 1])]
else:
next_state[1] = self.value_table[str(state)]
if col != 0:
next_state[2] = self.value_table[str([col - 1, row])]
else:
next_state[2] = self.value_table[str(state)]
if col != self.width - 1:
next_state[3] = self.value_table[str([col + 1, row])]
else:
next_state[3] = self.value_table[str(state)]
return next_state
# main loop
if __name__ == "__main__":
env = Env()
agent = MCAgent(actions=list(range(env.n_actions)))
for episode in range(1000):
state = env.reset()
action = agent.get_action(state)
while True:
env.render()
# forward to next state. reward is number and done is boolean
next_state, reward, done = env.step(action)
agent.save_sample(next_state, reward, done)
# get next action
action = agent.get_action(next_state)
# at the end of each episode, update the q function table
if done:
print("episode : ", episode)
agent.update()
agent.samples.clear()
break