From 5023a4097cd209b7bbe0cb0b919b52cba9636ce6 Mon Sep 17 00:00:00 2001 From: Longyin Hu Date: Thu, 30 Nov 2023 10:54:07 +0800 Subject: [PATCH] core/engines: add bf16 model support for tensorflow inference engine (#172) Signed-off-by: Longyin Hu --- cnap/core/engines/tensorflow_engine.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/cnap/core/engines/tensorflow_engine.py b/cnap/core/engines/tensorflow_engine.py index 1a60be6..a49b742 100644 --- a/cnap/core/engines/tensorflow_engine.py +++ b/cnap/core/engines/tensorflow_engine.py @@ -14,6 +14,7 @@ import numpy as np import tensorflow as tf +from tensorflow.core.protobuf import rewriter_config_pb2 # pylint: disable=no-name-in-module import cv2 from core.infereng import InferenceEngine @@ -238,6 +239,8 @@ class TensorFlowEngine(InferenceEngine): _input_size (Tuple[int, int]): The expected input size of the inference model. _model (dict): The dictionary representing the inference model. _session (tf.compat.v1.Session): The Tensorflow Session object of the inference model. + _config (tf.compat.v1.ConfigProto): The protocol buffer with configuration options for + tensorflow session. _preprocessor (Preprocessor): The Preprocessor object for preprocessing the input data. _postprocessor (Postprocessor): The Postprocessor object for postprocessing the output data. @@ -259,6 +262,7 @@ def __init__(self, config: TFModelConfig): self._input_size: Tuple[int, int] = None self._model = None self._session = None + self._config = None self._configure_environment() self._configure_optimizer() @@ -373,7 +377,7 @@ def _load_frozen_graph_model(self) -> None: self._model = {'input_tensor': input_tensor, 'output_tensor': output_tensor} - self._session = tf.compat.v1.Session(graph=graph) + self._session = tf.compat.v1.Session(graph=graph, config=self._config) def _load_saved_model(self) -> None: """Load the SavedModel from the .h5 file. @@ -394,6 +398,11 @@ def _load_saved_model(self) -> None: def _configure_optimizer(self) -> None: """Configure the optimizer.""" + if self._dtype == 'bfloat16': + self._config = tf.compat.v1.ConfigProto() + self._config.graph_options.rewrite_options.auto_mixed_precision_mkl \ + = rewriter_config_pb2.RewriterConfig.ON + def _configure_environment(self) -> None: """Configure the environment based on the data type. @@ -405,7 +414,7 @@ def _configure_environment(self) -> None: """ if self._dtype == 'float32': os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' - elif self._dtype == 'float16': + elif self._dtype in ['float16', 'bfloat16']: os.environ['TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_ALLOWLIST_ADD'] \ = 'BiasAdd,Relu6,Mul,AddV2' os.environ['TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_INFERLIST_REMOVE'] \