-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_teacher.py
107 lines (85 loc) · 3.25 KB
/
train_teacher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import print_function
import os
import argparse
import time
import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn
from models import model_dict
from dataset.cifar100 import get_cifar100_dataloaders
from dataset.imagenet import get_imagenet_dataloader
from helper.util import adjust_learning_rate
from helper.loops import train_vanilla as train, validate
from helper.util import read_cfg
def main():
best_acc = 0
parser = argparse.ArgumentParser()
parser.add_argument("-cfg", default="./configs/train_vallina/res32x4.yaml", metavar="FILE",
help="path to config file", type=str)
args = parser.parse_args()
opt = read_cfg(args.cfg)
# dataloader
if opt.dataset == 'cifar100':
train_loader, val_loader = get_cifar100_dataloaders(opt, batch_size=opt.batch_size, num_workers=opt.num_workers)
n_cls = 100
elif opt.dataset == 'tinyimagenet':
train_loader, val_loader = get_imagenet_dataloader(opt, batch_size=opt.batch_size, num_workers=opt.num_workers)
n_cls = 200
else:
raise NotImplementedError(opt.dataset)
# model
model = model_dict[opt.model](num_classes=n_cls)
# optimizer
optimizer = optim.SGD(model.parameters(),
lr=opt.learning_rate,
momentum=opt.momentum,
weight_decay=opt.weight_decay)
criterion = nn.CrossEntropyLoss()
if torch.cuda.is_available():
model = model.cuda()
criterion = criterion.cuda()
cudnn.benchmark = True
# routine
for epoch in range(1, opt.epochs + 1):
adjust_learning_rate(epoch, opt, optimizer)
print("==> training...")
time1 = time.time()
train_acc, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt)
time2 = time.time()
print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
test_acc, test_acc_top5, test_loss = validate(val_loader, model, criterion, opt)
# save the best model
if test_acc > best_acc:
best_acc = test_acc
state = {
'epoch': epoch,
'model': model.state_dict(),
'best_acc': best_acc,
'optimizer': optimizer.state_dict(),
}
save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model))
print('saving the best model!')
torch.save(state, save_file)
# regular saving
if epoch % opt.save_freq == 0:
print('==> Saving...')
state = {
'epoch': epoch,
'model': model.state_dict(),
'accuracy': test_acc,
'optimizer': optimizer.state_dict(),
}
save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
torch.save(state, save_file)
print('best accuracy:', best_acc)
# save model
state = {
'opt': opt,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model))
torch.save(state, save_file)
if __name__ == '__main__':
main()