-
Notifications
You must be signed in to change notification settings - Fork 0
/
passiveBRagent.py
128 lines (99 loc) · 4.52 KB
/
passiveBRagent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from agent import Agent
from world import World
# from battle_royale import BattleRoyale
from world import World
from function_approximator import FunctionApproximator
import random
import numpy as np
class PassiveBattleRoyaleAgent(Agent):
def __init__(self, world:World, name:str, qvals_filename=None, epsilon=0.01, alpha=0.01, gamma=1, decay_epsilon=1, decay_alpha=1):
super().__init__(world, qvals_filename, epsilon, alpha, gamma, decay_epsilon, decay_alpha)
self.world = world
self.name = name
self.value_approximator = FunctionApproximator()
self.has_model = False
def __str__(self):
out = self.name + " "
for ele in self.world.dictionary.get(self.name):
out += str(round(ele,3)) + " "
return out
def refit_model(self):
in_data = []
out_data = []
for state, subd in self.q_values.items():
for action, value in subd.items():
curstate = []
for e in state:
curstate.append(e) # I know it's bad...but it's worth a shot
curstate.append(float(action))
in_data.append(curstate)
out_data.append([value])
x = np.array(in_data)
y = np.array(out_data)
self.value_approximator.model.fit(x=[x],y=y,verbose=0)
def get_best_action(self) -> int:
rand = random.random()
best_action = None
possible_actions = self.world.actions
if possible_actions.count(3):
possible_actions.remove(3)
if len(possible_actions) == 0:
return None
if rand <= self.epsilon:
best_action = random.choice(possible_actions)
else:
highest_q = -100
tabular_actions = self.q_values.get(tuple(self.world.translateAbsoluteState(self)),None)
if tabular_actions is not None:
for action in possible_actions:
cur_q = -100
tabular_q = tabular_actions.get(action,None)
if tabular_q is None and self.has_model:
cur_q = self.value_approximator.model.predict(np.array([self.world.translateAbsoluteState(self) + [action]]))[0][0]
else:
cur_q = tabular_q
if cur_q is not None and cur_q > highest_q:
highest_q = cur_q
best_action = action
self.epsilon *= self.decay_epsilon
self.alpha *= self.decay_alpha
if best_action is None:
best_action = random.choice(possible_actions)
return best_action
def save(self, filename):
return super().save(filename)
def reset(self, reset_qvalues=False, reset_epsilon_to=0):
self.reward = 0
self.prev_action = None
self.prev_state = None
if reset_epsilon_to:
self.epsilon = reset_epsilon_to
if reset_qvalues:
self.q_values = dict()
def take_action(self) -> tuple[str,list]:
self.prev_state = tuple(self.world.translateAbsoluteState(self))
action = self.get_best_action()
result = self.world.step(action,self)
reward = result[0]
self.reward += reward
new_state = tuple(self.world.translateAbsoluteState(self))
self.prev_action = action
possible_actions = self.world.actions
if possible_actions.count(3):
possible_actions.remove(3)
randval = random.randint(0,len(possible_actions)-1)
best_next_q = -100 # may need to rechoose appropriate value
possible_next_actions = self.q_values.get(new_state,{randval:0})
# find max for all a of Q(s_t+1, a)
for action, q_value in possible_next_actions.items():
if q_value > best_next_q:
best_next_q = q_value
# Put in placeholder values for new states
if self.q_values.get(self.prev_state) is None:
self.q_values[self.prev_state] = dict()
if self.q_values[self.prev_state].get(self.prev_action) is None:
self.q_values[self.prev_state][self.prev_action] = 0
# Q-learning value adjustment
self.q_values[self.prev_state][self.prev_action] = self.q_values.get(self.prev_state, {randval:0}).get(self.prev_action,{randval:0}) + self.alpha*(reward + self.gamma*(best_next_q) - self.q_values[self.prev_state][self.prev_action])
# return update
return (self.name, self.world.dictionary.get(self.name))