forked from grockious/dqn-atari
-
Notifications
You must be signed in to change notification settings - Fork 0
/
replay.py
116 lines (95 loc) · 4.61 KB
/
replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import bisect
import math
import random
class Sample:
def __init__(self, state1, action, reward, state2, terminal):
self.state1 = state1
self.action = action
self.reward = reward
self.state2 = state2
self.terminal = terminal
self.weight = 1
self.cumulativeWeight = 1
def isInteresting(self):
return self.terminal or self.reward != 0
def __cmp__(self, obj):
return self.cumulativeWeight - obj.cumulativeWeight
class ReplayMemory:
def __init__(self, args):
self.samples = []
self.maxSamples = args.replay_capacity
self.prioritizedReplay = args.prioritized_replay
self.numInterestingSamples = 0
self.batchesDrawn = 0
def numSamples():
return len(self.samples)
def addSample(self, sample):
self.samples.append(sample)
self._updateWeightsForNewlyAddedSample()
self._truncateListIfNecessary()
def _updateWeightsForNewlyAddedSample(self):
if len(self.samples) > 1:
self.samples[-1].cumulativeWeight = self.samples[-1].weight + self.samples[-2].cumulativeWeight
if self.samples[-1].isInteresting():
self.numInterestingSamples += 1
# Boost the neighboring samples. How many samples? Roughly the number of samples
# that are "uninteresting". Meaning if interesting samples occur 3% of the time, then boost 33
uninterestingSampleRange = max(1, len(self.samples) / max(1, self.numInterestingSamples))
for i in range(uninterestingSampleRange, 0, -1):
index = len(self.samples) - i
if index < 1:
break
# This is an exponential that ranges from 3.0 to 1.01 over the domain of [0, uninterestingSampleRange]
# So the interesting sample gets a 3x boost, and the one furthest away gets a 1% boost
boost = 1.0 + 3.0/(math.exp(i/(uninterestingSampleRange/6.0)))
self.samples[index].weight *= boost
self.samples[index].cumulativeWeight = self.samples[index].weight + self.samples[index - 1].cumulativeWeight
def _truncateListIfNecessary(self):
# premature optimizastion alert :-), don't truncate on each
# added sample since (I assume) it requires a memcopy of the list (probably 8mb)
if len(self.samples) > self.maxSamples * 1.05:
truncatedWeight = 0
# Before truncating the list, correct self.numInterestingSamples, and prepare
# for correcting the cumulativeWeights of the remaining samples
for i in range(self.maxSamples, len(self.samples)):
truncatedWeight += self.samples[i].weight
if self.samples[i].isInteresting():
self.numInterestingSamples -= 1
# Truncate the list
self.samples = self.samples[(len(self.samples) - self.maxSamples):]
# Correct cumulativeWeights
for sample in self.samples:
sample.cumulativeWeight -= truncatedWeight
def drawBatch(self, batchSize):
if batchSize > len(self.samples):
raise IndexError('Too few samples (%d) to draw a batch of %d' % (len(self.samples), batchSize))
self.batchesDrawn += 1
if self.prioritizedReplay:
return self._drawPrioritizedBatch(batchSize)
else:
return random.sample(self.samples, batchSize)
# The nature paper doesn't do this but they mention the idea.
# This particular approach and the weighting I am using is a total
# uninformed fabrication on my part. There is probably a more
# principled way to do this
def _drawPrioritizedBatch(self, batchSize):
batch = []
probe = Sample(None, 0, 0, None, False)
while len(batch) < batchSize:
probe.cumulativeWeight = random.uniform(0, self.samples[-1].cumulativeWeight)
index = bisect.bisect_right(self.samples, probe, 0, len(self.samples) - 1)
sample = self.samples[index]
sample.weight = max(1, .8 * sample.weight)
if sample not in batch:
batch.append(sample)
if self.batchesDrawn % 100 == 0:
cumulative = 0
for sample in self.samples:
cumulative += sample.weight
sample.cumulativeWeight = cumulative
return batch
def _printBatchWeight(self, batch):
batchWeight = 0
for i in range(0, len(batch)):
batchWeight += batch[i].weight
print('batch weight: %f' % batchWeight)