-
-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathtflite_experimental.go
387 lines (344 loc) · 15.7 KB
/
tflite_experimental.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
package tflite
/*
#ifndef GO_TFLITE_EXPERIMENTAL_H
#include "tflite_experimental.go.h"
#endif
typedef void* (*f_tflite_registration_init)(TfLiteContext* context, const char* buffer, size_t length);
void* _tflite_registration_init(TfLiteContext* context, char* buffer, size_t length);
typedef void (*f_tflite_registration_free)(TfLiteContext* context, void* buffer);
void _tflite_registration_free(TfLiteContext* context, void* buffer);
typedef TfLiteStatus (*f_tflite_registration_prepare)(TfLiteContext* context, TfLiteNode* node);
TfLiteStatus _tflite_registration_prepare(TfLiteContext* context, TfLiteNode* node);
typedef TfLiteStatus (*f_tflite_registration_invoke)(TfLiteContext* context, TfLiteNode* node);
TfLiteStatus _tflite_registration_invoke(TfLiteContext* context, TfLiteNode* node);
typedef const char* (*f_tflite_registration_profiling_string)(const TfLiteContext* context, const TfLiteNode* node);
char* _tflite_registration_profiling_string(TfLiteContext* context, TfLiteNode* node);
static TfLiteRegistration*
_make_registration(void* o_init, void* o_free, void* o_prepare, void* o_invoke, void* o_profiling_string) {
TfLiteRegistration* r = (TfLiteRegistration*)malloc(sizeof(TfLiteRegistration));
r->init = (f_tflite_registration_init) o_init;
r->free = (f_tflite_registration_free) o_free;
r->prepare = (f_tflite_registration_prepare) o_prepare;
r->invoke = (f_tflite_registration_invoke) o_invoke;
r->profiling_string = (f_tflite_registration_profiling_string) o_profiling_string;
return r;
}
static void look_context(TfLiteContext *context) {
TfLiteIntArray *plan = NULL;
context->GetExecutionPlan(context, &plan);
if (plan == NULL) return;
int i;
for (i = 0; i < plan->size; i++) {
TfLiteNode *node = NULL;
TfLiteRegistration *reg = NULL;
context->GetNodeAndRegistration(context, i, &node, ®);
printf("%s\n", reg->custom_name);
}
}
static void writeToTensorAsVector(TfLiteTensor *tensor, char *bytes, size_t size, int nelem) {
static TfLiteIntArray dummy;
TfLiteIntArray* new_shape = (TfLiteIntArray*)malloc(sizeof(dummy) + sizeof(dummy.data[0]) * 1);
if (new_shape) {
new_shape->size = 1;
new_shape->data[0] = nelem;
memcpy(new_shape->data, tensor->dims->data, tensor->dims->size * sizeof(int));
}
// TfLiteTensorDataFree
if (tensor->allocation_type == kTfLiteDynamic && tensor->data.raw) {
free(tensor->data.raw);
}
tensor->data.raw = NULL;
if (tensor->dims) free(tensor->dims);
if (tensor->quantization.type == kTfLiteAffineQuantization) {
TfLiteAffineQuantization* q_params =
(TfLiteAffineQuantization*)(tensor->quantization.params);
if (q_params->scale) {
free(q_params->scale);
q_params->scale = NULL;
}
if (q_params->zero_point) {
free(q_params->zero_point);
q_params->zero_point = NULL;
}
free(q_params);
}
tensor->dims = new_shape;
tensor->data.raw = bytes;
tensor->bytes = size;
tensor->allocation_type = kTfLiteMmapRo;
tensor->quantization.type = kTfLiteNoQuantization;
tensor->quantization.params = NULL;
}
*/
import "C"
import (
"bytes"
"encoding/binary"
"io"
"unsafe"
)
const sizeof_int32_t = 4
// ResetVariableTensors resets variable tensors.
func (i *Interpreter) ResetVariableTensors() Status {
return Status(C.TfLiteInterpreterResetVariableTensors(i.i))
}
/*
type Registration interface {
}
func (o *InterpreterOptions) AddCustomOp(name string, reg *Registration, minVersion, maxVersion int) {
ptr := C.CString(name)
defer C.free(unsafe.Pointer(ptr))
r := C._make_registration()
C.TfLiteInterpreterOptionsAddCustomOp(o.o, ptr, r, C.int(minVersion), C.int(maxVersion))
}
type registration struct {
ccxt *C.TfLiteContext
}
//export _tflite_registration_init
func _tflite_registration_init(ccxt *C.TfLiteContext, buffer *C.char, length C.size_t) unsafe.Pointer {
println("registration.init")
C.look_context(ccxt)
//var executionPlan *TfLiteIntArray
//status := ccxt.GetExecutionPlan(ccxt, &executionPlan)
//if status != C.kTfLiteOk {
//return nil
//}
//var registration *C.TfLiteRegistration
//var node *C.TfLiteNode
//for i := 0; i < executionPlan.size; i++ {
//ccxt.GetNodeAndRegistration(ccxt, 0, &node, ®istration)
//}
println(buffer, length)
return nil
}
//export _tflite_registration_free
func _tflite_registration_free(ccxt *C.TfLiteContext, buffer unsafe.Pointer) {
println("registration.free")
}
//export _tflite_registration_prepare
func _tflite_registration_prepare(ccxt *C.TfLiteContext, node *C.TfLiteNode) C.TfLiteStatus {
println("registration.prepare")
return C.kTfLiteOk
}
//export _tflite_registration_invoke
func _tflite_registration_invoke(ccxt *C.TfLiteContext, node *C.TfLiteNode) C.TfLiteStatus {
println("registration.invoke")
return C.kTfLiteOk
}
//export _tflite_registration_profiling_string
func _tflite_registration_profiling_string(ccxt *C.TfLiteContext, node *C.TfLiteNode) *C.char {
println("registration.profiling_string")
return nil
}
*/
// ExtRegistration indicate registration structure.
type ExpRegistration struct {
Init unsafe.Pointer
Free unsafe.Pointer
Prepare unsafe.Pointer
Invoke unsafe.Pointer
ProfilingString unsafe.Pointer
}
type BuiltinOperator int
const (
BuiltinOperator_ADD BuiltinOperator = 0
BuiltinOperator_AVERAGE_POOL_2D BuiltinOperator = 1
BuiltinOperator_CONCATENATION BuiltinOperator = 2
BuiltinOperator_CONV_2D BuiltinOperator = 3
BuiltinOperator_DEPTHWISE_CONV_2D BuiltinOperator = 4
BuiltinOperator_DEQUANTIZE BuiltinOperator = 6
BuiltinOperator_EMBEDDING_LOOKUP BuiltinOperator = 7
BuiltinOperator_FLOOR BuiltinOperator = 8
BuiltinOperator_FULLY_CONNECTED BuiltinOperator = 9
BuiltinOperator_HASHTABLE_LOOKUP BuiltinOperator = 10
BuiltinOperator_L2_NORMALIZATION BuiltinOperator = 11
BuiltinOperator_L2_POOL_2D BuiltinOperator = 12
BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION BuiltinOperator = 13
BuiltinOperator_LOGISTIC BuiltinOperator = 14
BuiltinOperator_LSH_PROJECTION BuiltinOperator = 15
BuiltinOperator_LSTM BuiltinOperator = 16
BuiltinOperator_MAX_POOL_2D BuiltinOperator = 17
BuiltinOperator_MUL BuiltinOperator = 18
BuiltinOperator_RELU BuiltinOperator = 19
BuiltinOperator_RELU_N1_TO_1 BuiltinOperator = 20
BuiltinOperator_RELU6 BuiltinOperator = 21
BuiltinOperator_RESHAPE BuiltinOperator = 22
BuiltinOperator_RESIZE_BILINEAR BuiltinOperator = 23
BuiltinOperator_RNN BuiltinOperator = 24
BuiltinOperator_SOFTMAX BuiltinOperator = 25
BuiltinOperator_SPACE_TO_DEPTH BuiltinOperator = 26
BuiltinOperator_SVDF BuiltinOperator = 27
BuiltinOperator_TANH BuiltinOperator = 28
BuiltinOperator_CONCAT_EMBEDDINGS BuiltinOperator = 29
BuiltinOperator_SKIP_GRAM BuiltinOperator = 30
BuiltinOperator_CALL BuiltinOperator = 31
BuiltinOperator_CUSTOM BuiltinOperator = 32
BuiltinOperator_EMBEDDING_LOOKUP_SPARSE BuiltinOperator = 33
BuiltinOperator_PAD BuiltinOperator = 34
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN BuiltinOperator = 35
BuiltinOperator_GATHER BuiltinOperator = 36
BuiltinOperator_BATCH_TO_SPACE_ND BuiltinOperator = 37
BuiltinOperator_SPACE_TO_BATCH_ND BuiltinOperator = 38
BuiltinOperator_TRANSPOSE BuiltinOperator = 39
BuiltinOperator_MEAN BuiltinOperator = 40
BuiltinOperator_SUB BuiltinOperator = 41
BuiltinOperator_DIV BuiltinOperator = 42
BuiltinOperator_SQUEEZE BuiltinOperator = 43
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM BuiltinOperator = 44
BuiltinOperator_STRIDED_SLICE BuiltinOperator = 45
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN BuiltinOperator = 46
BuiltinOperator_EXP BuiltinOperator = 47
BuiltinOperator_TOPK_V2 BuiltinOperator = 48
BuiltinOperator_SPLIT BuiltinOperator = 49
BuiltinOperator_LOG_SOFTMAX BuiltinOperator = 50
BuiltinOperator_DELEGATE BuiltinOperator = 51
BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM BuiltinOperator = 52
BuiltinOperator_CAST BuiltinOperator = 53
BuiltinOperator_PRELU BuiltinOperator = 54
BuiltinOperator_MAXIMUM BuiltinOperator = 55
BuiltinOperator_ARG_MAX BuiltinOperator = 56
BuiltinOperator_MINIMUM BuiltinOperator = 57
BuiltinOperator_LESS BuiltinOperator = 58
BuiltinOperator_NEG BuiltinOperator = 59
BuiltinOperator_PADV2 BuiltinOperator = 60
BuiltinOperator_GREATER BuiltinOperator = 61
BuiltinOperator_GREATER_EQUAL BuiltinOperator = 62
BuiltinOperator_LESS_EQUAL BuiltinOperator = 63
BuiltinOperator_SELECT BuiltinOperator = 64
BuiltinOperator_SLICE BuiltinOperator = 65
BuiltinOperator_SIN BuiltinOperator = 66
BuiltinOperator_TRANSPOSE_CONV BuiltinOperator = 67
BuiltinOperator_SPARSE_TO_DENSE BuiltinOperator = 68
BuiltinOperator_TILE BuiltinOperator = 69
BuiltinOperator_EXPAND_DIMS BuiltinOperator = 70
BuiltinOperator_EQUAL BuiltinOperator = 71
BuiltinOperator_NOT_EQUAL BuiltinOperator = 72
BuiltinOperator_LOG BuiltinOperator = 73
BuiltinOperator_SUM BuiltinOperator = 74
BuiltinOperator_SQRT BuiltinOperator = 75
BuiltinOperator_RSQRT BuiltinOperator = 76
BuiltinOperator_SHAPE BuiltinOperator = 77
BuiltinOperator_POW BuiltinOperator = 78
BuiltinOperator_ARG_MIN BuiltinOperator = 79
BuiltinOperator_FAKE_QUANT BuiltinOperator = 80
BuiltinOperator_REDUCE_PROD BuiltinOperator = 81
BuiltinOperator_REDUCE_MAX BuiltinOperator = 82
BuiltinOperator_PACK BuiltinOperator = 83
BuiltinOperator_LOGICAL_OR BuiltinOperator = 84
BuiltinOperator_ONE_HOT BuiltinOperator = 85
BuiltinOperator_LOGICAL_AND BuiltinOperator = 86
BuiltinOperator_LOGICAL_NOT BuiltinOperator = 87
BuiltinOperator_UNPACK BuiltinOperator = 88
BuiltinOperator_REDUCE_MIN BuiltinOperator = 89
BuiltinOperator_FLOOR_DIV BuiltinOperator = 90
BuiltinOperator_REDUCE_ANY BuiltinOperator = 91
BuiltinOperator_SQUARE BuiltinOperator = 92
BuiltinOperator_ZEROS_LIKE BuiltinOperator = 93
BuiltinOperator_FILL BuiltinOperator = 94
BuiltinOperator_FLOOR_MOD BuiltinOperator = 95
BuiltinOperator_RANGE BuiltinOperator = 96
BuiltinOperator_RESIZE_NEAREST_NEIGHBOR BuiltinOperator = 97
BuiltinOperator_LEAKY_RELU BuiltinOperator = 98
BuiltinOperator_SQUARED_DIFFERENCE BuiltinOperator = 99
BuiltinOperator_MIRROR_PAD BuiltinOperator = 100
BuiltinOperator_ABS BuiltinOperator = 101
BuiltinOperator_SPLIT_V BuiltinOperator = 102
BuiltinOperator_UNIQUE BuiltinOperator = 103
BuiltinOperator_CEIL BuiltinOperator = 104
BuiltinOperator_REVERSE_V2 BuiltinOperator = 105
BuiltinOperator_ADD_N BuiltinOperator = 106
BuiltinOperator_GATHER_ND BuiltinOperator = 107
BuiltinOperator_COS BuiltinOperator = 108
BuiltinOperator_WHERE BuiltinOperator = 109
BuiltinOperator_RANK BuiltinOperator = 110
BuiltinOperator_ELU BuiltinOperator = 111
BuiltinOperator_REVERSE_SEQUENCE BuiltinOperator = 112
BuiltinOperator_MATRIX_DIAG BuiltinOperator = 113
BuiltinOperator_QUANTIZE BuiltinOperator = 114
BuiltinOperator_MATRIX_SET_DIAG BuiltinOperator = 115
BuiltinOperator_MIN BuiltinOperator = BuiltinOperator_ADD
BuiltinOperator_MAX BuiltinOperator = BuiltinOperator_MATRIX_SET_DIAG
)
// ExpAddBuiltinOp add builtin op specified by code and registration. Current implementation is work in progress.
func (o *InterpreterOptions) ExpAddBuiltinOp(op BuiltinOperator, reg *ExpRegistration, minVersion, maxVersion int) {
r := C._make_registration(
reg.Init,
reg.Free,
reg.Prepare,
reg.Invoke,
reg.ProfilingString,
)
C.TfLiteInterpreterOptionsAddBuiltinOp(o.o, C.TfLiteBuiltinOperator(op), r, C.int(minVersion), C.int(maxVersion))
}
// ExpAddCustomOp add custom op specified by name and registration. Current implementation is work in progress.
func (o *InterpreterOptions) ExpAddCustomOp(name string, reg *ExpRegistration, minVersion, maxVersion int) {
ptr := C.CString(name)
defer C.free(unsafe.Pointer(ptr))
r := C._make_registration(
reg.Init,
reg.Free,
reg.Prepare,
reg.Invoke,
reg.ProfilingString,
)
C.TfLiteInterpreterOptionsAddCustomOp(o.o, ptr, r, C.int(minVersion), C.int(maxVersion))
}
// SetUseNNAPI enable or disable the NN API for the interpreter (true to enable).
func (o *InterpreterOptions) SetUseNNAPI(enable bool) {
C.TfLiteInterpreterOptionsSetUseNNAPI(o.o, C.bool(enable))
}
// DynamicBuffer is buffer hold multiple strings.
type DynamicBuffer struct {
data bytes.Buffer
offset []int
}
// AddString append to the dynamic buffer.
func (d *DynamicBuffer) AddString(s string) {
b := []byte(s)
d.data.Write(b)
if len(d.offset) == 0 {
d.offset = append(d.offset, len(b))
} else {
d.offset = append(d.offset, d.offset[len(d.offset)-1]+len(b))
}
}
// WriteToTensorAsVector write buffer into the tensor as vector.
func (d *DynamicBuffer) WriteToTensorAsVector(t *Tensor) {
var out bytes.Buffer
b := make([]byte, 4)
// Allocate sufficient memory to tensor buffer.
num_strings := len(d.offset)
// Set num of string
binary.LittleEndian.PutUint32(b, uint32(num_strings))
out.Write(b)
if num_strings > 0 {
// Set offset of strings.
start := sizeof_int32_t + sizeof_int32_t*(num_strings+1)
offset := start
binary.LittleEndian.PutUint32(b, uint32(offset))
out.Write(b)
for i := 0; i < len(d.offset); i++ {
offset := start + d.offset[i]
binary.LittleEndian.PutUint32(b, uint32(offset))
out.Write(b)
}
// Copy data of strings.
io.Copy(&out, &d.data)
}
b = out.Bytes()
C.writeToTensorAsVector(t.t, (*C.char)(unsafe.Pointer(&b[0])), C.size_t(len(b)), C.int(len(d.offset)))
}
// GetString returns string in the string buffer.
func (t *Tensor) GetString(index int) string {
if t.Type() != String {
return ""
}
ptr := uintptr(t.Data())
count := int(*(*C.int32_t)(unsafe.Pointer(ptr)))
if index >= count {
return ""
}
offset1 := int(*(*C.int32_t)(unsafe.Pointer(ptr + uintptr(4*(index+1)))))
offset2 := int(*(*C.int32_t)(unsafe.Pointer(ptr + uintptr(4*(index+2)))))
return string((*((*[1<<31 - 1]uint8)(unsafe.Pointer(ptr))))[offset1:offset2])
}