-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added tflite support (experimental) (#40)
* tflite support (wip) * linting * just check for "resnet" in model path * fallback to single tensor input * adjust image_data shape to expected input shape also optional image_size input * removed commented out code * only resize tensor if input image doesn't match * added command to convert to tflite * added --model-architecture argument * added quantization arguments * updated readme * linting * typo in readme * improved wording in readme
- Loading branch information
Showing
6 changed files
with
196 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,4 +10,5 @@ disable= | |
too-few-public-methods, | ||
too-many-arguments, | ||
too-many-instance-attributes, | ||
duplicate-code, | ||
invalid-name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import logging | ||
|
||
import tensorflow as tf | ||
|
||
try: | ||
import tfjs_graph_converter | ||
except ImportError: | ||
tfjs_graph_converter = None | ||
|
||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
def get_tflite_converter_for_tfjs_model_path(model_path: str) -> tf.lite.TFLiteConverter: | ||
if tfjs_graph_converter is None: | ||
raise ImportError('tfjs_graph_converter required') | ||
graph = tfjs_graph_converter.api.load_graph_model(model_path) | ||
tf_fn = tfjs_graph_converter.api.graph_to_function_v2(graph) | ||
return tf.lite.TFLiteConverter.from_concrete_functions([tf_fn]) | ||
|
||
|
||
def get_tflite_converter_for_model_path(model_path: str) -> tf.lite.TFLiteConverter: | ||
LOGGER.debug('converting model_path: %s', model_path) | ||
# if model_path.endswith('.json'): | ||
return get_tflite_converter_for_tfjs_model_path(model_path) |