Skip to content

Commit

Permalink
Quant tool: make removal of Clip/Relu ops configurable (microsoft#20616)
Browse files Browse the repository at this point in the history
### Description
Adds the extra option `QDQKeepRemovableActivations` to optionally
prevent automatic removal of Clip/Relu ops in QDQ models. The current
default behavior, which is to remove Clip/Relu, remains the same if the
new option is not enabled.

### Motivation and Context
Explicitly representing these Relu/Clip operators in the QDQ model is
necessary if optimizations or EP transformations will later remove
QuantizeLinear/DequantizeLinear operators from the model.
  • Loading branch information
adrianlizarraga authored May 11, 2024
1 parent 49d197a commit 643ed14
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_qnn_qdq_config(
add_qtype_converts: bool = True,
activation_symmetric: bool = False,
weight_symmetric: bool | None = None,
keep_removable_activations: bool = False,
) -> StaticQuantConfig:
"""
Returns a static quantization configuration suitable for running QDQ models on QNN EP.
Expand Down Expand Up @@ -109,6 +110,11 @@ def get_qnn_qdq_config(
the zero-point values are 128 and 32,768, respectively.
weight_symmetric: True if weights should be quantized symmetrically (i.e., rmax == -rmin) by default.
Defaults to None. If set to None, weight_symmetric is assumed true if the weight_type is a signed int.
keep_removable_activations: Defaults to false. If true, "removable" activations (e.g., Clip or Relu) will not
be removed, and will be explicitly represented in the QDQ model. If false, these activations
are automatically removed if activations are asymmetrically quantized. Keeping these activations
is necessary if optimizations or EP transformations will later remove
QuantizeLinear/DequantizeLinear operators from the model.
Returns:
A StaticQuantConfig object
Expand Down Expand Up @@ -160,6 +166,7 @@ def get_qnn_qdq_config(
extra_options = {
"MinimumRealRange": 0.0001,
"DedicatedQDQPair": False, # Let ORT optimizer duplicate DQ nodes
"QDQKeepRemovableActivations": keep_removable_activations,
"TensorQuantOverrides": overrides_helper.get_dict(),
"ActivationSymmetric": activation_symmetric,
"WeightSymmetric": weight_symmetric,
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/python/tools/quantization/operators/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ def quantize(self):
if not self.quantizer.is_tensor_quantized(node.input[0]):
return

if not self.quantizer.is_activation_symmetric and self.quantizer.try_replacing_upstream_output(
node.input[0], node.output[0]
if (
not self.quantizer.is_activation_symmetric
and not self.quantizer.qdq_keep_removable_activations
and self.quantizer.try_replacing_upstream_output(node.input[0], node.output[0])
):
self.quantizer.remove_node(self.node)
else:
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/python/tools/quantization/qdq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ def __init__(

self.qdq_op_domain = ms_domain if extra_options.get("UseQDQContribOps", False) else None

# User can specify if removable activations, like Clip/Relu, should be kept in the graph.
# Used in the QDQRemovableActivation class.
self.qdq_keep_removable_activations = extra_options.get("QDQKeepRemovableActivations", False)

# The ONNX spec did not support 16-bit Q/DQ ops before opset 21.
# So, may have to override the Q/DQ op domain to 'com.microsoft' if the activation or weight types
# are 16-bit integers.
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/python/tools/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ def __init__(
Invalid if also set `scale` or `zero_point`.
'rmin' = Float : Override the minimum real tensor value in calibration data.
Invalid if also set `scale` or `zero_point`.
QDQKeepRemovableActivations = True/False:
Default is False. If true, "removable" activations (e.g., Clip or Relu) will not be removed, and
will be explicitly represented in the QDQ model. If false, these activations are automatically
removed if activations are asymmetrically quantized. Keeping these activations is necessary if
optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear
operators from the model.
execution_provider : A enum indicates the Execution Provider such as: CPU, TRT, NNAPI, SNE, etc.
Raises:
ValueError: Raise ValueError if execution provider is unknown
Expand Down Expand Up @@ -423,6 +429,12 @@ def quantize_static(
Invalid if also set `scale` or `zero_point`.
'rmin' = Float : Override the minimum real tensor value in calibration data.
Invalid if also set `scale` or `zero_point`.
QDQKeepRemovableActivations = True/False:
Default is False. If true, "removable" activations (e.g., Clip or Relu) will not be removed, and
will be explicitly represented in the QDQ model. If false, these activations are automatically
removed if activations are asymmetrically quantized. Keeping these activations is necessary if
optimizations or EP transformations will later remove QuantizeLinear/DequantizeLinear
operators from the model.
"""
if activation_type == QuantType.QFLOAT8E4M3FN or weight_type == QuantType.QFLOAT8E4M3FN:
if calibrate_method != CalibrationMethod.Distribution:
Expand Down
129 changes: 129 additions & 0 deletions onnxruntime/test/python/quantization/test_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ def input_feeds(self, n, name2shape, np_float_type=np.float32):


class TestQDQExtraOptions(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="ort.qdq.extra_options_")

# Note: swap with the commented line if you want to see the models in local test dir.
cls._tmp_dir_path = cls._tmp_model_dir.name
# cls._tmp_dir_path = "."

@classmethod
def tearDownClass(cls):
cls._tmp_model_dir.cleanup()

def test_qdq_extra_options(self):
# (input)
# |
Expand Down Expand Up @@ -236,6 +248,123 @@ def td(vals):
},
)

def test_qdq_keep_removable_activations_option(self):
#
# Create f32 model with Relu and Clip.
# input0 ---> Conv ---> Relu ---> Conv ---> Clip ---> output
#
shape1 = (1, 1, 3, 3)
w_shape1 = (2, 1, 2, 2)
w_shape2 = (2, 2, 2, 2)
shape3 = (1, 2, 1, 1)

input0 = onnx.helper.make_tensor_value_info("input0", onnx.TensorProto.FLOAT, shape1)
output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape3)

# Conv1
weight1_data = np.random.normal(-1.0, 1.0, w_shape1).astype(np.float32)
weight1_const = onnx.numpy_helper.from_array(weight1_data, "weight1_const")
conv1_node = onnx.helper.make_node("Conv", ["input0", "weight1_const"], ["conv1_out"], name="conv1_node")

# Relu1
relu1_node = onnx.helper.make_node("Relu", ["conv1_out"], ["relu1_out"], name="relu1_node")

# Conv2
weight2_data = np.random.normal(-1.8, 1.8, w_shape2).astype(np.float32)
weight2_const = onnx.numpy_helper.from_array(weight2_data, "weight2_const")
conv2_node = onnx.helper.make_node("Conv", ["relu1_out", "weight2_const"], ["conv2_out"], name="conv2_node")

# Clip1
min_const = onnx.numpy_helper.from_array(np.array(0.0, dtype=np.float32), "min_const")
max_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "max_const")
clip1_node = onnx.helper.make_node(
"Clip", ["conv2_out", "min_const", "max_const"], ["output"], name="clip1_node"
)

graph = onnx.helper.make_graph(
[conv1_node, relu1_node, conv2_node, clip1_node],
"keep_qdq_activations",
[input0],
[output],
initializer=[weight1_const, weight2_const, min_const, max_const],
)
opset_imports = [
onnx.helper.make_opsetid("", 18),
]
f32_model = onnx.helper.make_model(graph, opset_imports=opset_imports)
f32_model = onnx.shape_inference.infer_shapes(f32_model)
f32_model_path = os.path.join(self._tmp_dir_path, "keep.act.model.onnx")
onnx.save_model(f32_model, f32_model_path)

# Create a data reader.
input_data_list = []
for _ in range(5):
inputs = {"input0": np.random.randint(-10, 10, shape1).astype(np.float32)}
input_data_list.extend([inputs])
data_reader = TestDataFeeds(input_data_list)

#
# Quantize model with extra option to KEEP removable activations.
#
qdq_model_path = os.path.join(self._tmp_dir_path, "keep.act.model.qdq.onnx")

# Create u8_act/u8_wgt qdq model
quantize_static(
f32_model_path,
qdq_model_path,
data_reader,
quant_format=QuantFormat.QDQ,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QUInt8,
op_types_to_quantize=[node.op_type for node in f32_model.graph.node],
extra_options={"QDQKeepRemovableActivations": True},
)

has_relu = False
has_clip = False

qdq_model = onnx.load_model(qdq_model_path)

for node in qdq_model.graph.node:
if node.op_type == "Relu":
has_relu = True
if node.op_type == "Clip":
has_clip = True

self.assertTrue(has_relu)
self.assertTrue(has_clip)

#
# Quantize model without extra option. Clip and Relu should be removed by default.
#
qdq_model_path = os.path.join(self._tmp_dir_path, "nokeep.act.model.qdq.onnx")
data_reader.rewind()

# Create u8_act/u8_wgt qdq model
quantize_static(
f32_model_path,
qdq_model_path,
data_reader,
quant_format=QuantFormat.QDQ,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QUInt8,
op_types_to_quantize=[node.op_type for node in f32_model.graph.node],
)

has_relu = False
has_clip = False

qdq_model = onnx.load_model(qdq_model_path)

for node in qdq_model.graph.node:
if node.op_type == "Relu":
has_relu = True
if node.op_type == "Clip":
has_clip = True

self.assertFalse(has_relu)
self.assertFalse(has_clip)


class TestQDQFormatConv(TestQDQFormat):
def check_per_channel_counts(self, model_path, channel_count: int, axis: int = 0):
Expand Down

0 comments on commit 643ed14

Please sign in to comment.