-
Notifications
You must be signed in to change notification settings - Fork 1
/
visu_paper.py
83 lines (64 loc) · 1.97 KB
/
visu_paper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 31 13:28:33 2017
@author: mducoffe
visu curve
"""
import numpy as np
import pylab as pl
#%%
# step 1 read csv file
from contextlib import closing
import csv
import os
filename="random.csv"
def get_actif_data(repository, filename, max_value=800):
x_labels=[]
y_acc=[]
y_max = 0
with closing(open(os.path.join(repository, filename))) as f:
csv_f = csv.reader(f, delimiter=';', quotechar='|')
for row in csv_f:
x, y = int(row[0]), float(row[1])
if x >=max_value:
continue
"""
if y <0.25 and x >600:
continue
"""
if y < y_max:
y = y_max
else:
y_max = y
x_labels.append(x)
y_acc.append(y)
return x_labels, y_acc
#%%
repository="data/csv"
dataset='BagShoes'
network='LeNet5'
repository = os.path.join(repository, '{}/{}'.format(dataset, network))
methods = ['random','aaq', 'saaq', 'uncertainty', 'bald', 'egl']
filenames =['{}_{}_'.format(dataset, network)+str(method)+'.csv' for method in methods]
#filenames=['CIFAR_VGG_random.csv', 'CIFAR_VGG_egl.csv', 'CIFAR_LeNet5_uncertainty.csv']
legends=methods
linestyles=['r-', 'b--', 'b-', 'g-', 'k-', 'c-', 'k-']
dico_actif={}
for filename, legend, linestyle in zip(filenames, legends, linestyles):
actif_key=filename.split('.csv')[0]
print((actif_key, linestyle))
dico_actif[actif_key]=[get_actif_data(repository, filename), legend, linestyle]
#%%
pl.figure(1)
pl.clf()
for key in dico_actif:
data, legend, linestyle = dico_actif[key]
x_labels, y_acc = data
pl.plot(x_labels,y_acc,linestyle, label=legend)
pl.hold(True)
pl.grid()
pl.hold(False)
pl.legend(bbox_to_anchor=(0.5, 0.6), loc=2, borderaxespad=0.)
pl.savefig('img/test_acc_{}_{}.pdf'.format(dataset, network), dpi=300, bbox_inches='tight')
#pl.plot(ytest,yest,'+')
#%%