-
Notifications
You must be signed in to change notification settings - Fork 9
/
carb.py
382 lines (315 loc) · 14.7 KB
/
carb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
'''
Usage:
benchmark --gold=GOLD_OIE --out=OUTPUT_FILE (--openiefive=OPENIE5 | --stanford=STANFORD_OIE | --ollie=OLLIE_OIE |--reverb=REVERB_OIE | --clausie=CLAUSIE_OIE | --openiefour=OPENIEFOUR_OIE | --props=PROPS_OIE | --tabbed=TABBED_OIE | --benchmarkGold=BENCHMARK_GOLD | --allennlp=ALLENNLP_OIE ) [--exactMatch | --predMatch | --lexicalMatch | --binaryMatch | --simpleMatch | --strictMatch] [--error-file=ERROR_FILE] [--binary]
Options:
--gold=GOLD_OIE The gold reference Open IE file (by default, it should be under ./oie_corpus/all.oie).
--benchmarkgold=GOLD_OIE The benchmark's gold reference.
--out-OUTPUT_FILE The output file, into which the precision recall curve will be written.
--clausie=CLAUSIE_OIE Read ClausIE format from file CLAUSIE_OIE.
--ollie=OLLIE_OIE Read OLLIE format from file OLLIE_OIE.
--openiefour=OPENIEFOUR_OIE Read Open IE 4 format from file OPENIEFOUR_OIE.
--openiefive=OPENIE5 Read Open IE 5 format from file OPENIE5.
--props=PROPS_OIE Read PropS format from file PROPS_OIE
--reverb=REVERB_OIE Read ReVerb format from file REVERB_OIE
--stanford=STANFORD_OIE Read Stanford format from file STANFORD_OIE
--tabbed=TABBED_OIE Read simple tab format file, where each line consists of:
sent, prob, pred,arg1, arg2, ...
--exactmatch Use exact match when judging whether an extraction is correct.
'''
from __future__ import division
import docopt
import string
import numpy as np
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
import re
import logging
import pdb
import ipdb
from _collections import defaultdict
logging.basicConfig(level = logging.INFO)
from oie_readers.stanfordReader import StanfordReader
from oie_readers.ollieReader import OllieReader
from oie_readers.reVerbReader import ReVerbReader
from oie_readers.clausieReader import ClausieReader
from oie_readers.openieFourReader import OpenieFourReader
from oie_readers.openieFiveReader import OpenieFiveReader
from oie_readers.propsReader import PropSReader
from oie_readers.tabReader import TabReader
from oie_readers.benchmarkGoldReader import BenchmarkGoldReader
from oie_readers.goldReader import GoldReader
from matcher import Matcher
from operator import itemgetter
import pprint
from copy import copy
pp = pprint.PrettyPrinter(indent=4)
class Benchmark:
''' Compare the gold OIE dataset against a predicted equivalent '''
def __init__(self, gold_fn):
''' Load gold Open IE, this will serve to compare against using the compare function '''
gr = GoldReader()
gr.read(gold_fn)
self.gold = gr.oie
def compare(self, predicted, matchingFunc, output_fn, error_file = None, binary=False):
''' Compare gold against predicted using a specified matching function.
Outputs PR curve to output_fn '''
y_true = []
y_scores = []
errors = []
correct = 0
incorrect = 0
correctTotal = 0
unmatchedCount = 0
predicted = Benchmark.normalizeDict(predicted)
gold = Benchmark.normalizeDict(self.gold)
if binary:
predicted = Benchmark.binarize(predicted)
gold = Benchmark.binarize(gold)
#gold = self.gold
# taking all distinct values of confidences as thresholds
confidence_thresholds = set()
for sent in predicted:
for predicted_ex in predicted[sent]:
confidence_thresholds.add(predicted_ex.confidence)
confidence_thresholds = sorted(list(confidence_thresholds))
num_conf = len(confidence_thresholds)
results = {}
p = np.zeros(num_conf)
pl = np.zeros(num_conf)
r = np.zeros(num_conf)
rl = np.zeros(num_conf)
for sent, goldExtractions in gold.items():
if sent in predicted:
predictedExtractions = predicted[sent]
else:
predictedExtractions = []
scores = [[None for _ in predictedExtractions] for __ in goldExtractions]
# print("***Gold Extractions***")
# print("\n".join([goldExtractions[i].pred + ' ' + " ".join(goldExtractions[i].args) for i in range(len(goldExtractions))]))
# print("***Predicted Extractions***")
# print("\n".join([predictedExtractions[i].pred+ " ".join(predictedExtractions[i].args) for i in range(len(predictedExtractions))]))
for i, goldEx in enumerate(goldExtractions):
for j, predictedEx in enumerate(predictedExtractions):
score = matchingFunc(goldEx, predictedEx,ignoreStopwords = True,ignoreCase = True)
scores[i][j] = score
# OPTIMISED GLOBAL MATCH
sent_confidences = [extraction.confidence for extraction in predictedExtractions]
sent_confidences.sort()
prev_c = 0
for conf in sent_confidences:
c = confidence_thresholds.index(conf)
ext_indices = []
for ext_indx, extraction in enumerate(predictedExtractions):
if extraction.confidence >= conf:
ext_indices.append(ext_indx)
recall_numerator = 0
for i, row in enumerate(scores):
max_recall_row = max([row[ext_indx][1] for ext_indx in ext_indices ], default=0)
recall_numerator += max_recall_row
precision_numerator = 0
selected_rows = []
selected_cols = []
num_precision_matches = min(len(scores), len(ext_indices))
for t in range(num_precision_matches):
matched_row = -1
matched_col = -1
matched_precision = -1 # initialised to <0 so that it updates whenever precision is 0 as well
for i in range(len(scores)):
if i in selected_rows:
continue
for ext_indx in ext_indices:
if ext_indx in selected_cols:
continue
if scores[i][ext_indx][0] > matched_precision:
matched_precision = scores[i][ext_indx][0]
matched_row = i
matched_col = ext_indx
selected_rows.append(matched_row)
selected_cols.append(matched_col)
precision_numerator += scores[matched_row][matched_col][0]
p[prev_c:c+1] += precision_numerator
pl[prev_c:c+1] += len(ext_indices)
r[prev_c:c+1] += recall_numerator
rl[prev_c:c+1] += len(scores)
prev_c = c+1
# for indices beyond the maximum sentence confidence, len(scores) has to be added to the denominator of recall
rl[prev_c:] += len(scores)
prec_scores = [a/b if b>0 else 1 for a,b in zip(p,pl) ]
rec_scores = [a/b if b>0 else 0 for a,b in zip(r,rl)]
f1s = [Benchmark.f1(p,r) for p,r in zip(prec_scores, rec_scores)]
try:
optimal_idx = np.nanargmax(f1s)
optimal = (prec_scores[optimal_idx], rec_scores[optimal_idx], f1s[optimal_idx])
except ValueError:
# When there is no prediction
optimal = (0,0,0)
# In order to calculate auc, we need to add the point corresponding to precision=1 , recall=0 to the PR-curve
temp_rec_scores = rec_scores.copy()
temp_prec_scores = prec_scores.copy()
temp_rec_scores.append(0)
temp_prec_scores.append(1)
# print("AUC: {}\t Optimal (precision, recall, F1): {}".format( np.round(auc(temp_rec_scores, temp_prec_scores),3), np.round(optimal,3) ))
with open(output_fn, 'w') as fout:
fout.write('{0}\t{1}\t{2}\n'.format("Precision", "Recall", "Confidence"))
for cur_p, cur_r, cur_conf in sorted(zip(prec_scores, rec_scores, confidence_thresholds), key = lambda cur: cur[1]):
fout.write('{0}\t{1}\t{2}\n'.format(cur_p, cur_r, cur_conf))
if len(f1s)>0:
return np.round(auc(temp_rec_scores, temp_prec_scores),3), np.round(optimal,3)
else:
# When there is no prediction
return 0, (0,0,0)
@staticmethod
def binarize(extrs):
res = defaultdict(lambda: [])
for sent,extr in extrs.items():
for ex in extr:
#Add (a1, r, a2)
temp = copy(ex)
temp.args = ex.args[:2]
res[sent].append(temp)
if len(ex.args) <= 2:
continue
#Add (a1, r a2 , a3 ...)
for arg in ex.args[2:]:
temp.args = [ex.args[0]]
temp.pred = ex.pred + ' ' + ex.args[1]
words = arg.split()
#Add preposition of arg to rel
if words[0].lower() in Benchmark.PREPS:
temp.pred += ' ' + words[0]
words = words[1:]
temp.args.append(' '.join(words))
res[sent].append(temp)
return res
@staticmethod
def f1(prec, rec):
try:
return 2*prec*rec / (prec+rec)
except ZeroDivisionError:
return 0
@staticmethod
def aggregate_scores_greedily(scores):
# Greedy match: pick the prediction/gold match with the best f1 and exclude
# them both, until nothing left matches. Each input square is a [prec, rec]
# pair. Returns precision and recall as score-and-denominator pairs.
matches = []
while True:
max_s = 0
gold, pred = None, None
for i, gold_ss in enumerate(scores):
if i in [m[0] for m in matches]:
# Those are already taken rows
continue
for j, pred_s in enumerate(scores[i]):
if j in [m[1] for m in matches]:
# Those are used columns
continue
if pred_s and Benchmark.f1(*pred_s) > max_s:
max_s = Benchmark.f1(*pred_s)
gold = i
pred = j
if max_s == 0:
break
matches.append([gold, pred])
# Now that matches are determined, compute final scores.
prec_scores = [scores[i][j][0] for i,j in matches]
rec_scores = [scores[i][j][1] for i,j in matches]
total_prec = sum(prec_scores)
total_rec = sum(rec_scores)
scoring_metrics = {"precision" : [total_prec, len(scores[0])],
"recall" : [total_rec, len(scores)],
"precision_of_matches" : prec_scores,
"recall_of_matches" : rec_scores
}
return scoring_metrics
# Helper functions:
@staticmethod
def normalizeDict(d):
return dict([(Benchmark.normalizeKey(k), v) for k, v in d.items()])
@staticmethod
def normalizeKey(k):
# return Benchmark.removePunct(unicode(Benchmark.PTB_unescape(k.replace(' ','')), errors = 'ignore'))
return Benchmark.removePunct(str(Benchmark.PTB_unescape(k.replace(' ',''))))
@staticmethod
def PTB_escape(s):
for u, e in Benchmark.PTB_ESCAPES:
s = s.replace(u, e)
return s
@staticmethod
def PTB_unescape(s):
for u, e in Benchmark.PTB_ESCAPES:
s = s.replace(e, u)
return s
@staticmethod
def removePunct(s):
return Benchmark.regex.sub('', s)
# CONSTANTS
regex = re.compile('[%s]' % re.escape(string.punctuation))
# Penn treebank bracket escapes
# Taken from: https://github.com/nlplab/brat/blob/master/server/src/gtbtokenize.py
PTB_ESCAPES = [('(', '-LRB-'),
(')', '-RRB-'),
('[', '-LSB-'),
(']', '-RSB-'),
('{', '-LCB-'),
('}', '-RCB-'),]
PREPS = ['above','across','against','along','among','around','at','before','behind','below','beneath','beside','between','by','for','from','in','into','near','of','off','on','to','toward','under','upon','with','within']
def f_beta(precision, recall, beta = 1):
"""
Get F_beta score from precision and recall.
"""
beta = float(beta) # Make sure that results are in float
return (1 + pow(beta, 2)) * (precision * recall) / ((pow(beta, 2) * precision) + recall)
if __name__ == '__main__':
args = docopt.docopt(__doc__)
logging.debug(args)
if args['--stanford']:
predicted = StanfordReader()
predicted.read(args['--stanford'])
if args['--props']:
predicted = PropSReader()
predicted.read(args['--props'])
if args['--ollie']:
predicted = OllieReader()
predicted.read(args['--ollie'])
if args['--reverb']:
predicted = ReVerbReader()
predicted.read(args['--reverb'])
if args['--clausie']:
predicted = ClausieReader()
predicted.read(args['--clausie'])
if args['--openiefour']:
predicted = OpenieFourReader()
predicted.read(args['--openiefour'])
if args['--openiefive']:
predicted = OpenieFiveReader()
predicted.read(args['--openiefive'])
if args['--benchmarkGold']:
predicted = BenchmarkGoldReader()
predicted.read(args['--benchmarkGold'])
if args['--tabbed']:
predicted = TabReader()
predicted.read(args['--tabbed'])
if args['--binaryMatch']:
matchingFunc = Matcher.binary_tuple_match
elif args['--simpleMatch']:
matchingFunc = Matcher.simple_tuple_match
elif args['--exactMatch']:
matchingFunc = Matcher.argMatch
elif args['--predMatch']:
matchingFunc = Matcher.predMatch
elif args['--lexicalMatch']:
matchingFunc = Matcher.lexicalMatch
elif args['--strictMatch']:
matchingFunc = Matcher.tuple_match
else:
matchingFunc = Matcher.binary_linient_tuple_match
b = Benchmark(args['--gold'])
out_filename = args['--out']
logging.info("Writing PR curve of {} to {}".format(predicted.name, out_filename))
auc, optimal_f1_point = b.compare(predicted = predicted.oie,
matchingFunc = matchingFunc,
output_fn = out_filename,
error_file = args["--error-file"],
binary = args["--binary"])
print("AUC: {}\t Optimal (precision, recall, F1): {}".format( auc, optimal_f1_point ))