Skip to content

Commit

Permalink
Update lr_schedules.py (microsoft#4563)
Browse files Browse the repository at this point in the history
add cosine annealing scheduler

this scheduler is widely used in image classification task, and many llm
(e.g. llama) use this also.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Nov 10, 2023
1 parent da652d0 commit 4388a60
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 2 deletions.
113 changes: 111 additions & 2 deletions deepspeed/runtime/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
ONE_CYCLE = 'OneCycle'
WARMUP_LR = 'WarmupLR'
WARMUP_DECAY_LR = 'WarmupDecayLR'
VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR]
WARMUP_COSINE_LR = 'WarmupCosineLR'
VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR, WARMUP_COSINE_LR]

LR_RANGE_TEST_MIN_LR = 'lr_range_test_min_lr'
LR_RANGE_TEST_STEP_RATE = 'lr_range_test_step_rate'
Expand Down Expand Up @@ -50,6 +51,9 @@
WARMUP_LOG_RATE = 'log'
WARMUP_LINEAR_RATE = 'linear'

WARMUP_MIN_RATIO = 'warmup_min_ratio'
COS_MIN_RATIO = 'cos_min_ratio'

TOTAL_NUM_STEPS = 'total_num_steps'


Expand Down Expand Up @@ -109,6 +113,11 @@ def add_tuning_arguments(parser):
type=str,
default=WARMUP_LOG_RATE,
help='WarmupLR increasing function during warmup')

# WarmUP cos LR
group.add_argument("--warmup_min_ratio", type=float, default=0.01, help='Cosine LR lower bound.')
group.add_argument("--cos_min_ratio", type=float, default=0.01, help='Cosine LR lower bound.')

return parser


Expand Down Expand Up @@ -457,7 +466,6 @@ def __init__(self,
if cycle_momentum:
self._initialize_momentum(self.optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate,
last_batch_iteration)

# Initialize batch iteration tracker
self.last_batch_iteration = last_batch_iteration

Expand Down Expand Up @@ -761,3 +769,104 @@ def _get_gamma(self):
0.0,
float(self.total_num_steps - self.last_batch_iteration) /
float(max(1.0, self.total_num_steps - self.warmup_num_steps)))


class WarmupCosineLR(object):
"""Increase the learning rate of each parameter group from min lr ratio to max lr ratio
over warmup_num_steps steps, and then decay at cosine rate over the remaining training steps to min cosine ratio.
Args:
optimizer (Optimizer): Wrapped optimizer.
total_num_steps (int): total number of training steps
warmup_min_ratio (float or list): warmup start learning rate ratio. Default: 0
warmup_num_steps (int): number of steps to warm up from warmup_min_ratio to 1.0. Default: 1000
warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log
cos_min_ratio (float): cosine end learning rate ratio. Default: 0.0001
last_batch_iteration (int): The index of the last batch. Default: -1.
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = WarmupCosineLR(optimizer, 1000000)
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
>>> scheduler.step()
"""

def __init__(self,
optimizer: Optimizer,
total_num_steps: int,
warmup_min_ratio: float = 0.0,
warmup_num_steps: int = 1000,
cos_min_ratio: float = 0.0001,
warmup_type: str = WARMUP_LOG_RATE,
last_batch_iteration: int = -1):

self.optimizer = get_torch_optimizer(optimizer)

self.total_num_steps = total_num_steps
self.last_batch_iteration = last_batch_iteration
self.cos_min_ratio = cos_min_ratio

self.warmup_type = warmup_type
self.warmup_min_ratio = warmup_min_ratio
self.warmup_num_steps = max(2, warmup_num_steps)
self.inverse_log_warm_up = 1.0 / math.log(warmup_num_steps)

if self.total_num_steps < self.warmup_num_steps:
logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format(
total_num_steps, warmup_num_steps))
self.org_lrs = [group['lr'] for group in self.optimizer.param_groups]

def get_lr_ratio(self):
if self.last_batch_iteration < 0:
logger.warning("Attempting to get learning rate from scheduler before it has started")
return [0.0]

if self.last_batch_iteration < self.warmup_num_steps:
if self.warmup_type == WARMUP_LOG_RATE:
ratio = self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1)
elif self.warmup_type == WARMUP_LINEAR_RATE:
ratio = self.last_batch_iteration / self.warmup_num_steps
ratio_delta = 1. - self.warmup_min_ratio
ratio = self.warmup_min_ratio + ratio * ratio_delta
return ratio

real_last_step = self.last_batch_iteration - self.warmup_num_steps + 1
real_total_steps = self.total_num_steps - self.warmup_num_steps
ratio_delta = 1. - self.cos_min_ratio
ratio = (1 + math.cos(math.pi * real_last_step / real_total_steps)) / 2
ratio = max(0.0, self.cos_min_ratio + ratio_delta * ratio)
return ratio

def step(self, last_batch_iteration=None):
if last_batch_iteration is None:
last_batch_iteration = self.last_batch_iteration + 1
self.last_batch_iteration = last_batch_iteration

lrs = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, lrs):
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

def get_lr(self):
if self.last_batch_iteration < 0:
logger.warning("Attempting to get learning rate from scheduler before it has started")
return [0.0]
lr_ratio = self.get_lr_ratio()
return [org_lr * lr_ratio for org_lr in self.org_lrs]

def state_dict(self):
return {'last_batch_iteration': self.last_batch_iteration}

def load_state_dict(self, sd):
self.last_batch_iteration = sd['last_batch_iteration']

def _format_param(self, optimizer, param_value, param_name):
if isinstance(param_value, list) or isinstance(param_value, tuple):
if len(param_value) != len(optimizer.param_groups):
raise ValueError("expected {} value for {}, got {}".format(len(optimizer.param_groups), param_name,
FileNotFoundError(param_value)))
return list(param_value)
return [param_value] * len(optimizer.param_groups)
69 changes: 69 additions & 0 deletions tests/unit/runtime/test_lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from deepspeed.runtime.lr_schedules import ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR, CYCLE_FIRST_STEP_SIZE, DECAY_LR_RATE, DECAY_STEP_SIZE
from deepspeed.runtime.lr_schedules import CYCLE_MIN_MOM, CYCLE_MAX_MOM, DECAY_MOM_RATE
from deepspeed.runtime.lr_schedules import WARMUP_DECAY_LR, TOTAL_NUM_STEPS
from deepspeed.runtime.lr_schedules import WARMUP_COSINE_LR, WARMUP_MIN_RATIO, COS_MIN_RATIO


def _verify_continuous_decrease(values):
Expand Down Expand Up @@ -441,3 +442,71 @@ def test_mom(self, min_mom, max_mom, decay_rate, step_size):
# Verify decay phase
if decay_rate > 0:
_verify_continuous_increase(step_moms[(step_size * 2):])


class TestWarmupCosineLR(DistributedTest):
world_size = 1

@pytest.mark.parametrize("total_num_steps, warmup_num_steps, cos_min_ratio, warmup_min_ratio",
[
(100, 10, 0.1, 0.2),
(200, 20, 0.1, 0.2),
(500, 30, 0.0, 0.2),
(600, 300, 0.1, 0.0),
(600, 550, 0.0, 0.0),
]) # yapf: disable
def test_lr(self, total_num_steps, warmup_num_steps, cos_min_ratio, warmup_min_ratio):
opt_lr = 0.0015
config_dict = {
"train_batch_size": 2,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": opt_lr
},
},
"scheduler": {
"type": WARMUP_COSINE_LR,
"params": {
TOTAL_NUM_STEPS: total_num_steps,
WARMUP_MIN_RATIO: warmup_min_ratio,
WARMUP_NUM_STEPS: warmup_num_steps,
COS_MIN_RATIO: cos_min_ratio,
}
},
"gradient_clipping": 1.0
}
hidden_dim = 10

model = SimpleModel(hidden_dim, empty_grad=False)
model, _, _, lr_scheduler = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=model.parameters())
data_loader = random_dataloader(model=model,
total_samples=max(50, total_num_steps * 3),
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float)

step_lrs = []
for _, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
step_lrs.extend(lr_scheduler.get_lr())

# Verify starting lr
assert abs(step_lrs[0] - opt_lr * warmup_min_ratio) < 1e-7

# Verify peak lr
assert abs(step_lrs[warmup_num_steps - 1] - opt_lr) < 1e-7

# Verify end lr
assert abs(step_lrs[total_num_steps - 1] - opt_lr * cos_min_ratio) < 1e-7

# Verify increasing phase
_verify_continuous_increase(step_lrs[:warmup_num_steps])

# Verify decreasing phase
_verify_continuous_decrease(step_lrs[warmup_num_steps:total_num_steps])

0 comments on commit 4388a60

Please sign in to comment.