-
Notifications
You must be signed in to change notification settings - Fork 1
/
conllx_counts.py
208 lines (160 loc) · 7.48 KB
/
conllx_counts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
""" This script contains functions used to evaluate two ConllX files.
The Sentence lists of the ConllX files are passed to the get_sentence_list_counts
functions, and the tokenization, POS tag, UAS, label, and LAS scores are computed.
First the sentences are aligned using ced_word_alignment,
then the scores for each sentence are computed.
Finally, the mean of each score is obtained for all sentences.
"""
from typing import List, Tuple
from pandas import DataFrame, Series, concat
from align_trees import align_trees
from ced_word_alignment.alignment import align_words
from classes import AlignmentNumbers, ConllxStatistics, TreeCounts, TreeMatches
from class_conllx import Sentence
def get_alignment_numbers(gold_df, parsed_df):
insertion_count = (gold_df['FORM'] == 'tok').sum()
deletion_count = (parsed_df['FORM'] == 'tok').sum()
return AlignmentNumbers(insertion_count, deletion_count)
def get_column_matches(col_1: Series, col_2: Series) -> float:
"""Returns the number of matches given two columns.
Args:
col_1 (Series): selected column of the first DataFrame
col_2 (Series): selected column of the second DataFrame
Returns:
float: the number of matches given two columns
"""
return (col_1 == col_2).sum()
def get_pp_matches(col_1: Series, col_2: Series) -> int:
"""Returns if a sentence is perfect.
Args:
col_1 (Series): selected column of the first DataFrame
col_2 (Series): selected column of the second DataFrame
Returns:
int: whether the sentence is perfect
"""
return int((col_1 == col_2).all())
def get_pp_las_matches(head_1: Series, deprel_1: Series, head_2: Series, deprel_2: Series) -> int:
"""Returns if a sentence is perfect.
Args:
col_1 (Series): selected column of the first DataFrame
col_2 (Series): selected column of the second DataFrame
Returns:
int: whether the sentence is perfect
"""
return int((head_1 == head_2).all() & (deprel_1 == deprel_2).all())
def get_word_list(col_1):
token_list = list(col_1) # ['a', '+b', 'c+', 'd']
token_list = [tok for tok in token_list if tok != 'tok']
sentence = ' '.join(token_list) # 'a +b c+ d'
return sentence.replace('+ ', '+').replace(' +', '+')
def get_word_matches(col_1: Series, col_2: Series) -> float:
sentence_1 = get_word_list(col_1)
sentence_2 = get_word_list(col_2)
alignment = align_words(sentence_1, sentence_2)
words_1 = sentence_1.split()
words_2 = sentence_2.split()
for i, word_comp in enumerate(alignment):
if word_comp[0] is None:
words_1.insert(i, 'word')
if word_comp[1] is None:
words_2.insert(i, 'word')
gold_col = Series(words_1).rename('g')
parsed_col = Series(words_2).rename('p')
df = concat([gold_col, parsed_col], axis=1)
removed_mismatches = df[~((df['g'] == 'word') | (df['p'] == 'word'))]
return get_column_matches(removed_mismatches['g'], removed_mismatches['p'])
def get_las_matches(head_1: Series, deprel_1: Series, head_2: Series, deprel_2: Series) -> float:
"""Returns the number of matches of the label and attachments between two trees
(the HEAD and DEPREL columns).
Args:
head_1 (Series): HEAD column of the first tree
deprel_1 (Series): DEPREL column of the first tree
head_2 (Series): HEAD column of the second tree
deprel_2 (Series): DEPREL column of the second tree
Returns:
float: the number of matches of the labels and attachments
"""
# import pdb; pdb.set_trace()
return ((head_1 == head_2) & (deprel_1 == deprel_2)).sum()
def get_tree_matches(gold_df: DataFrame, parsed_df: DataFrame) -> Tuple[TreeMatches, TreeCounts]:
"""Gets the matches of two trees. The matches are on
tokenization, POS tags, UAS, label, and LAS
assumption: gold_df and parsed_df are aligned by adding null alignment tok
TODO: we insert null alignment tokens ... add to align_trees
C=correct match
S=Sub
I=Inserted in prediction
D=Deleted in prediction
Length of Reference R=C+S+D (length of gold without null alignment tok)
Length of Prediction P=C+S+I (length of parsed without null alignment tok)
Precision of edits PREC= C / P (what is correct of the prediction)
Recall of edits REC = C / R (what is correct of the reference)
The Tokenization F-score = 2*PREC*REC / (PREC+REC)
Args:
gold_df (DataFrame): the first tree
parsed_df (DataFrame): the second tree
Returns:
TreeMatches: matching scores of two the trees
"""
assert gold_df.shape[0] == parsed_df.shape[0], 'trees must be aligned!'
tokenization_matches = get_column_matches(gold_df['FORM'], parsed_df['FORM'])
word_matches = get_word_matches(gold_df['FORM'], parsed_df['FORM'])
dfs = gold_df.merge(parsed_df, on='ID', suffixes=('_gold', '_parsed'))
# remove insertions before calculating matches other than tokenization
dfs = dfs[dfs['FORM_gold'] != 'tok']
return TreeMatches(
tokenization_matches,
get_column_matches(dfs['UPOS_gold'], dfs['UPOS_parsed']),
get_column_matches(dfs['HEAD_gold'], dfs['HEAD_parsed']),
get_column_matches(dfs['DEPREL_gold'], dfs['DEPREL_parsed']),
get_las_matches(dfs['HEAD_gold'], dfs['DEPREL_gold'], dfs['HEAD_parsed'], dfs['DEPREL_parsed']),
word_matches,
get_pp_matches(dfs['HEAD_gold'], dfs['HEAD_parsed']),
get_pp_matches(dfs['DEPREL_gold'], dfs['DEPREL_parsed']),
get_pp_las_matches(dfs['HEAD_gold'], dfs['DEPREL_gold'], dfs['HEAD_parsed'], dfs['DEPREL_parsed'])
)
def get_tree_counts(gold_df, parsed_df):
assert gold_df.shape[0] == parsed_df.shape[0], 'trees must be aligned!'
dfs = gold_df.merge(parsed_df, on='ID', suffixes=('_gold', '_parsed'))
ref_token_count = dfs[dfs['FORM_gold'] != 'tok'].shape[0]
pred_token_count = dfs[dfs['FORM_parsed'] != 'tok'].shape[0]
ref_word_count = dfs[~dfs['FORM_gold'].str.contains('\+')].shape[0]
return TreeCounts(
ref_token_count,
pred_token_count,
gold_df.shape[0],
ref_word_count
)
def get_sentence_list_counts(
gold_sen_list: List[Sentence],
parsed_sen_list: List[Sentence]
) -> ConllxStatistics:
"""Given two Sentence lists, compute the following scores:
tokenization, POS tags, UAS, label, and LAS
Args:
gold_sen_list (List[Sentence]): first sentence list
parsed_sen_list (List[Sentence]): second sentence list
Returns:
DataFrame: TreeCounts for all sentences
"""
assert len(gold_sen_list) == len(parsed_sen_list)
sentence_matches_list = []
sentence_counts_list = []
alignment_numbers_list = []
for g_sen, p_sen in zip(gold_sen_list, parsed_sen_list):
g_df = g_sen.dependency_tree.copy()
p_df = p_sen.dependency_tree.copy()
g_df, p_df = align_trees(g_df, p_df)
sentence_counts_list.append(get_tree_counts(g_df, p_df))
sentence_matches_list.append(get_tree_matches(g_df, p_df))
alignment_numbers_list.append(get_alignment_numbers(g_df, p_df))
sentence_counts = DataFrame(sentence_counts_list).sum()
sentence_matches = DataFrame(sentence_matches_list).sum()
alignment_numbers = DataFrame(alignment_numbers_list).sum()
sentence_number = len(gold_sen_list)
return ConllxStatistics(
sentence_counts,
sentence_matches,
alignment_numbers,
sentence_number
)