-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
executable file
·104 lines (84 loc) · 3.42 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import os
from input_util import prepare_input
from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "true"
def tint(t):
if isinstance(t, torch.Tensor):
return t.item()
return t
def decode(dataloader, model, args, ema=None):
# if accelerator is not None:
# return decode_accl(dataloader, model, args, accelerator)
model.eval()
predict_labels = []
if ema is None:
with torch.no_grad():
for batch in dataloader:
inputs = prepare_input(batch, args, train=False)
result = model.predict(**inputs)
for i in range(batch[0].size(0)):
predict_label = list(set([f'{tint(l[0])},{tint(l[1]) + 1} {dataloader.dataset.id2type[tint(l[2])]}' for l in result[i]]))
predict_labels.append(predict_label)
else:
with torch.no_grad():
with ema.average_parameters():
for batch in dataloader:
inputs = prepare_input(batch, args, train=False)
result = model.predict(**inputs)
for i in range(batch[0].size(0)):
predict_label = list(set([f'{tint(l[0])},{tint(l[1]) + 1} {dataloader.dataset.id2type[tint(l[2])]}' for l in result[i]]))
predict_labels.append(predict_label)
return predict_labels, []
def decode_threshold(dataloader, model, args=None, accelerator=None, threshold_list=[0.5], device=None):
if accelerator is not None:
raise NotImplementedError
if isinstance(threshold_list, float):
threshold_list = [threshold_list]
model.eval()
predict_labels = [[] for _ in threshold_list]
with torch.no_grad():
for batch in tqdm(dataloader):
inputs = prepare_input(batch, args, train=False, device=device)
result = model.predict(**inputs)
for i in range(batch[0].size(0)):
for idx, threshold in enumerate(threshold_list):
predict_label = list(set([f'{tint(l[0])},{tint(l[1]) + 1} {dataloader.dataset.id2type[tint(l[2])]}' for l in result[i] if l[3] >= threshold]))
predict_labels[idx].append(predict_label)
return predict_labels, []
def write_predict(predict_labels, file_path):
with open(file_path, 'w', encoding='utf-8') as f:
for predcit_label in predict_labels:
f.write('|'.join(predcit_label) + '\n')
def metric(dataset, predict_labels_0, suffix=None):
correct = 0
predict_count = 0
label_count = 0
true_labels = []
for i in range(len(dataset)):
true_label = []
for ent in dataset.df[i]['entities']:
st = ent['start']
ed = ent['end']
tp = ent['type']
true_label.append(f'{st},{ed} {tp}')
true_labels.extend(true_label)
predict_label = predict_labels_0[i]
for l in predict_label:
if l in true_label:
correct += 1
predict_count += len(predict_label)
label_count += len(true_label)
if correct == 0:
p = 0
r = 0
else:
p = correct / predict_count
r = correct / label_count
if p == 0 or r == 0:
f1 = 0
else:
f1 = 2 * p * r / (p + r)
if suffix is None:
return {'p':p, 'r':r, 'f1':f1}
return {f'{suffix}_p':p, f'{suffix}_r':r, f'{suffix}_f1':f1}