Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
husencd committed Jun 5, 2018
1 parent cfaf3b5 commit aa8c5cc
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def main():
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, threshold=0.001, patience=args.lr_patience)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, threshold=0.001, patience=args.lr_patience)

lr_mult = []
for param_group in optimizer.param_groups:
Expand Down Expand Up @@ -98,7 +98,7 @@ def main():
vis = Visualizer(env=args.env)
for epoch in range(args.begin_epoch, args.epochs + 1):
if args.train:
adjust_learning_rate(optimizer, epoch, lr_mult)
adjust_learning_rate(optimizer, epoch, lr_mult, args)
train_epoch(epoch, train_loader, model, criterion, optimizer, args, device, train_logger, train_batch_logger, vis)
print()

Expand Down Expand Up @@ -139,11 +139,14 @@ def main():
test.test(test_loader, model, args, device)


def adjust_learning_rate(optimizer, epoch, lr_mult):
def adjust_learning_rate(optimizer, epoch, lr_mult, args):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1**((epoch - 1) // 30))
for i, param_group in enumerate(optimizer.param_groups):
param_group['lr'] = lr * lr_mult[i]
if args.finetune and args.ft_begin_index:
param_group['lr'] = lr * lr_mult[i]
else:
param_group['lr'] = lr


if __name__ == '__main__':
Expand Down

0 comments on commit aa8c5cc

Please sign in to comment.