-
Notifications
You must be signed in to change notification settings - Fork 1
/
GenericCollector.py
48 lines (35 loc) · 1.38 KB
/
GenericCollector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from collections import namedtuple
import torch
from RunningAverage import RunningStats
class GenericCollector:
def __init__(self):
self.keys = []
self._n_env = 0
self._buffer = {}
self._simple_stats = namedtuple('simple_stats', ['step', 'max', 'sum', 'mean', 'std'])
def init(self, n_env, **kwargs):
self._n_env = n_env
for k in kwargs:
self.add(k, kwargs[k])
def add(self, key, shape):
self.keys.append(key)
self._buffer[key] = RunningStats(shape, 'cpu', n=self._n_env)
def update(self, **kwargs):
for k in kwargs:
if k in self._buffer:
self._buffer[k].update(kwargs[k], reduction='none')
def reset(self, indices):
result = {}
for k in self._buffer:
result[k] = []
for i in indices:
result[k].append(self._evaluate(k, i))
self._buffer[k].reset(i)
result[k] = self._simple_stats(*tuple(map(list, zip(*result[k]))))
return result
def _evaluate(self, key, index):
return [self._buffer[key].count[index].item() - 1, self._buffer[key].max[index].item(), self._buffer[key].sum[index].item(), self._buffer[key].mean[index].item(), self._buffer[key].std[index].item()]
def clear(self):
self.keys.clear()
self._n_env = 0
self._buffer = {}