forked from husencd/DriverPostureClassification
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
52 lines (40 loc) · 1.64 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import time
from utils import AverageMeter, calculate_accuracy
def test(data_loader, model, args, device):
batch_time = AverageMeter()
data_time = AverageMeter()
top1 = AverageMeter()
top3 = AverageMeter()
# switch to evaluate mode
model.eval()
end_time = time.time()
for i, (input, target) in enumerate(data_loader):
# measure data loading time
data_time.update(time.time() - end_time)
input = input.to(device)
target = target.to(device)
# compute output and loss
output = model(input)
# measure accuracy and record loss
prec1, prec3 = calculate_accuracy(output, target, topk=(1, 3))
# prec1[0]: convert torch.Size([1]) to torch.Size([])
top1.update(prec1[0].item(), input.size(0))
top3.update(prec3[0].item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end_time)
end_time = time.time()
if (i + 1) % args.log_interval == 0:
print('Test Iter [{0}/{1}]\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@3 {top3.val:.3f} ({top3.avg:.3f})\t'
'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data Time {data_time.val:.3f} ({data_time.avg:.3f})'.format(
i + 1,
len(data_loader),
top1=top1,
top3=top3,
batch_time=batch_time,
data_time=data_time))
print(' * Prec@1 {top1.avg:.2f}% | Prec@3 {top3.avg:.2f}%'.format(
top1=top1, top3=top3))
return top1.avg