-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathcosine.py
25 lines (21 loc) · 922 Bytes
/
cosine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import math
import torch
class WarmupCosineScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_epochs: int,
total_epochs: int,
steps_per_epoch: int,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_epochs * steps_per_epoch
self.total_steps = total_epochs * steps_per_epoch
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count < self.warmup_steps:
return [self._step_count / self.warmup_steps * base_lr for base_lr in self.base_lrs]
decay_steps = self.total_steps - self.warmup_steps
cos_val = math.cos(math.pi * (self._step_count - self.warmup_steps) / decay_steps)
return [0.5 * base_lr * (1 + cos_val) for base_lr in self.base_lrs]