forked from varunagrawal/tiny-faces-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
95 lines (75 loc) · 3.26 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os
import os.path as osp
import torch
from torch import optim
from torchvision import transforms
import trainer
from datasets import get_dataloader
from models.loss import DetectionCriterion
from models.model import DetectionModel
def arguments():
parser = argparse.ArgumentParser()
parser.add_argument("traindata")
parser.add_argument("valdata")
parser.add_argument("--dataset-root", default="")
parser.add_argument("--dataset", default="WIDERFace")
parser.add_argument("--lr", default=1e-4, type=float)
parser.add_argument("--weight-decay", default=0.0005, type=float)
parser.add_argument("--momentum", default=0.9, type=float)
parser.add_argument("--batch_size", default=12, type=int)
parser.add_argument("--workers", default=8, type=int)
parser.add_argument("--start-epoch", default=0, type=int)
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--save-every", default=10, type=int)
parser.add_argument("--resume", default="")
parser.add_argument("--debug", action="store_true")
return parser.parse_args()
def main():
args = arguments()
num_templates = 25 # aka the number of clusters
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
img_transforms = transforms.Compose([
transforms.ToTensor(),
normalize
])
train_loader, _ = get_dataloader(args.traindata, args, num_templates,
img_transforms=img_transforms)
model = DetectionModel(num_objects=1, num_templates=num_templates)
loss_fn = DetectionCriterion(num_templates)
# directory where we'll store model weights
weights_dir = "weights"
if not osp.exists(weights_dir):
os.mkdir(weights_dir)
# check for CUDA
if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
optimizer = optim.SGD(model.learnable_parameters(args.lr), lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay)
# optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
if args.resume:
checkpoint = torch.load(args.resume)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
# Set the start epoch if it has not been
if not args.start_epoch:
args.start_epoch = checkpoint['epoch']
scheduler = optim.lr_scheduler.StepLR(optimizer,
step_size=20,
last_epoch=args.start_epoch-1)
# train and evalute for `epochs`
for epoch in range(args.start_epoch, args.epochs):
trainer.train(model, loss_fn, optimizer, train_loader, epoch, device=device)
scheduler.step()
if (epoch+1) % args.save_every == 0:
trainer.save_checkpoint({
'epoch': epoch + 1,
'batch_size': train_loader.batch_size,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}, filename="checkpoint_{0}.pth".format(epoch+1), save_path=weights_dir)
if __name__ == '__main__':
main()