diff --git a/py_src/yolov4/tflite/__init__.py b/py_src/yolov4/tflite/__init__.py index 395df3a9..5998d9af 100644 --- a/py_src/yolov4/tflite/__init__.py +++ b/py_src/yolov4/tflite/__init__.py @@ -21,17 +21,15 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import time -from typing import Union - import numpy as np try: import tflite_runtime.interpreter as tflite + from tflite_runtime.interpreter import load_delegate except ModuleNotFoundError: import tensorflow.lite as tflite + from tensorflow.lite.experimental import load_delegate -from ..common import media, predict from ..common.base_class import BaseClass @@ -47,16 +45,15 @@ def __init__(self, tiny: bool = False, tpu: bool = False): self.output_index = None self.output_size = None - def load_tflite(self, tflite_path): + def load_tflite(self, tflite_path: str) -> None: if self.tpu: self.interpreter = tflite.Interpreter( model_path=tflite_path, - experimental_delegates=[ - tflite.load_delegate("libedgetpu.so.1") - ], + experimental_delegates=[load_delegate("libedgetpu.so.1")], ) else: self.interpreter = tflite.Interpreter(model_path=tflite_path) + self.interpreter.allocate_tensors() input_details = self.interpreter.get_input_details()[0] # width, height