Skip to content

Commit

Permalink
core/engines: add bf16 model support for tensorflow inference engine (i…
Browse files Browse the repository at this point in the history
…ntel#172)

Signed-off-by: Longyin Hu <[email protected]>
  • Loading branch information
Hulongyin authored Nov 30, 2023
1 parent a097e41 commit 5023a40
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions cnap/core/engines/tensorflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np
import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2 # pylint: disable=no-name-in-module
import cv2

from core.infereng import InferenceEngine
Expand Down Expand Up @@ -238,6 +239,8 @@ class TensorFlowEngine(InferenceEngine):
_input_size (Tuple[int, int]): The expected input size of the inference model.
_model (dict): The dictionary representing the inference model.
_session (tf.compat.v1.Session): The Tensorflow Session object of the inference model.
_config (tf.compat.v1.ConfigProto): The protocol buffer with configuration options for
tensorflow session.
_preprocessor (Preprocessor): The Preprocessor object for preprocessing the input data.
_postprocessor (Postprocessor): The Postprocessor object for postprocessing the output
data.
Expand All @@ -259,6 +262,7 @@ def __init__(self, config: TFModelConfig):
self._input_size: Tuple[int, int] = None
self._model = None
self._session = None
self._config = None

self._configure_environment()
self._configure_optimizer()
Expand Down Expand Up @@ -373,7 +377,7 @@ def _load_frozen_graph_model(self) -> None:

self._model = {'input_tensor': input_tensor, 'output_tensor': output_tensor}

self._session = tf.compat.v1.Session(graph=graph)
self._session = tf.compat.v1.Session(graph=graph, config=self._config)

def _load_saved_model(self) -> None:
"""Load the SavedModel from the .h5 file.
Expand All @@ -394,6 +398,11 @@ def _load_saved_model(self) -> None:

def _configure_optimizer(self) -> None:
"""Configure the optimizer."""
if self._dtype == 'bfloat16':
self._config = tf.compat.v1.ConfigProto()
self._config.graph_options.rewrite_options.auto_mixed_precision_mkl \
= rewriter_config_pb2.RewriterConfig.ON


def _configure_environment(self) -> None:
"""Configure the environment based on the data type.
Expand All @@ -405,7 +414,7 @@ def _configure_environment(self) -> None:
"""
if self._dtype == 'float32':
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
elif self._dtype == 'float16':
elif self._dtype in ['float16', 'bfloat16']:
os.environ['TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_ALLOWLIST_ADD'] \
= 'BiasAdd,Relu6,Mul,AddV2'
os.environ['TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_INFERLIST_REMOVE'] \
Expand Down

0 comments on commit 5023a40

Please sign in to comment.