-
Notifications
You must be signed in to change notification settings - Fork 38
/
utils.py
275 lines (223 loc) · 8.73 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
# %%
import pickle
import jsonlines
import nltk
from nltk.tokenize import sent_tokenize
from nltk import word_tokenize
import numpy as np
from tabulate import tabulate
from mosestokenizer import *
import random
from random import choices
import os
import sys
import re
from collections import defaultdict as ddict
from scipy.stats import pearsonr, spearmanr, kendalltau
# nltk.download('stopwords')
detokenizer = MosesDetokenizer('en')
def read_pickle(file):
with open(file, 'rb') as f:
data = pickle.load(f)
return data
def save_pickle(data, file):
with open(file, 'wb') as f:
pickle.dump(data, f)
print(f'Saved to {file}.')
def read_file_to_list(file_name):
lines = []
with open(file_name, 'r', encoding='utf8') as f:
for line in f.readlines():
lines.append(line.strip())
return lines
def write_list_to_file(list_to_write, filename):
out_file = open(filename, 'w')
for line in list_to_write:
print(line, file=out_file)
out_file.flush()
out_file.close()
print(f'Saved to {filename}.')
def read_jsonlines_to_list(file_name):
lines = []
with jsonlines.open(file_name, 'r') as reader:
for obj in reader:
lines.append(obj)
return lines
def write_list_to_jsonline(list_to_write, filename):
with jsonlines.open(filename, 'w') as writer:
writer.write_all(list_to_write)
print(f'Saved to {filename}.')
def capitalize_sents(text: str):
""" Given a string, capitalize the initial letter of each sentence. """
sentences = sent_tokenize(text)
sentences = [sent.strip() for sent in sentences]
sentences = [sent.capitalize() for sent in sentences]
sentences = " ".join(sentences)
return sentences
def is_capitalized(text: str):
""" Given a string (system output etc.) , check whether it is lowercased,
or normally capitalized.
"""
return not text.islower()
def tokenize(text: str):
words = word_tokenize(text)
return " ".join(words)
def detokenize(text: str):
words = text.split(" ")
return detokenizer(words)
def use_original_bracket(text: str):
return text.replace('-lrb-', '(').replace('-rrb-', ')').replace('-LRB-', '(').replace('-RRB-', ')').replace('-lsb-',
'[').replace(
'-rsb-', ']').replace('-LSB-', '[').replace('-RSB-', ']')
# Disable print
def blockPrint():
sys.stdout = open(os.devnull, 'w')
# Restore print
def enablePrint():
sys.stdout = sys.__stdout__
def retrieve_scores(saveto):
def get_r_p_f(line):
line = line.split(" ")
r = float(line[-3][-7:])
p = float(line[-2][-7:])
f = float(line[-1][-7:])
return r, p, f
lines = read_file_to_list(saveto)
rouge1_list, rouge2_list, rougel_list = [], [], []
for line in lines:
if line.startswith('1 ROUGE-1 Eval'):
rouge1_list.append(get_r_p_f(line))
if line.startswith('1 ROUGE-2 Eval'):
rouge2_list.append(get_r_p_f(line))
if line.startswith('1 ROUGE-L Eval'):
rougel_list.append(get_r_p_f(line))
return rouge1_list, rouge2_list, rougel_list
def get_rank(data, metric):
""" Rank all systems based on a metric (avg score) """
scores = {} # {sysname: [scores]}
for doc_id in data:
sys_summs = data[doc_id]['sys_summs']
for sys_name in sys_summs:
score = sys_summs[sys_name]['scores'][metric]
scores.setdefault(sys_name, []).append(score)
scores = {k: sum(v) / len(v) for k, v in scores.items()}
new_scores = dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
# for k in new_scores:
# print(k)
return new_scores.keys()
def get_sents_from_tags(text, sent_start_tag, sent_end_tag):
sents = re.findall(r'%s (.+?) %s' % (sent_start_tag, sent_end_tag), text)
sents = [sent for sent in sents if len(sent) > 0] # remove empty sents
return sents
def get_metrics_list(sd):
"""
Does each system summary dict have same all_metrics?
:param sd: scores dict
:return: list of all_metrics in the scores dict
"""
metrics_tuple_set = set(
[tuple(sorted(list(x['scores'].keys())))
for d in sd.values() for x in d['sys_summs'].values()])
assert len(metrics_tuple_set) == 1, (
metrics_tuple_set, "all system summary score dicts should have the same set of all_metrics")
metrics_list = list(list(metrics_tuple_set)[0])
return metrics_list
def print_score_ranges(sd):
metrics_list = get_metrics_list(sd)
print_list = []
headers = ["min", "25-perc", "median", "75-perc", "max", "mean"]
for m in metrics_list:
scores = [s['scores'][m] for d in sd.values() for s in d['sys_summs'].values() if s['sys_summ'] != 'EMPTY']
print_list.append([m,
np.min(scores),
np.percentile(scores, 25),
np.median(scores),
np.percentile(scores, 75),
np.max(scores),
np.mean(scores)])
print(tabulate(print_list, headers=headers, floatfmt=".6f", tablefmt="simple"))
def get_system_level_scores(sd, metrics, agg='mean', nas=False):
"""
systems[system_name][metric] = average_score or list of scores
"""
systems = ddict(lambda: ddict(list))
for isd in sd.values():
for system_name, scores in isd['sys_summs'].items():
for m in metrics:
# Empty summary
if scores['sys_summ'] == 'EMPTY':
systems[system_name][m].append(None)
else:
systems[system_name][m].append(scores['scores'][m])
for system_name, scores in systems.items():
for m in scores:
all_scores = systems[system_name][m]
if agg == 'mean':
all_scores = [x for x in all_scores if x is not None]
systems[system_name][m] = np.mean(all_scores)
if nas:
min_scores = {}
max_scores = {}
for m in metrics:
min_scores[m] = np.min([systems[sys][m] for sys in systems.keys()])
max_scores[m] = np.max([systems[sys][m] for sys in systems.keys()])
for sys in systems:
systems[sys]['nas'] = np.mean([
(systems[sys][m] - min_scores[m]) / (max_scores[m] - min_scores[m]) for m in metrics
])
return systems
def get_topk(systems, k, metric='rouge2_f'):
systems_l = [(name, score[metric]) for name, score in systems.items()]
systems_l = sorted(systems_l, key=lambda x: x[1], reverse=True)
topk_system_names = [tup[0] for tup in systems_l[:k]]
return {name: systems[name] for name in topk_system_names}
def print_correlation(topk_systems, metric_pairs):
# disagreement between every pair of metrics for the topk
headers = ['metric_pair', 'pearson', 'spearman', 'kendalltau']
print_list = []
for pair in metric_pairs:
if 'bart_en_sim' in pair[1] or 'bart_sim' in pair[1]:
continue
m1_scores = []
m2_scores = []
for scores in topk_systems.values():
m1_scores.append(scores[pair[0]]) # Human metric
m2_scores.append(scores[pair[1]])
pearson, _ = pearsonr(m1_scores, m2_scores)
spearman, _ = spearmanr(m1_scores, m2_scores)
ktau, _ = kendalltau(m1_scores, m2_scores)
print_list.append([f'{pair[1]}', pearson, spearman, ktau])
# sort based on pearson
print_list = sorted(print_list, key=lambda x: x[2], reverse=True)
print(tabulate(print_list, headers=headers, tablefmt='simple'))
def get_predictions_br(system_pairs, systems, metric):
random.seed(666)
preds = {}
for pair in system_pairs:
sys1 = systems[pair[0]][metric]
sys2 = systems[pair[1]][metric]
n = len(sys1)
points = [i for i in range(0, n)]
is_better = 0
N = 1000
for i in range(N):
sample = choices(points, k=n)
sys1_, sys2_ = [], []
# Due to EMPTY summary, we have to ensure sys1_, sys2_ not empty
while len(sys1_) == 0:
for p in sample:
if sys1[p] is None or sys2[p] is None:
continue
else:
sys1_.append(sys1[p])
sys2_.append(sys2[p])
sample = choices(points, k=n)
if np.mean(sys1_) > np.mean(sys2_):
is_better += 1
if is_better / N >= 0.95:
preds[pair] = 0 # pair[0] is better
elif is_better / N <= 0.05:
preds[pair] = 1 # pair[1] is better
else:
preds[pair] = 2 # can't say
return preds