forked from jdkato/prose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_test.go
100 lines (87 loc) · 1.95 KB
/
extract_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package prose
import (
"bytes"
"encoding/json"
"io"
"io/ioutil"
"math"
"path/filepath"
"reflect"
"testing"
)
func makeNER(text string, model *Model) (*Document, error) {
return NewDocument(text,
WithSegmentation(false),
UsingModel(model))
}
type prodigyOuput struct {
Text string
Spans []LabeledEntity
Answer string
}
func readProdigy(jsonLines []byte) []prodigyOuput {
dec := json.NewDecoder(bytes.NewReader(jsonLines))
entries := []prodigyOuput{}
for {
ent := prodigyOuput{}
err := dec.Decode(&ent)
if err != nil {
if err == io.EOF {
break
}
panic(err)
}
entries = append(entries, ent)
}
return entries
}
func split(data []prodigyOuput) ([]EntityContext, []prodigyOuput) {
cutoff := int(float64(len(data)) * 0.8)
train, test := []EntityContext{}, []prodigyOuput{}
for i := range data {
if i < cutoff {
train = append(train, EntityContext{
Text: data[i].Text,
Spans: data[i].Spans,
Accept: data[i].Answer == "accept"})
} else {
test = append(test, data[i])
}
}
return train, test
}
func TestSumLogs(t *testing.T) {
s := sumLogs([]float64{math.Log2(3), math.Log2(5)})
if s != 3.0 {
t.Errorf("sumLogs() expected = %v, got = %v", 3.0, s)
}
}
func TestNERProdigy(t *testing.T) {
data := filepath.Join(testdata, "reddit_product.jsonl")
file, e := ioutil.ReadFile(data)
if e != nil {
panic(e)
}
train, test := split(readProdigy(file))
correct := 0.0
model := ModelFromData("PRODUCT", UsingEntities(train))
for _, entry := range test {
doc, _ := makeNER(entry.Text, model)
ents := doc.Entities()
if entry.Answer != "accept" && len(ents) == 0 {
correct++
} else {
expected := []string{}
for _, span := range entry.Spans {
expected = append(expected, entry.Text[span.Start:span.End])
}
if reflect.DeepEqual(expected, ents) {
correct++
}
}
}
r := correct / float64(len(test))
if r < 0.819444 {
t.Errorf("NERProdigy() expected >= 0.819444, got = %v", r)
}
}