-
-
Notifications
You must be signed in to change notification settings - Fork 21
/
helpers.py
39 lines (31 loc) · 1.17 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
Helper utilities.
"""
import numpy as np
def read_labels(file_path):
"""
Helper for loading labels.txt
"""
with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
ret = {}
for line in lines:
pair = line.strip().split(maxsplit=1)
ret[int(pair[0])] = pair[1].strip()
return ret
def set_input_tensor(interpreter, image):
tensor_index = interpreter.get_input_details()[0]["index"]
input_tensor = interpreter.tensor(tensor_index)()[0]
input_tensor[:, :] = image
def classify_image(interpreter, image, top_k=1):
"""Returns a sorted array of classification results."""
set_input_tensor(interpreter, image)
interpreter.invoke()
output_details = interpreter.get_output_details()[0]
output = np.squeeze(interpreter.get_tensor(output_details["index"]))
# If the model is quantized (uint8 data), then dequantize the results
if output_details["dtype"] == np.uint8:
scale, zero_point = output_details["quantization"]
output = scale * (output - zero_point)
ordered = np.argpartition(-output, top_k)
return [(i, output[i]) for i in ordered[:top_k]]