forked from zhutmost/lsq-net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
108 lines (87 loc) · 4.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import logging
from pathlib import Path
import torch as t
import yaml
import process
import util
from model import create_model
def main():
script_dir = Path.cwd()
args = util.get_config(default_file=script_dir / 'config.yaml')
output_dir = script_dir / args.output_dir
output_dir.mkdir(exist_ok=True)
log_dir = util.init_logger(args.name, output_dir, 'logging.conf')
logger = logging.getLogger()
with open(log_dir / "args.yaml", "w") as yaml_file: # dump experiment config
yaml.safe_dump(args, yaml_file)
pymonitor = util.ProgressMonitor(logger)
tbmonitor = util.TensorBoardMonitor(logger, log_dir)
monitors = [pymonitor, tbmonitor]
if args.device.type == 'cpu' or not t.cuda.is_available() or args.device.gpu == []:
args.device.gpu = []
else:
available_gpu = t.cuda.device_count()
for dev_id in args.device.gpu:
if dev_id >= available_gpu:
logger.error('GPU device ID {0} requested, but only {1} devices available'
.format(dev_id, available_gpu))
exit(1)
# Set default device in case the first one on the list
t.cuda.set_device(args.device.gpu[0])
# Enable the cudnn built-in auto-tuner to accelerating training, but it
# will introduce some fluctuations in a narrow range.
t.backends.cudnn.benchmark = True
t.backends.cudnn.deterministic = False
# Create the model
model = create_model(args)
start_epoch = 0
perf_scoreboard = process.PerformanceScoreboard(args.log.num_best_scores)
if args.resume.path:
model, start_epoch, _ = util.load_checkpoint(
model, args.resume.path, args.device.type, lean=args.resume.lean)
# Initialize data loader
train_loader, val_loader, test_loader = util.load_data(
args.dataloader.dataset, args.dataloader.path, args.batch_size,
args.dataloader.workers, args.dataloader.val_split)
logger.info('Dataset `%s` size:' % args.dataloader.dataset +
'\n training = %d (%d)' % (len(train_loader.sampler), len(train_loader)) +
'\n validation = %d (%d)' % (len(val_loader.sampler), len(val_loader)) +
'\n test = %d (%d)' % (len(test_loader.sampler), len(test_loader)))
# Define loss function (criterion) and optimizer
criterion = t.nn.CrossEntropyLoss().to(args.device.type)
# optimizer = t.optim.Adam(model.parameters(), lr=args.optimizer.learning_rate)
optimizer = t.optim.SGD(model.parameters(),
lr=args.optimizer.learning_rate,
momentum=args.optimizer.momentum,
weight_decay=args.optimizer.weight_decay)
lr_scheduler = util.lr_scheduler(optimizer,
batch_size=train_loader.batch_size,
num_samples=len(train_loader.sampler),
**args.lr_scheduler)
logger.info(('Optimizer: %s' % optimizer).replace('\n', '\n' + ' ' * 11))
logger.info('LR scheduler: %s\n' % lr_scheduler)
if args.eval:
process.validate(test_loader, model, criterion, -1, monitors, args)
else: # training
if args.resume.path or args.pre_trained:
logger.info('>>>>>>>> Epoch -1 (pre-trained model evaluation)')
top1, top5, _ = process.validate(val_loader, model, criterion,
start_epoch - 1, monitors, args)
perf_scoreboard.update(top1, top5, start_epoch - 1)
for epoch in range(start_epoch, args.epochs):
logger.info('>>>>>>>> Epoch %3d' % epoch)
t_top1, t_top5, t_loss = process.train(train_loader, model, criterion, optimizer,
lr_scheduler, epoch, monitors, args)
v_top1, v_top5, v_loss = process.validate(val_loader, model, criterion, epoch, monitors, args)
tbmonitor.writer.add_scalars('Train_vs_Validation/Loss', {'train': t_loss, 'val': v_loss}, epoch)
tbmonitor.writer.add_scalars('Train_vs_Validation/Top1', {'train': t_top1, 'val': v_top1}, epoch)
tbmonitor.writer.add_scalars('Train_vs_Validation/Top5', {'train': t_top5, 'val': v_top5}, epoch)
perf_scoreboard.update(v_top1, v_top5, epoch)
is_best = perf_scoreboard.is_best(epoch)
util.save_checkpoint(epoch, args.arch, model, {'top1': v_top1, 'top5': v_top5}, is_best, args.name, log_dir)
logger.info('>>>>>>>> Epoch -1 (final model evaluation)')
process.validate(test_loader, model, criterion, -1, monitors, args)
logger.info('Program completed successfully ... exiting ...')
logger.info('If you have any questions or suggestions, please visit: github.com/zhutmost/lsq-net')
if __name__ == "__main__":
main()