forked from XinJCheng/CSPN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lr_scheduler.py
108 lines (96 loc) · 4.15 KB
/
lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
import warnings
from torch.optim.optimizer import Optimizer
import math
class ReduceLROnPlateau(object):
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate by a factor
of 2-10 once learning stagnates. This scheduler reads a metrics
quantity and if no improvement is seen for a 'patience' number
of epochs, the learning rate is reduced.
Args:
factor: factor by which the learning rate will
be reduced. new_lr = lr * factor
patience: number of epochs with no improvement
after which learning rate will be reduced.
verbose: int. 0: quiet, 1: update messages.
mode: one of {min, max}. In `min` mode,
lr will be reduced when the quantity
monitored has stopped decreasing; in `max`
mode it will be reduced when the quantity
monitored has stopped increasing.
epsilon: threshold for measuring the new optimum,
to only focus on significant changes.
cooldown: number of epochs to wait before resuming
normal operation after lr has been reduced.
min_lr: lower bound on the learning rate.
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
>>> for epoch in range(10):
>>> train(...)
>>> val_acc, val_loss = validate(...)
>>> scheduler.step(val_loss, epoch)
"""
def __init__(self, optimizer, mode='min', factor=0.1, patience=3,
verbose=0, epsilon=1e-4, cooldown=0, min_lr=0.000001):
super(ReduceLROnPlateau, self).__init__()
if factor >= 1.0:
raise ValueError('ReduceLROnPlateau '
'does not support a factor >= 1.0.')
self.factor = factor
self.min_lr = min_lr
self.epsilon = epsilon
self.patience = patience
self.verbose = verbose
self.cooldown = cooldown
self.cooldown_counter = 0 # Cooldown counter.
self.monitor_op = None
self.wait = 0
self.best = 0
self.mode = mode
assert isinstance(optimizer, Optimizer)
self.optimizer = optimizer
self._reset()
def _reset(self):
"""Resets wait counter and cooldown counter.
"""
if self.mode not in ['min', 'max']:
raise RuntimeError('Learning Rate Plateau Reducing mode %s is unknown!')
if self.mode == 'min' :
self.monitor_op = lambda a, b: np.less(a, b - self.epsilon)
self.best = np.Inf
else:
self.monitor_op = lambda a, b: np.greater(a, b + self.epsilon)
self.best = -np.Inf
self.cooldown_counter = 0
self.wait = 0
self.lr_epsilon = self.min_lr * 1e-4
def reset(self):
self._reset()
def step(self, metrics, epoch):
current = metrics
if current is None:
warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning)
else:
if self.in_cooldown():
self.cooldown_counter -= 1
self.wait = 0
if self.monitor_op(current, self.best):
self.best = current
self.wait = 0
elif not self.in_cooldown():
if self.wait >= self.patience:
for param_group in self.optimizer.param_groups:
old_lr = float(param_group['lr'])
if old_lr > self.min_lr + self.lr_epsilon:
new_lr = old_lr * self.factor
new_lr = max(new_lr, self.min_lr)
param_group['lr'] = new_lr
if self.verbose > 0:
print('\nEpoch %05d: reducing learning rate to %s.' % (epoch, new_lr))
self.cooldown_counter = self.cooldown
self.wait = 0
self.wait += 1
def in_cooldown(self):
return self.cooldown_counter > 0