Skip to content

Commit

Permalink
Update object detection example to use new scannertools api
Browse files Browse the repository at this point in the history
  • Loading branch information
fpoms committed Apr 16, 2019
1 parent f56adf3 commit 6822cde
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
11 changes: 6 additions & 5 deletions examples/apps/object_detection_tensorflow/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
import cv2
import os
import scannerpy
import scannerpy.stdlib.util
import pickle
import visualization_utils as vis_util
import tarfile

from scannerpy import FrameType, DeviceType
from scannerpy.stdlib import tensorflow
from scannerpy.types import NumpyArrayFloat32
from scannertools import tensorflow
from typing import Tuple, Sequence
from tqdm import tqdm

Expand All @@ -24,12 +23,12 @@
category_index = vis_util.create_category_index(categories)

def download_and_extract_model(url, local_path=None):
path = scannerpy.stdlib.util.download_temp_file(url, local_path)
path = scannerpy.util.download_temp_file(url, local_path)
tar_file = tarfile.open(path)
for f in tar_file.getmembers():
file_name = os.path.basename(f.name)
if 'frozen_inference_graph.pb' in file_name:
local_path = scannerpy.stdlib.util.temp_directory()
local_path = scannerpy.util.temp_directory()
tar_file.extract(f, local_path)
model_path = os.path.join(local_path, f.name)
break
Expand All @@ -40,11 +39,13 @@ def download_and_extract_model(url, local_path=None):
batch=2)
class ObjDetect(tensorflow.TensorFlowKernel):
def __init__(self, config, dnn_url):
print('objdet', config)
print([d.id for d in config.devices])
tensorflow.TensorFlowKernel.__init__(self, config)
self.dnn_url = dnn_url
self.model_name = dnn_url.rsplit('/', 1)[-1]
self.local_model_path = os.path.join(
scannerpy.stdlib.util.temp_directory(),
scannerpy.util.temp_directory(),
self.model_name.rsplit('.')[0],
'frozen_inference_graph.pb')

Expand Down
24 changes: 13 additions & 11 deletions examples/apps/object_detection_tensorflow/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from scannerpy import Client, DeviceType
from scannerpy.storage import NamedVideoStream, PythonStream
from scannertools.storage.python import PythonStream
import scannerpy as sp
import os
import sys
import math
Expand All @@ -21,10 +21,10 @@ def main():
print('Detecting objects in movie {}'.format(movie_path))
movie_name = os.path.splitext(os.path.basename(movie_path))[0]

sc = Client()
sc = sp.Client()

stride = 1
input_stream = NamedVideoStream(sc, movie_name, path=movie_path)
input_stream = sp.NamedVideoStream(sc, movie_name, path=movie_path)
frame = sc.io.Input([input_stream])
strided_frame = sc.streams.Stride(frame, [stride])

Expand All @@ -33,12 +33,14 @@ def main():
objdet_frame = sc.ops.ObjDetect(
frame=strided_frame,
dnn_url=model_url,
device=DeviceType.GPU if sc.has_gpu() else DeviceType.CPU,
device=sp.DeviceType.GPU if sc.has_gpu() else sp.DeviceType.CPU,
batch=2)

detect_stream = NamedVideoStream(sc, movie_name + '_detect')
detect_stream = sp.NamedVideoStream(sc, movie_name + '_detect')
output_op = sc.io.Output(objdet_frame, [detect_stream])
sc.run(output_op)
sc.run(output_op,
sp.PerfParams.estimate(),
cache_mode=sp.CacheMode.Overwrite)

print('Extracting data from Scanner output...')
# bundled_data_list is a list of bundled_data
Expand All @@ -58,9 +60,11 @@ def main():
drawn_frame = sc.ops.TFDrawBoxes(frame=strided_frame,
bundled_data=bundled_data,
min_score_thresh=0.5)
drawn_stream = NamedVideoStream(sc, movie_name + '_drawn_frames')
drawn_stream = sp.NamedVideoStream(sc, movie_name + '_drawn_frames')
output_op = sc.io.Output(drawn_frame, [drawn_stream])
sc.run(output_op)
sc.run(output_op,
sp.PerfParams.estimate(),
cache_mode=sp.CacheMode.Overwrite)

drawn_stream.save_mp4(movie_name + '_obj_detect')

Expand All @@ -73,5 +77,3 @@ def main():

if __name__ == '__main__':
main()


0 comments on commit 6822cde

Please sign in to comment.