-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathlogger.py
141 lines (122 loc) · 5.14 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from collections import defaultdict
printable_method={'transgnn','gat'}
def create_print_dict(args):
if args.method=='transgnn':
return {'n_layer':args.num_layers,
'hidden_channels':args.hidden_channels,
'trans_heads':args.trans_heads,
'lr':args.lr,
'epochs':args.epochs}
elif args.method=='gat':
return {'n_layer':args.num_layers,
'hidden_channels':args.hidden_channels,
'gat_heads':args.gat_heads,
'lr':args.lr,
'epochs':args.epochs
}
else:
return None
class Logger(object):
""" Adapted from https://github.com/snap-stanford/ogb/ """
def __init__(self, runs, info=None):
self.info = info
self.results = [[] for _ in range(runs)]
def add_result(self, run, result):
assert len(result) == 4
assert run >= 0 and run < len(self.results)
self.results[run].append(result)
def print_statistics(self, run=None, mode='max_acc'):
if run is not None:
result = 100 * torch.tensor(self.results[run])
argmax = result[:, 1].argmax().item()
argmin = result[:, 3].argmin().item()
if mode == 'max_acc':
ind = argmax
else:
ind = argmin
print(f'Run {run + 1:02d}:')
print(f'Highest Train: {result[:, 0].max():.2f}')
print(f'Highest Valid: {result[:, 1].max():.2f}')
print(f'Highest Test: {result[:, 2].max():.2f}')
print(f'Chosen epoch: {ind+1}')
print(f'Final Train: {result[ind, 0]:.2f}')
print(f'Final Test: {result[ind, 2]:.2f}')
self.test=result[ind, 2]
else:
result = 100 * torch.tensor(self.results)
best_results = []
for r in result:
train1 = r[:, 0].max().item()
test1 = r[:, 2].max().item()
valid = r[:, 1].max().item()
if mode == 'max_acc':
train2 = r[r[:, 1].argmax(), 0].item()
test2 = r[r[:, 1].argmax(), 2].item()
else:
train2 = r[r[:, 3].argmin(), 0].item()
test2 = r[r[:, 3].argmin(), 2].item()
best_results.append((train1, test1, valid, train2, test2))
best_result = torch.tensor(best_results)
print(f'All runs:')
r = best_result[:, 0]
print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 1]
print(f'Highest Test: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 2]
print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 3]
print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}')
r = best_result[:, 4]
print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}')
self.test=r.mean()
return
def output(self,out_path,info):
with open(out_path,'a') as f:
f.write(info)
f.write(f'test acc:{self.test}\n')
class SimpleLogger(object):
""" Adapted from https://github.com/CUAI/CorrectAndSmooth """
def __init__(self, desc, param_names, num_values=2):
self.results = defaultdict(dict)
self.param_names = tuple(param_names)
self.used_args = list()
self.desc = desc
self.num_values = num_values
def add_result(self, run, args, values):
"""Takes run=int, args=tuple, value=tuple(float)"""
assert(len(args) == len(self.param_names))
assert(len(values) == self.num_values)
self.results[run][args] = values
if args not in self.used_args:
self.used_args.append(args)
def get_best(self, top_k=1):
all_results = []
for args in self.used_args:
results = [i[args] for i in self.results.values() if args in i]
results = torch.tensor(results)*100
results_mean = results.mean(dim=0)[-1]
results_std = results.std(dim=0)
all_results.append((args, results_mean))
results = sorted(all_results, key=lambda x: x[1], reverse=True)[:top_k]
return [i[0] for i in results]
def prettyprint(self, x):
if isinstance(x, float):
return '%.2f' % x
return str(x)
def display(self, args = None):
disp_args = self.used_args if args is None else args
if len(disp_args) > 1:
print(f'{self.desc} {self.param_names}, {len(self.results.keys())} runs')
for args in disp_args:
results = [i[args] for i in self.results.values() if args in i]
results = torch.tensor(results)*100
results_mean = results.mean(dim=0)
results_std = results.std(dim=0)
res_str = f'{results_mean[0]:.2f} ± {results_std[0]:.2f}'
for i in range(1, self.num_values):
res_str += f' -> {results_mean[i]:.2f} ± {results_std[1]:.2f}'
print(f'Args {[self.prettyprint(x) for x in args]}: {res_str}')
if len(disp_args) > 1:
print()
return results