-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluator.py
104 lines (77 loc) · 4.99 KB
/
evaluator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import random, copy
import torch, yaml, csv
import numpy as np
import torch.nn as nn
from data.data_loader import build_datasets, build_data_loader
from models.basics import get_model
from attacks.utils import train, test
from attacks.member_inference import population_attack, reference_attack, shadow_attack
from analyzer.plot import *
from privacy_meter.audit_report import ROCCurveReport, SignalHistogramReport
def set_seed(seed_value):
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)
torch.cuda.manual_seed_all(seed_value)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--attack", type=str, default="population", help="Type of MIA attack", choices=["population", "reference", "shadow"])
parser.add_argument("--plot", type=str, default="True", help="Wheter to visualize results", choices=["True", "False"])
parser.add_argument("--model", type=str, default="searchable_alexnet", help="Model to evaluate")
parser.add_argument("--width", type=float, default=1.0, help="Width expand ratio")
parser.add_argument("--depth", type=int, default=1, help="Number of model layers")
parser.add_argument("--depth-multi", type=float, default=1.0, help="Number of model layers")
args = parser.parse_args()
cf = "./configs/"+args.attack+"_attack_evaluate.yaml"
with open(cf, "rb") as f:
configs = yaml.load(f, Loader=yaml.Loader)
configs['train']['model_name'] = args.model
configs['train']['width_multi'] = args.width
if args.model == 'searchable_mobilenet':
configs['train']['depth_multi'] = args.depth_multi
configs['attack']['test_name'] = 'test_'+args.model+'_w'+str(args.width)+'_d'+str(args.depth_multi)
else:
configs['train']['depth_multi'] = args.depth
configs['attack']['test_name'] = 'test_'+args.model+'_w'+str(args.width)+'_d'+str(args.depth)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_value = configs['run']['seed']
set_seed(seed_value)
print('\n ******************************** START of THE SCRIPT **************************************************** \n')
model = get_model(configs)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
model.to(device)
train_dataset, test_dataset = build_datasets(configs)
# train_model = copy.deepcopy(model)
# train_loader, test_loader, train_sampler = build_data_loader(configs)
# criterion, path = nn.CrossEntropyLoss(), configs['run']['saved_models']
# lr, w_decay = float(configs['train']['learning_rate']), float(configs['train']['weight_decay'])
# train_model = train(train_model, configs['train']['epochs'], configs['train']['optimizer'], criterion, lr, w_decay, train_loader, test_loader, device, path)
# train_loss, train_accuracy = test(train_model, train_loader, device, criterion)
# test_loss, test_accuracy = test(train_model, test_loader, device, criterion)
# print(f"Number of parameters in the model: {num_params} || Testset Accuracy {test_accuracy}")
if configs['attack']['type'] == 'population':
audit_results, infer_game, _ = population_attack(configs, model, train_dataset, test_dataset, device)
elif configs['attack']['type'] == 'reference':
audit_results, infer_game, _ = reference_attack(configs, model, train_dataset, test_dataset, device)
elif configs['attack']['type'] == 'shadow':
audit_results, infer_game, _ = shadow_attack(configs, model, train_dataset, test_dataset, device)
else:
raise NotImplementedError(f"{configs['attack']['type']} is not implemented")
logdir = configs['attack']['log_dir']+'/'+args.attack+'_'+configs['attack']['test_name']
ROCCurveReport.generate_report(metric_result=audit_results,
inference_game_type=infer_game, show=True, filename = logdir+'/roc_curve.png')
if args.attack == 'shadow':
SignalHistogramReport.generate_report(metric_result=audit_results[0][0], inference_game_type=infer_game,
show=True, filename = logdir+'/signal_histogram.png')
roc_auc = audit_results[0][0].roc_auc
else:
SignalHistogramReport.generate_report(metric_result=audit_results[0], inference_game_type=infer_game,
show=True, filename = logdir+'/signal_histogram.png')
roc_auc = plot_log_scale_roc_curve(audit_results, logdir, logdir + "/log_scale_roc_curve.png")
train_loss, train_accuracy, test_loss, test_accuracy = 0, 0, 0, 0
with open('./summary_results.csv', 'a') as f:
writer = csv.writer(f, delimiter=',')
writer.writerow([args.attack, configs['train']['model_name'], configs['train']['width_multi'], configs['train']['depth_multi'], \
num_params, train_loss, train_accuracy, test_loss, test_accuracy, roc_auc])