-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
79 lines (68 loc) · 2.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import random
from easydict import EasyDict
from data import *
import pandas as pd
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def default_arg():
args = EasyDict(
{
'model': 'resnext',
'augmix': True,
'opt': 'adam',
'lr': 1e-6,
'init': 'zero',
'epochs': 2,
'reg': 'none',
'policy': 'resnet18',
'sigmoid': False,
'est': 'vanilla',
'transform': 'geometry',
'num_samples': 12,
'paths': {
'wideresnet': 'model/model_wrn_best.pth.tar',
'resnext': 'model/model_resnext_best.pth.tar'
},
'student_paths': {
'wideresnet': 'model/wrn_student_augmix.pth',
'resnext': 'model/resnext_student_augmix.pth'
},
'checkpoints': 'checkpoints',
'cifar': 'cifar',
'cifarc': 'cifar_corrupt/CIFAR-100-C'
}
)
return args
# test functions
def test_corrupt(tta, model, cifar_corrupt_dir):
accs = []
for c in CORRUPTIONS:
loader = get_corrupt_loader(c, 256, cifar_corrupt_dir)
acc = test(tta, model, loader)
accs.append(acc)
return accs, np.mean(accs)
def test(tta, model, loader):
score, total = 0, 0
for inputs, labels in loader:
inputs = inputs.to(device)
labels = labels.to(device)
augmented = tta(inputs)
outputs = model(augmented)
_, preds = outputs.max(dim=1)
score += (labels == preds).sum().item()
total += inputs.size(0)
acc = score / total
return acc
# print result
def show_corrupted_acc(acc, avg_acc):
df = {}
for i, c in enumerate(CORRUPTIONS):
df[c] = acc[i]
df['average'] = avg_acc
df = pd.DataFrame(df)
print(df)