forked from coursat-ai/MultiCheXNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
callback.py
59 lines (51 loc) · 2.37 KB
/
callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class CustomCallBacks():
"""
How to Use:
monitor = CustomeCallBacksObject.calls('DetectionModel.h5',
checkpointMonitor = 'val_loss',
checkpointMode = 'min',
earlyStopMonitor = 'val_loss',
earlyStopPatience = 10,
earlyStopMode = 'auto',
useReduceOnPlateau = False,
useLearningRateScheduler = True)
model.fit(..., callbacks=monitor)
"""
def calls(self,
model_name,
checkpointMonitor = 'val_loss',
checkpointMode = 'min',
earlyStopMonitor = 'val_loss',
earlyStopPatience = 10,
earlyStopMode = 'auto',
useReduceOnPlateau = False,
useLearningRateScheduler = False
):
checkpoint = tf.keras.callbacks.ModelCheckpoint(model_name,
monitor= checkpointMonitor,
verbose=1,
save_best_only=True,
save_weights_only=False,
mode=checkpointMode,
period=1)
early = tf.keras.callbacks.EarlyStopping(monitor=earlyStopMonitor,
min_delta=0,
patience=earlyStopPatience,
verbose=1,
mode=earlyStopMode)
class myCallBack(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if (logs.get('val_accuracy') > 0.998):
print ('\nReached 0.998 Validation accuracy!')
self.model.stop_training = True
my_call = myCallBack()
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.05, patience=5,
verbose=1, mode='min', min_delta=0,
cooldown=0, min_lr=0)
lr_schedule = tf.keras.callbacks.LearningRateScheduler(lambda epoch: 1e-5 * 10**(epoch / 10))
monitor = [checkpoint, early]
if useReduceOnPlateau:
monitor.append(reduce_lr)
if useLearningRateScheduler:
monitor.append(lr_schedule)
return monitor