-
Notifications
You must be signed in to change notification settings - Fork 9
/
callbacks.py
70 lines (64 loc) · 3.05 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from keras.callbacks import TensorBoard, ModelCheckpoint
import tensorflow as tf
import numpy as np
class CustomTensorBoard(TensorBoard):
""" to log the loss after each batch
"""
def __init__(self, log_every=1, **kwargs):
super(CustomTensorBoard, self).__init__(**kwargs)
self.log_every = log_every
self.counter = 0
def on_batch_end(self, batch, logs=None):
self.counter+=1
if self.counter%self.log_every==0:
for name, value in logs.items():
if name in ['batch', 'size']:
continue
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value.item()
summary_value.tag = name
self.writer.add_summary(summary, self.counter)
self.writer.flush()
super(CustomTensorBoard, self).on_batch_end(batch, logs)
class CustomModelCheckpoint(ModelCheckpoint):
""" to save the template model, not the multi-GPU model
"""
def __init__(self, model_to_save, **kwargs):
super(CustomModelCheckpoint, self).__init__(**kwargs)
self.model_to_save = model_to_save
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = self.filepath.format(epoch=epoch + 1, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.model_to_save.save_weights(filepath, overwrite=True)
else:
self.model_to_save.save(filepath, overwrite=True)
else:
if self.verbose > 0:
print('\nEpoch %05d: %s did not improve from %0.5f' %
(epoch + 1, self.monitor, self.best))
else:
if self.verbose > 0:
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.model_to_save.save_weights(filepath, overwrite=True)
else:
self.model_to_save.save(filepath, overwrite=True)
super(CustomModelCheckpoint, self).on_batch_end(epoch, logs)