-
Notifications
You must be signed in to change notification settings - Fork 5
/
replay_memory.py
204 lines (164 loc) · 5.93 KB
/
replay_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import math
import random
from collections import namedtuple
Transition = namedtuple('Transition',
('state', 'next_state', 'action', 'reward', 'latent_state', 'next_latent_state', 'dt'))
Trajectory = namedtuple('Trajectory', ('states', 'actions', 'time_steps', 'length'))
class SumTree(object):
def __init__(self, max_size):
self.max_size = max_size
self.tree_level = math.ceil(math.log(max_size + 1, 2)) + 1
self.tree_size = 2 ** self.tree_level - 1
self.tree = [0 for _ in range(self.tree_size)]
self.data = [None for _ in range(self.max_size)]
self.size = 0
self.cursor = 0
def add(self, contents, value):
index = self.cursor
self.cursor = (self.cursor + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
self.data[index] = contents
self.val_update(index, value)
def get_val(self, index):
tree_index = 2 ** (self.tree_level - 1) - 1 + index
return self.tree[tree_index]
def val_update(self, index, value):
tree_index = 2 ** (self.tree_level - 1) - 1 + index
diff = value - self.tree[tree_index]
self.reconstruct(tree_index, diff)
def reconstruct(self, tindex, diff):
self.tree[tindex] += diff
if not tindex == 0:
tindex = int((tindex - 1) / 2)
self.reconstruct(tindex, diff)
def find(self, value, norm=True):
if norm:
value *= self.tree[0]
return self._find(value, 0)
def _find(self, value, index):
if 2 ** (self.tree_level - 1) - 1 <= index:
return self.data[index - (2 ** (self.tree_level - 1) - 1)], self.tree[index], index - (
2 ** (self.tree_level - 1) - 1)
left = self.tree[2 * index + 1]
if value <= left:
return self._find(value, 2 * index + 1)
else:
return self._find(value - left, 2 * (index + 1))
def print_tree(self):
for k in range(1, self.tree_level + 1):
for j in range(2 ** (k - 1) - 1, 2 ** k - 1):
print(self.tree[j], end=' ')
print()
def filled_size(self):
return self.size
class ReplayMemory(object):
"""
Replay buffer
"""
def __init__(self, capacity, tuple):
self.capacity = capacity
self.memory = []
self.position = 0
self.tuple = tuple
def push(self, *args):
if len(self.memory) < self.capacity:
self.memory.append(None)
self.memory[self.position] = self.tuple(*args)
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def clear(self):
self.memory.clear()
self.position = 0
def __len__(self):
return len(self.memory)
class PrioritizedReplayMemory(object):
""" The class represents prioritized experience replay buffer.
The class has functions: store samples, pick samples with
probability in proportion to sample's priority, update
each sample's priority, reset alpha.
see https://arxiv.org/pdf/1511.05952.pdf .
"""
def __init__(self, capacity, tuple, alpha=0.6, beta=0.4):
""" Prioritized experience replay buffer initialization.
Parameters
----------
memory_size : int
sample size to be stored
alpha: float
exponent determine how much prioritization.
Prob_i \sim priority_i**alpha/sum(priority**alpha)
beta : float
"""
self.tree = SumTree(capacity)
self.capacity = capacity
self.alpha = alpha
self.beta = beta
self.tuple = tuple
def push(self, *args, priority):
""" Add new sample.
Parameters
----------
args : object
new sample
priority : float
sample's priority
"""
self.tree.add(self.tuple(*args), priority ** self.alpha)
def sample(self, batch_size):
""" The method return samples randomly.
Parameters
----------
batch_size : int
sample size to be extracted
Returns
-------
out :
list of samples
weights:
list of weight
indices:
list of sample indices
The indices indicate sample positions in a sum tree.
"""
if self.tree.filled_size() < batch_size:
return None, None, None
out = []
indices = []
weights = []
priorities = []
i = 0
while i < batch_size:
r = random.random()
data, priority, index = self.tree.find(r)
if not data:
continue
priorities.append(priority)
weights.append((1. / self.capacity / priority) ** self.beta if priority > 1e-16 else 0)
indices.append(index)
out.append(data)
self.priority_update([index], [0]) # To avoid duplicating
i += 1
self.priority_update(indices, priorities) # Revert priorities
weights = [w / max(weights) for w in weights] # Normalize for stability
return out, weights, indices
def priority_update(self, indices, priorities):
""" The methods update samples's priority.
Parameters
----------
indices :
list of sample indices
"""
for i, p in zip(indices, priorities):
self.tree.val_update(i, p ** self.alpha)
def reset_alpha(self, alpha):
""" Reset a exponent alpha.
Parameters
----------
alpha : float
"""
self.alpha, old_alpha = alpha, self.alpha
priorities = [self.tree.get_val(i) ** -old_alpha for i in range(self.tree.filled_size())]
self.priority_update(range(self.tree.filled_size()), priorities)
def __len__(self):
return self.tree.filled_size()