forked from Git-123-Hub/maddpg-pettingzoo-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBuffer.py
54 lines (42 loc) · 2 KB
/
Buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import torch
class Buffer:
"""replay buffer for each agent"""
def __init__(self, capacity, obs_dim, act_dim, device):
self.capacity = capacity
self.obs = np.zeros((capacity, obs_dim))
self.action = np.zeros((capacity, act_dim))
self.reward = np.zeros(capacity)
self.next_obs = np.zeros((capacity, obs_dim))
self.done = np.zeros(capacity, dtype=bool)
self._index = 0
self._size = 0
self.device = device
def add(self, obs, action, reward, next_obs, done):
""" add an experience to the memory """
self.obs[self._index] = obs
self.action[self._index] = action
self.reward[self._index] = reward
self.next_obs[self._index] = next_obs
self.done[self._index] = done
self._index = (self._index + 1) % self.capacity
if self._size < self.capacity:
self._size += 1
def sample(self, indices):
# retrieve data, Note that the data stored is ndarray
obs = self.obs[indices]
action = self.action[indices]
reward = self.reward[indices]
next_obs = self.next_obs[indices]
done = self.done[indices]
# NOTE that `obs`, `action`, `next_obs` will be passed to network(nn.Module),
# so the first dimension should be `batch_size`
obs = torch.from_numpy(obs).float().to(self.device) # torch.Size([batch_size, state_dim])
action = torch.from_numpy(action).float().to(self.device) # torch.Size([batch_size, action_dim])
reward = torch.from_numpy(reward).float().to(self.device) # just a tensor with length: batch_size
# reward = (reward - reward.mean()) / (reward.std() + 1e-7)
next_obs = torch.from_numpy(next_obs).float().to(self.device) # Size([batch_size, state_dim])
done = torch.from_numpy(done).float().to(self.device) # just a tensor with length: batch_size
return obs, action, reward, next_obs, done
def __len__(self):
return self._size