-
Notifications
You must be signed in to change notification settings - Fork 10
/
logger.py
67 lines (57 loc) · 2.27 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import yaml
import csv
import h5py
class Logger(object):
def __init__(self, logdir):
self.logdir = logdir
if not os.path.isdir(logdir):
os.makedirs(logdir)
self.cfg_file = os.path.join(self.logdir, 'cfg.yaml')
self.acc_file = os.path.join(self.logdir, 'acc.csv')
self.loss_file = os.path.join(self.logdir, 'loss.csv')
self.ws_file = os.path.join(self.logdir, 'ws.h5')
self.acc_keys = None
self.loss_keys = None
self.logging_ws = False
def log_cfg(self, cfg):
print('===> Saving cfg parameters to: ', self.cfg_file)
with open(self.cfg_file, 'w') as f:
yaml.dump(cfg, f)
def log_acc(self, accs):
if self.acc_keys is None:
self.acc_keys = [k for k in accs.keys()]
with open(self.acc_file, 'w') as f:
writer = csv.DictWriter(f, fieldnames=self.acc_keys)
writer.writeheader()
writer.writerow(accs)
else:
with open(self.acc_file, 'a') as f:
writer = csv.DictWriter(f, fieldnames=self.acc_keys)
writer.writerow(accs)
def log_loss(self, losses):
# valid_losses = {k: v for k, v in losses.items() if v is not None}
valid_losses = losses
if self.loss_keys is None:
self.loss_keys = [k for k in valid_losses.keys()]
with open(self.loss_file, 'w') as f:
writer = csv.DictWriter(f, fieldnames=self.loss_keys)
writer.writeheader()
writer.writerow(valid_losses)
else:
with open(self.loss_file, 'a') as f:
writer = csv.DictWriter(f, fieldnames=self.loss_keys)
writer.writerow(valid_losses)
def log_ws(self, e, ws):
mode = 'a' if self.logging_ws else 'w'
self.logging_ws = True
key = 'Epoch{:02d}'.format(e)
with h5py.File(self.ws_file, mode) as f:
g = f.create_group(key)
for k, v in ws.items():
g.create_dataset(k, data=v)