-
Notifications
You must be signed in to change notification settings - Fork 2
/
example_test.go
109 lines (95 loc) · 2.03 KB
/
example_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
package treelite_test
import (
"bufio"
"fmt"
"log"
"math"
"os"
"runtime"
"strconv"
"strings"
"testing"
"github.com/getumen/go-treelite"
)
func Example() {
data, nRow, nCol := loadData()
dMatrix, err := treelite.CreateFromMat(data, nRow, nCol, float32(math.NaN()))
if err != nil {
log.Fatal(err)
}
defer dMatrix.Close()
model, err := treelite.LoadXGBoostModel("testdata/xgboost.model")
if err != nil {
log.Fatal(err)
}
defer model.Close()
annotator, err := treelite.NewAnnotator(model, dMatrix, 1, true)
if err != nil {
log.Fatal(err)
}
defer annotator.Close()
err = annotator.Save("testdata/go-example-annotation.json")
if err != nil {
log.Fatal(err)
}
compiler, err := treelite.NewCompiler(
"ast_native",
&treelite.CompilerParam{
AnnotationPath: "testdata/go-example-annotation.json",
Quantize: true,
ParallelComp: runtime.NumCPU(),
Verbose: true,
},
)
if err != nil {
log.Fatal(err)
}
defer compiler.Close()
err = compiler.ExportSharedLib(
model,
"testdata/go_example_compiled_model",
"gcc",
nil,
)
if err != nil {
log.Fatal(err)
}
predictor, err := treelite.NewPredictor(
fmt.Sprintf("testdata/go_example_compiled_model.%s", treelite.GetSharedLibExtension()),
runtime.NumCPU(),
)
if err != nil {
log.Fatal(err)
}
defer predictor.Close()
scores, err := predictor.PredictBatch(dMatrix, true, false)
if err != nil {
log.Fatal(err)
}
fmt.Printf("%+v\n", scores)
}
func loadData() ([]float32, int, int) {
nCol := 30
var nRow int
feature := make([]float32, 0)
featureFile, err := os.Open("testdata/feature.csv")
if err != nil {
log.Fatal(err)
}
scanner := bufio.NewScanner(featureFile)
for scanner.Scan() {
nRow++
featureValues := strings.Split(scanner.Text(), ",")
for _, valueString := range featureValues {
value, err := strconv.ParseFloat(valueString, 32)
if err != nil {
log.Fatal(err)
}
feature = append(feature, float32(value))
}
}
return feature, nRow, nCol
}
func TestEndToEnd_Example(t *testing.T) {
Example()
}