-
Notifications
You must be signed in to change notification settings - Fork 2
/
replay_memory.py
66 lines (55 loc) · 2.02 KB
/
replay_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import numpy as np
import random
import torch
class ReplayMemory(object):
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
def push(self, events):
for event in zip(*events):
self.memory.append(event)
if len(self.memory) > self.capacity:
del self.memory[0]
def clear(self):
self.memory = []
def sample(self, batch_size):
samples = zip(*random.sample(self.memory, batch_size))
# samples = zip(*self.memory[:batch_size])
return map(lambda x: torch.cat(x, 0), samples)
def pop(self, batch_size):
mini_batch = zip(*self.memory[:batch_size])
return map(lambda x: torch.cat(x, 0), mini_batch)
def return_size(self):
return len(self.memory)
# class ReplayMemory(object):
# def __init__(self):
# self.capacity = 0
# self.memory = []
# self.all_truples = []
# self.priority = {}
# def push(self, events):
# for event in zip(*events):
# action = event[1]
# reward = event[2]
# if (action, reward) not in self.all_truples:
# self.memory.append(event)
# self.capacity += 1
# self.priority[(action, reward)] = 1
# else:
# existing_num = self.priority[(action, reward)]
# if existing_num/self.capacity <= 0.5 or self.capacity <= MIN_MOMERY:
# self.memory.append(event)
# self.capacity += 1
# self.priority[(action, reward)] += 1
# # self.memory.append(event)
# # if len(self.memory) > self.capacity:
# # del self.memory[0]
# def clear(self):
# self.memory = []
# self.capacity = 0
# def get_capacity(self):
# return self.capacity
# def sample(self, batch_size):
# samples = zip(*random.sample(self.memory, batch_size))
# return map(lambda x: torch.cat(x, 0), samples)