diff --git a/examples/apps/object_detection_tensorflow/kernels.py b/examples/apps/object_detection_tensorflow/kernels.py index bbf17180..5ba0aefa 100644 --- a/examples/apps/object_detection_tensorflow/kernels.py +++ b/examples/apps/object_detection_tensorflow/kernels.py @@ -3,14 +3,13 @@ import cv2 import os import scannerpy -import scannerpy.stdlib.util import pickle import visualization_utils as vis_util import tarfile from scannerpy import FrameType, DeviceType -from scannerpy.stdlib import tensorflow from scannerpy.types import NumpyArrayFloat32 +from scannertools import tensorflow from typing import Tuple, Sequence from tqdm import tqdm @@ -24,12 +23,12 @@ category_index = vis_util.create_category_index(categories) def download_and_extract_model(url, local_path=None): - path = scannerpy.stdlib.util.download_temp_file(url, local_path) + path = scannerpy.util.download_temp_file(url, local_path) tar_file = tarfile.open(path) for f in tar_file.getmembers(): file_name = os.path.basename(f.name) if 'frozen_inference_graph.pb' in file_name: - local_path = scannerpy.stdlib.util.temp_directory() + local_path = scannerpy.util.temp_directory() tar_file.extract(f, local_path) model_path = os.path.join(local_path, f.name) break @@ -40,11 +39,13 @@ def download_and_extract_model(url, local_path=None): batch=2) class ObjDetect(tensorflow.TensorFlowKernel): def __init__(self, config, dnn_url): + print('objdet', config) + print([d.id for d in config.devices]) tensorflow.TensorFlowKernel.__init__(self, config) self.dnn_url = dnn_url self.model_name = dnn_url.rsplit('/', 1)[-1] self.local_model_path = os.path.join( - scannerpy.stdlib.util.temp_directory(), + scannerpy.util.temp_directory(), self.model_name.rsplit('.')[0], 'frozen_inference_graph.pb') diff --git a/examples/apps/object_detection_tensorflow/main.py b/examples/apps/object_detection_tensorflow/main.py index 7768270e..93ceb690 100644 --- a/examples/apps/object_detection_tensorflow/main.py +++ b/examples/apps/object_detection_tensorflow/main.py @@ -1,5 +1,5 @@ -from scannerpy import Client, DeviceType -from scannerpy.storage import NamedVideoStream, PythonStream +from scannertools.storage.python import PythonStream +import scannerpy as sp import os import sys import math @@ -21,10 +21,10 @@ def main(): print('Detecting objects in movie {}'.format(movie_path)) movie_name = os.path.splitext(os.path.basename(movie_path))[0] - sc = Client() + sc = sp.Client() stride = 1 - input_stream = NamedVideoStream(sc, movie_name, path=movie_path) + input_stream = sp.NamedVideoStream(sc, movie_name, path=movie_path) frame = sc.io.Input([input_stream]) strided_frame = sc.streams.Stride(frame, [stride]) @@ -33,12 +33,14 @@ def main(): objdet_frame = sc.ops.ObjDetect( frame=strided_frame, dnn_url=model_url, - device=DeviceType.GPU if sc.has_gpu() else DeviceType.CPU, + device=sp.DeviceType.GPU if sc.has_gpu() else sp.DeviceType.CPU, batch=2) - detect_stream = NamedVideoStream(sc, movie_name + '_detect') + detect_stream = sp.NamedVideoStream(sc, movie_name + '_detect') output_op = sc.io.Output(objdet_frame, [detect_stream]) - sc.run(output_op) + sc.run(output_op, + sp.PerfParams.estimate(), + cache_mode=sp.CacheMode.Overwrite) print('Extracting data from Scanner output...') # bundled_data_list is a list of bundled_data @@ -58,9 +60,11 @@ def main(): drawn_frame = sc.ops.TFDrawBoxes(frame=strided_frame, bundled_data=bundled_data, min_score_thresh=0.5) - drawn_stream = NamedVideoStream(sc, movie_name + '_drawn_frames') + drawn_stream = sp.NamedVideoStream(sc, movie_name + '_drawn_frames') output_op = sc.io.Output(drawn_frame, [drawn_stream]) - sc.run(output_op) + sc.run(output_op, + sp.PerfParams.estimate(), + cache_mode=sp.CacheMode.Overwrite) drawn_stream.save_mp4(movie_name + '_obj_detect') @@ -73,5 +77,3 @@ def main(): if __name__ == '__main__': main() - -