diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 12e4e3f45fef..e411245827e2 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -87,7 +87,8 @@ bool MatchPattern(DFPattern pattern, Expr expr); * \return Return An Expr with every match of the pattern inside the callbacks rewritten by the * functions inside the callbacks */ -Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod = IRModule()); +Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod = IRModule(), + int allow_overlapping_groups = 0); /*! * \brief Partition all matches of a DFPattern inside an Expr into separate Function calls diff --git a/include/tvm/relay/qnn/attrs.h b/include/tvm/relay/qnn/attrs.h index c5213fe07471..96fe36104164 100644 --- a/include/tvm/relay/qnn/attrs.h +++ b/include/tvm/relay/qnn/attrs.h @@ -78,6 +78,7 @@ struct QuantizeAttrs : public tvm::AttrsNode { /*! \brief Attribute for dequantize operator */ struct DequantizeAttrs : public tvm::AttrsNode { int axis; + DataType out_dtype; TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") { TVM_ATTR_FIELD(axis) @@ -85,6 +86,10 @@ struct DequantizeAttrs : public tvm::AttrsNode { "The channel axis for channel wise dequantization. Default value is -1," "which corresponds to the last axis.") .set_default(-1); + TVM_ATTR_FIELD(out_dtype) + .describe( + "The datatype we are dequantizing to (float32 or int32). Defaults to float32.") + .set_default(DataType::Float(32)); } }; diff --git a/python/tvm/data/__init__.py b/python/tvm/data/__init__.py new file mode 100644 index 000000000000..e46a6353b11c --- /dev/null +++ b/python/tvm/data/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +from ._dataset_manager import DatasetManager, TFDatasetManager, RandomDatasetManager \ No newline at end of file diff --git a/python/tvm/data/_dataset_manager.py b/python/tvm/data/_dataset_manager.py new file mode 100644 index 000000000000..8028a6b9de35 --- /dev/null +++ b/python/tvm/data/_dataset_manager.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Wrapper classes to expose datasets during quantization.""" + +import numpy as np + + +class DatasetManager: + """Simple wrapper class to expose datasets in quantization.""" + + def get_next_batch(self): + """Returns the next batch of data. + + Returns + ------- + inputs : List + The inputs to be provided to the graph. + The list is of the form [batched_input_1, batched_input_2, ..., batched_input_n] + + labels: List + The expected outputs of the graph. + The length of labels should be equal to the batch size. + """ + raise NotImplementedError + + def batch_size(self): + """Returns the size of each batch the dataset manager has. + + Returns + ------- + batch_size : int + The number of inputs in each batch. + """ + + def num_batches(self): + """Returns the number of batches the dataset manager has. + + Returns + ------ + num_batches : int + The number of batches the dataset manager contains. + """ + raise NotImplementedError + + def is_empty(self): + """Checks whether the dataset manager has gone through + all its batches. + Returns + ------- + is_empty : bool + True if there are batches left, False if there are no more + batches. + """ + raise NotImplementedError + + def reset(self): + """Resets the counter in the dataset manager to the beginning.""" + raise NotImplementedError + + +class TFDatasetManager(DatasetManager): + """DatasetManager wrapping a tensorflow dataset.""" + + def __init__(self, tf_dataset, batch_size, total_batches): + self.idx = 0 + self.total_batches = total_batches + self.batch_sz = batch_size + self.tf_dataset = tf_dataset + self.tf_iter = iter(self.tf_dataset) + + def get_next_batch(self): + if self.is_empty(): + raise IndexError + self.idx += 1 + + data, label = next(self.tf_iter) + + return [data.numpy()], label.numpy() + + def num_batches(self): + return self.total_batches + + def batch_size(self): + return self.batch_sz + + def is_empty(self): + return self.idx >= self.total_batches + + def reset(self): + self.tf_iter = iter(self.tf_dataset) + self.idx = 0 + + +class RandomDatasetManager(DatasetManager): + """DatasetManager that creates a random input of a specific shape. + This class is mostly used for testing, and as an example of how to + implement a DatasetManager. + """ + + def __init__(self, data_shape, dtype, batch_size, total_batches): + self.idx = 0 + self.data_shape = data_shape + self.dtype = dtype + self.batch_sz = batch_size + self.total_batches = total_batches + + def get_next_batch(self): + if self.is_empty(): + raise IndexError + self.idx += 1 + return [np.random.randn(*self.data_shape).astype(self.dtype)], None + + def batch_size(self): + return self.batch_sz + + def num_batches(self): + return self.total_batches + + def is_empty(self): + return self.idx >= self.total_batches + + def reset(self): + self.idx = 0 diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index d4a8481d106e..b238b79423b7 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -799,7 +799,7 @@ def __init__(self, require_type=False): self.pattern = None self.require_type = require_type - def rewrite(self, expr: Expr) -> Expr: + def rewrite(self, expr: Expr, allow_overlapping_groups: bool = False) -> Expr: """ Rewrite expression with this callback @@ -813,7 +813,7 @@ def rewrite(self, expr: Expr) -> Expr: result : tvm.relay.Expr The Expression with matched subgraphs rewritten by the callbacks. """ - return rewrite(self, expr) + return rewrite(self, expr, allow_overlapping_groups = allow_overlapping_groups) def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Expr: """ @@ -843,7 +843,8 @@ def __init__(self, pattern, callback, require_type): self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type) -def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr: +def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None, + allow_overlapping_groups: bool = False) -> Expr: """ Rewrite expression with the given callbacks. @@ -868,8 +869,7 @@ def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr: for callback in callbacks: assert callback.pattern is not None tmp.append(_DFPatternCallback(callback.pattern, callback.callback, callback.require_type)) - - return ffi.rewrite(tmp, expr, mod) + return ffi.rewrite(tmp, expr, mod, allow_overlapping_groups) def partition( diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index a5892f331f06..2e2ef872144c 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -20,10 +20,9 @@ from __future__ import absolute_import as _abs from tvm.relay.expr import Tuple, TupleWrapper from tvm.relay.op.nn.utils import get_pad_tuple2d -from . import _make from ... import op as reg from ...op import OpPattern - +from . import _make def requantize( data, @@ -118,7 +117,7 @@ def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"): return _make.quantize(data, output_scale, output_zero_point, axis, out_dtype) -def dequantize(data, input_scale, input_zero_point, axis=-1): +def dequantize(data, input_scale, input_zero_point, axis=-1, out_dtype="float32"): r"""Dequantize op This operator takes quantized int8 and unit8 as input and produces dequantized float32 as output. The output shape is the same as input shape. The input @@ -134,13 +133,15 @@ def dequantize(data, input_scale, input_zero_point, axis=-1): The input scale. axis : int The channel axis for quantization. Default value is -1 which corresponds to the last axis. + out_dtype : str, optional + The output type to dequantize to. Can be either float32 or int32. Returns ------- result : tvm.relay.Expr The computed result. """ - return _make.dequantize(data, input_scale, input_zero_point, axis) + return _make.dequantize(data, input_scale, input_zero_point, axis, out_dtype) def concatenate(data, input_scales, input_zero_points, output_scale, output_zero_point, axis): @@ -611,7 +612,6 @@ def subtract( output_zero_point, ) - # register fuse pattern for qnn ops reg.register_pattern("qnn.quantize", OpPattern.OPAQUE) reg.register_pattern("qnn.dequantize", OpPattern.OPAQUE) diff --git a/python/tvm/relay/transform/quantize/__init__.py b/python/tvm/relay/transform/quantize/__init__.py new file mode 100644 index 000000000000..56b57804ac49 --- /dev/null +++ b/python/tvm/relay/transform/quantize/__init__.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin, invalid-name +"""The namespace containing quantization and calibration passes""" +from ._calibration_callback import ( + CalibrationCallback, + GlobalCalibrationCallback, + AverageMaxCalibrationCallback, +) +from ._quantizer_patterns import ( + QuantizerPattern, + Conv2DBiasAddPattern, + Conv2DPattern, + DensePattern, + DenseBiasAddPattern, + AddPattern, + MultiplyPattern, + PerChannelPattern, +) +from ._average_max_channel_patterns import ( + AverageMaxPerChannelConv2DBiasAddPattern, + AverageMaxPerChannelConv2DPattern, + AverageMaxPerChannelDenseBiasAddPattern, + AverageMaxPerChannelDensePattern, +) + +from ._quantizer_pattern_utils import all_patterns, average_max_per_channel_patterns + +from ._quantizer import Quantizer +from ._calibrator import QuantizationCalibrator +from ._requantizer import Requantizer + +from ._quantize_pass import QuantizePass + +from . import _ffi as ffi diff --git a/python/tvm/relay/transform/quantize/_average_max_channel_patterns.py b/python/tvm/relay/transform/quantize/_average_max_channel_patterns.py new file mode 100644 index 000000000000..02636d1a3bb8 --- /dev/null +++ b/python/tvm/relay/transform/quantize/_average_max_channel_patterns.py @@ -0,0 +1,125 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Per channel implementation of Conv2DPattern, Conv2DBiasAddPattern, and DensePattern, using the +average max algorithm to pick scales.""" + +import numpy as np + +from tvm import relay +from tvm.relay.transform.quantize import ( + Conv2DPattern, + Conv2DBiasAddPattern, + DensePattern, + DenseBiasAddPattern, + PerChannelPattern, + CalibrationCallback, + QuantizerPattern, +) + + +class AverageMaxPerChannelPattern(PerChannelPattern): + """Per channel implementation of the AverageMax algorithm.""" + + def calibrate_pattern(self, calibration_info): + self.attr_callback(calibration_info.partition_info.expr) + scale_zp_values = {} + + data_max_avg = 0 + weight_max_avg = np.zeros(shape=self.get_scale_size()) + num_inputs = ( + calibration_info.dataset_manager.num_batches() + * calibration_info.dataset_manager.batch_size() + ) + + while not calibration_info.dataset_manager.is_empty(): + # Get the original input from dataset manger, run unquantized graph with those inputs + image_list, _ = calibration_info.dataset_manager.get_next_batch() + unquantized_inputs = calibration_info.get_unquantized_layer_inputs(image_list) + + data = unquantized_inputs[0] + weight = unquantized_inputs[1] + + data_max_avg += np.max(np.abs(data)) / num_inputs + + axis = list(range(len(weight.shape))).remove(0) + weight_max_avg += np.max(np.abs(weight), axis=axis) / num_inputs + + calibration_info.dataset_manager.reset() + + # Since this is a symmetric distribution and we are quantizing to int8, there are 256 bins, + # and 128 are positive + data_scale = data_max_avg / 128 + weight_scales = weight_max_avg / 128 + scales = np.array([data_scale, weight_scales]) + + for i, scale in enumerate(scales): + scale_name = calibration_info.partition_info.input_scale_zps[i][0].name_hint + zp_name = calibration_info.partition_info.input_scale_zps[i][1].name_hint + + scale_zp_values[scale_name] = scale.astype("float32") + scale_zp_values[zp_name] = np.array(0).astype("int32") + + return scale_zp_values + + +class AverageMaxPerChannelConv2DPattern(AverageMaxPerChannelPattern, Conv2DPattern): + """Conv2DPattern with the per channel average max algorithm as the calibration method.""" + + def extract_attrs(self, pre, post, node_map): + conv2d = node_map[self.conv2d][0] + weight = node_map[self.conv_weight][0] + + self.get_attrs(conv2d.attrs, weight.checked_type.shape) + return post + + def get_scale_size(self): + return (self.channels,) + + +class AverageMaxPerChannelConv2DBiasAddPattern( + AverageMaxPerChannelConv2DPattern, Conv2DBiasAddPattern +): + """Per channel version of Conv2DBiasAddPattern, implementing the average max algorithm to + calculate scales and zero points.""" + + +class AverageMaxPerChannelDensePattern(AverageMaxPerChannelPattern, DensePattern): + """Per channel version of DensePattern, implementing the average max algorithm to + calculate scales and zero points.""" + + def __init__(self, calibration_callback: CalibrationCallback = None): + super().__init__(calibration_callback) + + def extract_attrs(self, pre, post, node_map): + dense = node_map[self.dense][0] + weight = node_map[self.weight][0] + + self.get_attrs(dense.attrs, weight.checked_type.shape) + self.units = self.attrs["units"] + + return post + + def get_scale_size(self): + return (self.units,) + + +class AverageMaxPerChannelDenseBiasAddPattern( + AverageMaxPerChannelDensePattern, DenseBiasAddPattern +): + """Per channel version of DenseBiasAddPattern, implementing the average max algorithm to + calculate scales and zero point.""" diff --git a/python/tvm/relay/transform/quantize/_calibration_callback.py b/python/tvm/relay/transform/quantize/_calibration_callback.py new file mode 100644 index 000000000000..fd7bbee24a79 --- /dev/null +++ b/python/tvm/relay/transform/quantize/_calibration_callback.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Methods for calibrating functions.""" + +import numpy as np + + +class CalibrationCallback: + """Abstract class that that defines the API for calibrating a pattern.""" + + def calibrate_pattern(self, calibration_info): + """Calculates the scale and zero points for quantizing parts + of a generic pattern. If you would like to do per-channel or + pattern-specific calibration, please overwrite calibrate_pattern + in the relevant QuantizerPattern. + + Parameters + ---------- + calibration_info : CalibrationInfo + The class containing relevant information and utility functions + to calibrate one instance of a pattern. + + Returns + ------- + scale_zp_map : Dictionary + A map from the names of scales and zero point variables in this pattern to their values. + """ + raise NotImplementedError + + +class GlobalCalibrationCallback(CalibrationCallback): + """Sets the scales and zero points to a user-provided value.""" + + def __init__(self, scale_value, zp_value): + self.scale_value = np.array(scale_value).astype("float32") + self.zp_value = np.array(zp_value).astype("int32") + + def calibrate_pattern(self, calibration_info): + """Returns the scale and zero point value set during initialization to the + QuantizationCalibrator. + + Parameters + ---------- + calibration_info : CalibrationInfo + Object containing information needed during calibration. + + Returns + ------- + scale_zp_map : dict of str to value + The map from names of scale and zero point variables to the global scale and zero point + values. + """ + scale_zp_map = {} + for i in range(len(calibration_info.input_scale_zps)): + scale_name = calibration_info.input_scale_zps[i][0].name_hint + scale_zp_map[scale_name] = self.scale_value + zp_name = calibration_info.input_scale_zps[i][1].name_hint + scale_zp_map[zp_name] = self.zp_value + + return scale_zp_map + + +class AverageMaxCalibrationCallback(CalibrationCallback): + def calibrate_pattern(self, calibration_info): + scale_zp_values = {} + + min_sums = np.zeros(shape=(len(calibration_info.partition_info.input_scale_zps))) + max_sums = np.zeros(shape=(len(calibration_info.partition_info.input_scale_zps))) + + while not calibration_info.dataset_manager.is_empty(): + # Get the original input from dataset manger, run unquantized graph with those inputs + image_list, _ = calibration_info.dataset_manager.get_next_batch() + unquantized_inputs = calibration_info.get_unquantized_layer_inputs(image_list) + + # Iterate through scale and zp variables + for i, unquantized_input in enumerate(unquantized_inputs): + # Calculate the average min, max across each batch + + min_sums[i] += np.min(unquantized_input) + max_sums[i] += np.max(unquantized_input) + + calibration_info.dataset_manager.reset() + + avg_mins = min_sums / calibration_info.dataset_manager.num_batches() + avg_maxs = max_sums / calibration_info.dataset_manager.num_batches() + + # Threshold for quantization of an input to a layer is mean(abs(avg_max), abs(avg_min)) + thresholds = np.mean([np.abs(avg_mins), np.abs(avg_maxs)], axis=0) + + # Since this is a symmetric distribution and we are quantizing to int8, there are 256 bins, and 128 are positive + scales = thresholds / 128 + + for i, scale_value in enumerate(scales): + scale_name = calibration_info.partition_info.input_scale_zps[i][0].name_hint + scale_zp_values[scale_name] = np.array(scale_value).astype("float32") + zp_name = calibration_info.partition_info.input_scale_zps[i][1].name_hint + scale_zp_values[zp_name] = np.array(0).astype("int32") + + return scale_zp_values diff --git a/python/tvm/relay/transform/quantize/_calibrator.py b/python/tvm/relay/transform/quantize/_calibrator.py new file mode 100644 index 000000000000..10df91a40f4d --- /dev/null +++ b/python/tvm/relay/transform/quantize/_calibrator.py @@ -0,0 +1,380 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""API for calibrating a quantized function.""" +import numpy as np + +import tvm +from tvm import relay +from tvm.contrib import graph_runtime +import tvm.relay.build_module as build_module + + +class QuantizationCalibrator: + """The QuantizationCalibrator picks scales and zero points for all qnn ops in the quantized module. + + Parameters + ---------- + quantizer : Quantizer + Quantizer created with the mod we are calibrating. + + target : String, optional + The target to run the quantized function on during calibration. + + ctx : String, optional + The ctx used for running the quantized function on during calibration. + + dataset_manager : DatasetManager, optional + The dataset manager containing data used to run the graph during + data-aware calibration. + """ + + def __init__(self, quantizer, target="llvm", ctx=tvm.cpu(), dataset_manager=None, show_scale_zps=False): + self.quantizer = quantizer + + self.calibration_info = CalibrationInfo( + quantizer.tuple_subgraph_func, + quantizer.q_tuple_subgraph_func, + quantizer.partition_infos, + dataset_manager, + target, + ctx, + ) + + self.show_scale_zps = show_scale_zps + + def calibrate(self): + """Picks the scales and zero points for all qnn ops in the quantized graph, using the + calibrate_pattern function from the quantizer. + + Returns + ------- + calibrated_func : relay.Function + The quantized function with the values for scales and zero points substituted into the + function. + """ + # Create a map of DFPatternCallback to QuantizerPattern + pattern_map = {pattern.pattern: pattern for pattern in self.quantizer.patterns} + + for partition_info in self.calibration_info.partition_infos: + # Set the partition info so we can access it from the callback + self.calibration_info.set_current_partition_info(partition_info) + quantizer_pattern = pattern_map[partition_info.pattern] + + # Get the values for scales and ZPs in this layer, store + scale_zps = quantizer_pattern.calibrate_pattern(self.calibration_info) + if self.show_scale_zps: + self.report_scale_zps(scale_zps) + self.calibration_info.update_scale_zp_map(scale_zps) + + calibrated_func = build_module.bind_params_by_name( + self.quantizer.q_tuple_subgraph_func, self.calibration_info.scale_zp_value_map + ) + + # If num_orig_outputs is -1, original output wasn't a tuple + params = calibrated_func.params + if self.quantizer.num_orig_outputs == -1: + calibrated_func = relay.Function(params, calibrated_func.body.fields[0]) + else: + new_body = relay.Tuple(calibrated_func.body.fields[0 : self.quantizer.num_orig_outputs]) + calibrated_func = relay.Function(params, new_body) + + return calibrated_func + + def report_scale_zps(self, scale_zp_map): + """Prints the scales and zero points out. + + Parameters + ---------- + scale_zp_map : dict of str to value + The map from names of scale and zero point variables to their assigned values. + """ + for key, value in scale_zp_map.items(): + print("Set ", key, " variable to ", value) + + +class CalibrationInfo: + """Helper class that contains information necessary for picking scales and zero points into + calibrate_pattern. The state of CalibrationInfo is updated by QuantizationCalibrator. + + Parameters + ---------- + tuple_subgraph_func : relay.Function + A function whose output is a tuple that contains values we will need to access during + calibration. + + q_tuple_subgraph_func : relay.Function + A quantized version of the tuple_subgraph_func. Note that to run this function, you + must pass in values for scales and zero points. + + partition_infos : List[PatternCalibrationInfo] + A list of objects that correspond to every pattern matched during quantization. Each + contains scale and zero point variables, and indices into the the tuple functions. + + dataset_manager : DatasetManager + The dataset manager containing data used to run the graph during data-aware calibration. + + target : String + The target to run the quantized function on during calibration. + + ctx : String + The ctx used for running the quantized function on during calibration. + """ + + def __init__( + self, + tuple_subgraph_func, + q_tuple_subgraph_func, + partition_infos, + dataset_manager, + target, + ctx, + ): + self.tuple_subgraph_func = tuple_subgraph_func + self.q_tuple_subgraph_func = q_tuple_subgraph_func + self.dataset_manager = dataset_manager + self.partition_infos = partition_infos + self.target = target + self.ctx = ctx + + self.partition_info = None + self.input_scale_zps = None + + tuple_subgraph_mod = tvm.ir.IRModule.from_expr(self.tuple_subgraph_func) + q_tuple_subgraph_mod = tvm.ir.IRModule.from_expr(self.q_tuple_subgraph_func) + + self.tuple_subgraph_graphmodule = None + self.q_tuple_subgraph_graphmodule = None + self.init_subgraph_graphmodules(tuple_subgraph_mod, q_tuple_subgraph_mod) + + self.scale_zp_value_map = {} + self.initialize_scale_zp_map() + + def init_subgraph_graphmodules(self, tuple_subgraph_mod, q_tuple_subgraph_mod): + """Builds the tuple subgraphs so they can be run during calibration. + + Parameters + ---------- + tuple_subgraph_mod : tvm.ir.IRModule + Module wrapping tuple_subgraph_func. + + q_tuple_subgraph_mod : tvm.ir.IRModule + Module wrapping q_tuple_subgraph_func. + """ + # AlterOpLayout is disabled because it inserts some pads and other ops + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + tuple_subgraph_lib = relay.build(tuple_subgraph_mod, target=self.target) + q_tuple_subgraph_lib = relay.build(q_tuple_subgraph_mod, target=self.target) + + ts_graph_mod = graph_runtime.GraphModule(tuple_subgraph_lib["default"](self.ctx)) + q_ts_graph_mod = graph_runtime.GraphModule(q_tuple_subgraph_lib["default"](self.ctx)) + self.tuple_subgraph_graphmodule = ts_graph_mod + self.q_tuple_subgraph_graphmodule = q_ts_graph_mod + + def initialize_scale_zp_map(self): + """Initializes scales to 1 and zero points to zero. These values will only be used + to calculate values in the tuple subgraph that are not returned to the user.""" + for p_info in self.partition_infos: + for count in range(len(p_info.input_scale_zps)): + scale_var = p_info.input_scale_zps[count][0] + zp_var = p_info.input_scale_zps[count][1] + + self.scale_zp_value_map[scale_var.name_hint] = np.array(1).astype("float32") + self.scale_zp_value_map[zp_var.name_hint] = np.array(0).astype("int32") + + def set_current_partition_info(self, partition_info): + """Sets the partition_info for the current iteration, and exposes the list of scale and zp + variables directly instead of requiring the user to access it through the partition_info + object. + + Parameters + ---------- + partition_info : PatternCalibrationInfo + The PatternCalibrationInfo object corresponding to the pattern that will be quantized + in the calibrate_pattern callback.""" + self.partition_info = partition_info + self.input_scale_zps = self.partition_info.input_scale_zps + + def update_scale_zp_map(self, new_scale_zps): + """Updates the QuantizationCalibrator's scale and zero point map with values returned from + calibrate_pattern. + + Parameters + ---------- + new_scale_zps : dict + Dictionary mapping scale and zero point variable names to values.""" + self.scale_zp_value_map.update(new_scale_zps) + + def _run_tuple_mod(self, inputs, idx_list): + """Runs the graph that has all the intermediate outputs in it, and extracts the values for + the current pattern using indices into the tuple. + + Parameters + ---------- + inputs : List[ndarray] + A list of inputs to the original, unquantized function. + + idx_list : List[int] + A list of indices into the tuple_subgraph_mod + + Returns + ------- + value_list : + A list of outputs from the tuple function corresponding to the indices in idx_list. + """ + value_list = [] + + # Set the user provided inputs + for i, inp in enumerate(inputs): + self.tuple_subgraph_graphmodule.set_input(i, inp) + + self.tuple_subgraph_graphmodule.run() + + # Get the correct values out + for idx in idx_list: + value_list.append(self.tuple_subgraph_graphmodule.get_output(idx.value).asnumpy()) + + return value_list + + def _run_quantized_tuple_mod(self, inputs, current_layer_scale_zps, idx_list): + """Runs the quantized verion of the graph that has all the intermediate outputs in it, + and extracts the values for the current pattern using indices into the tuple. Because we + are running the quantized version, we need to pass in scales and zero points for the + current pattern. + + Parameters + ---------- + inputs : List[ndarray] + A list of inputs to the original, unquantized function. + + current_layer_scale_zps : dict + A dictionary mapping scale and zero point variable names to values. These values are + the scales and zero point values for the current pattern. Note that if you pass in an + dictionary, the function will still run, but the current scale and zero point will be + 1 and 0 (the default values), respectively. + + idx_list : List[int] + A list of indices into the tuple_subgraph_mod + + Returns + ------- + value_list : + A list of outputs from the tuple function corresponding to the indices in idx_list. + """ + value_list = [] + + # Set user provided inputs + for i, inp in enumerate(inputs): + self.q_tuple_subgraph_graphmodule.set_input(i, inp) + + # Set the scale and zero points + self.q_tuple_subgraph_graphmodule.set_input(**self.scale_zp_value_map) + self.q_tuple_subgraph_graphmodule.set_input(**current_layer_scale_zps) + + self.q_tuple_subgraph_graphmodule.run() + + for idx in idx_list: + value_list.append(self.q_tuple_subgraph_graphmodule.get_output(idx.value).asnumpy()) + + return value_list + + def get_unquantized_layer_inputs(self, data): + """Utility function that evaluates the inputs to the current layer and returns the results + for given inputs. This function should be called from inside calibrate_pattern. + + Parameters + ---------- + inputs : List + List of inputs to pass into the mod. Inputs appear in the same order they appeared in + the original, unquantized function. + + Returns + ------- + input_values : tuple of numpy.ndarray + A tuple of the values of inputs to the unquantized layer. If the layer is a binop, + there will be two elements in the tuple, if an n-op, there will be n elements in + the tuple. + """ + return self._run_tuple_mod(data, self.partition_info.input_idxs) + + def get_quantized_layer_inputs(self, data, current_layer_scale_zps): + """Utility function that evaluates the quantized inputs to the current quantized layer, + and returns the results in a tuple. It uses previously set scale and zero points when + evaluating the graph. This function should be called from inside calibrate_pattern. + + Parameters + ---------- + inputs : list + List of inputs to pass into the mod. Inputs appear in the same order they appeared in + the original, unquantized function. + + current_layer_scale_zps: dictionary + Map from names of scales and zero points you are setting in the current layer to + their values. This map should be of the same format as the map you return from + _calibration_callback. + + Returns + ------- + quantized_input_values : tuple of numpy.ndarray + A tuple of the values of the inputs to the quantized layer. If the layer is a binop, + there will be two elements in the tuple, if an n-op, there will be n elements in + the tuple. + """ + return self._run_quantized_tuple_mod( + data, current_layer_scale_zps, self.partition_info.input_idxs + ) + + def get_unquantized_layer_output(self, inputs): + """Utility function that evaluates the unquantized output of the current layer and returns + it. This function should be called from inside calibrate_pattern. + + Parameters + ---------- + input_list : list + List of inputs to pass into the mod. Inputs appear in the same order they appeared in + the original, unquantized function. + + Returns + ------- + output_value : numpy.ndarray + The output of this layer. + """ + return self._run_tuple_mod(inputs, [self.partition_info.output_idx]) + + def get_quantized_layer_output(self, data, current_layer_scale_zps): + """Utility function that evaluates the quantized output of the current layer. + This function should be called from inside calibrate_pattern. + + Parameters + ---------- + inputs : list + List of inputs to pass into the mod. Inputs appear in the same order they appeared in + the original, unquantized function. + + current_layer_scale_zps: dictionary + Map from names of scales and zero points you are setting in the current layer to their + values. This map should be of the same format as the map you return from + _calibration_callback. + + Returns + ------- + output_value : numpy.ndarray + The output of the quantized layer. + """ + return self._run_quantized_tuple_mod( + data, current_layer_scale_zps, [self.partition_info.output_idx] + ) diff --git a/python/tvm/relay/transform/quantize/_ffi.py b/python/tvm/relay/transform/quantize/_ffi.py new file mode 100644 index 000000000000..1f04be739341 --- /dev/null +++ b/python/tvm/relay/transform/quantize/_ffi.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Quantization FFI bindings.""" +import tvm._ffi + +tvm._ffi._init_api("relay.transform.quantize", __name__) diff --git a/python/tvm/relay/transform/quantize/_quantize_pass.py b/python/tvm/relay/transform/quantize/_quantize_pass.py new file mode 100644 index 000000000000..42a99fbe07e7 --- /dev/null +++ b/python/tvm/relay/transform/quantize/_quantize_pass.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Relay pass wrapping the quantization and calibration workflow.""" + +from typing import List + +import tvm +from tvm.relay.transform.quantize import ( + Quantizer, + QuantizationCalibrator, + Requantizer, + QuantizerPattern, +) +from .. import function_pass + + +@function_pass(opt_level=5) +class QuantizePass: + """Explicit relay pass wrapper around quantization workflow. + + Parameters + ---------- + quantizer_pattern_list : List[QuantizerPattern] + The patterns we want to quantize. + + params : dict of str to NDArray + Constants needed to run the mod. We need params so that we can run parts of the + graph during calibration. + + target : str + Target to generate code for calibration on. + + skip_first : bool + If True, we do not quantize the first quantizable pattern in the function. If False, + we will quantize it. + + skip_last : bool + If True, we do not quantize the last quantizable pattern in the function. If False, + we will quantize it.""" + + def __init__( + self, + quantizer_pattern_list: List[QuantizerPattern], + params=None, + target="llvm", + device=tvm.cpu(0), + skip_first=True, + skip_last=False, + ): + self.quantizer_pattern_list = quantizer_pattern_list + self.params = params + self.target = target + self.device = device + self.skip_first = skip_first + self.skip_last = skip_last + + def transform_function(self, func, mod, ctx): + """Quantizes, calibrates and requantizes the function. + Parameters + ---------- + func : relay.Function + Function to apply the transformation on. + + """ + params = {} + # Extract params that are in this function + for param in func.params: + if param.name_hint in self.params.keys(): + params[param.name_hint] = self.params[param.name_hint] + quantizer = Quantizer( + func, params, self.quantizer_pattern_list, self.skip_first, self.skip_last + ) + + calibrator = QuantizationCalibrator(quantizer, target=self.target, ctx=self.device) + transformed_func = calibrator.calibrate() + transformed_func = Requantizer().requantize(transformed_func) + return transformed_func diff --git a/python/tvm/relay/transform/quantize/_quantizer.py b/python/tvm/relay/transform/quantize/_quantizer.py new file mode 100644 index 000000000000..e6f7341f2c00 --- /dev/null +++ b/python/tvm/relay/transform/quantize/_quantizer.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Quantizes functions by inserting qnn.quantize and qnn.dequantize ops.""" +from typing import List + +import tvm +from tvm import relay +from tvm.relay.dataflow_pattern import ffi as pattern_ffi +from tvm.relay.dataflow_pattern import _DFPatternCallback +from tvm.relay.transform.quantize import QuantizerPattern +from tvm.relay.frontend.common import infer_type + +from . import _ffi as ffi + + +class Quantizer: + """Class that inserts quantize and dequantizes around patterns. It also constructs + important structures used by the QuantizationCalibrator. + + Parameters + ---------- + func : relay.Function + Funtion we are quantizing. + + params : dict of str to NDArray + Parameters you would pass into relay.build or relay.build_module. We need params + so that we can run parts of the graph during calibration. + + patterns : List[QuantizerPattern] + A list of all the patterns that we are going to quantize using this Quantizer. + + skip_first : bool + If True, we do not quantize the first quantizable pattern in the function. If False, + we will quantize it. + + skip_last : bool + If True, we do not quantize the last quantizable pattern in the function. If False, + we will quantize it. + """ + + def __init__( + self, func, params, patterns: List[QuantizerPattern], skip_first=True, skip_last=False + ): + self.patterns = patterns + self.original_func = prerequisite_optimize(func, params) + + # num_orig_outputs is -1 if output is not a Tuple, else is length of tuple + if isinstance(self.original_func.body, tvm.relay.expr.Tuple): + self.num_orig_outputs = len(self.original_func.body) + else: + self.num_orig_outputs = -1 + + # Partition the func into sub functions containing the patterns we want to quantize + partitioned_func = self.original_func + for q_pattern in self.patterns: + partitioned_func = q_pattern.pattern.partition(partitioned_func) + + # Get rid of first and last par + partitioned_func = skip_partitions(partitioned_func, skip_first, skip_last) + # Add outputs necessary for calibration + tuple_subgraph_func = partition_outputs(partitioned_func) + + # Lower partitioned funcs and store in a mod + self.tuple_subgraph_func = lower_partitions(tuple_subgraph_func) + + # Rewrite the multi-output graph to be quantized, and lower partitioned funcs + outs = rewrite_partitions(self.patterns, tuple_subgraph_func) + q_tuple_subgraph_func = outs["new_out"] + + # Information about each partition used for calibration + self.partition_infos = outs["infos_"] + + # Lower quantized partitions and store in a mod + self.q_tuple_subgraph_func = lower_partitions(q_tuple_subgraph_func) + + # Create a function containing just the quantized original graph + quantized_func = self.q_tuple_subgraph_func + if self.num_orig_outputs == -1: + self.quantized_func = relay.Function( + self.q_tuple_subgraph_func.params, quantized_func.body.fields[0] + ) + else: + tuple_body = relay.Tuple(quantized_func.body.fields[self.num_orig_outputs]) + self.quantized_func = relay.Function(self.q_tuple_subgraph_func.params, tuple_body) + + +def prerequisite_optimize(func, params=None): + """Prerequisite optimization passes for quantization. Perform "DynamicToStatic", + "SimplifyInference", "FoldConstant", "FoldScaleAxis" before quantization. + + Parameters + --------- + params : dict of str to NDArray + Parameters to use during calibration. + + Returns + ------- + preopt_func : relay.Function + The original function with optimizations needed before quantization applied. + """ + optimize = tvm.transform.Sequential( + [ + relay.transform.DynamicToStatic(), + relay.transform.SimplifyInference(), + relay.transform.FoldConstant(), + relay.transform.FoldScaleAxis(), + relay.transform.FoldConstant(), + relay.transform.EliminateCommonSubexpr(), + ] + ) + + if params is not None: + func = relay.build_module.bind_params_by_name(func, params) + + mod = tvm.ir.IRModule.from_expr(func) + + with relay.build_config(opt_level=3): + mod = optimize(mod) + + return mod["main"] + + +def partition_outputs(expr): + return ffi.partition_outputs(expr) + + +def rewrite_partitions(callbacks, expr): + return ffi.rewrite_partitions( + [ + _DFPatternCallback(callback.pattern, callback.callback, callback.require_type) + for callback in callbacks + ], + infer_type(expr), + ) + + +def lower_partitions(expr): + return ffi.lower_partitions(expr) + + +def skip_partitions(expr, skip_first=True, skip_last=True): + return ffi.skip_partitions(expr, skip_first, skip_last) diff --git a/python/tvm/relay/transform/quantize/_quantizer_pattern_utils.py b/python/tvm/relay/transform/quantize/_quantizer_pattern_utils.py new file mode 100644 index 000000000000..0e9c1b6e1b09 --- /dev/null +++ b/python/tvm/relay/transform/quantize/_quantizer_pattern_utils.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tvm.relay.transform.quantize import ( + CalibrationCallback, + Conv2DPattern, + Conv2DBiasAddPattern, + DensePattern, + DenseBiasAddPattern, + AddPattern, + MultiplyPattern, + AverageMaxCalibrationCallback, + AverageMaxPerChannelConv2DBiasAddPattern, + AverageMaxPerChannelConv2DPattern, + AverageMaxPerChannelDenseBiasAddPattern, + AverageMaxPerChannelDensePattern, +) + + +def all_patterns(cc: CalibrationCallback = None): + return [ + Conv2DBiasAddPattern(cc), + Conv2DPattern(cc), + DenseBiasAddPattern(cc), + DensePattern(cc), + AddPattern(cc), + MultiplyPattern(cc), + ] + + +def average_max_per_channel_patterns(): + cc = AverageMaxCalibrationCallback() + return [ + AverageMaxPerChannelConv2DBiasAddPattern(cc), + AverageMaxPerChannelConv2DPattern(cc), + AverageMaxPerChannelDenseBiasAddPattern(cc), + AverageMaxPerChannelDensePattern(cc), + AddPattern(cc), + MultiplyPattern(cc), + ] diff --git a/python/tvm/relay/transform/quantize/_quantizer_patterns.py b/python/tvm/relay/transform/quantize/_quantizer_patterns.py new file mode 100644 index 000000000000..d396d1827aa1 --- /dev/null +++ b/python/tvm/relay/transform/quantize/_quantizer_patterns.py @@ -0,0 +1,711 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Patterns to quantize and how to quantize them.""" + +import tvm +from tvm import relay + +from tvm.relay.transform.quantize import CalibrationCallback +from tvm.relay.dataflow_pattern import ( + is_op, + wildcard, + is_constant, + DFPatternCallback, + _DFPatternCallback, +) +from tvm.relay.dataflow_pattern import ffi as pattern_ffi +from tvm.relay.frontend.common import infer_type +from tvm.relay.op.nn.utils import get_pad_tuple2d +from . import _ffi as ffi + + +class QuantizerPattern(DFPatternCallback): + """DFPatternCallback to rewrite patterns as quantized. Also contains extra information + used for quantization and calibration. + + Parameters + ---------- + calibration_callback : CalibrationCallback + The method we will use to calibrate the nn.conv2d pattern. + """ + + # Counts the number of times we've added a scale and zp for variable naming + # This needs to be a global variable and not initialized in __init__ because + # each scale and zero point must be unique, even if they are created by different + # instances. + scales_count = 0 + zp_count = 0 + + def __init__(self, calibration_callback: CalibrationCallback = None): + super().__init__() + self.calibration_callback = calibration_callback + + def calibrate_pattern(self, calibration_info): + """Calculates the scale and zero points for quantizing parts of a generic pattern. By + default, we call the calibrate_pattern method of the CalibrationCallback object that is + passed into QuantizerPattern during initialization. However, if you want a pattern specific + quantization method or a per-channel quantization method, you should overwrite the + QuantizerPattern's calibrate_pattern method. + + Parameters + ---------- + calibration_info : CalibrationInfo + The class containing relevant information and utility functions to calibrate one + instance of a pattern. + + Returns + ------- + scale_zp_map : Dictionary + A map from the names of scales and zero point variables in this pattern to their + values. + """ + return self.calibration_callback.calibrate_pattern(calibration_info) + + def callback(self, pre, post, node_map): + raise NotImplementedError + + def scale(self, name, is_weight=False): + """Helper to create the scale variable for qnn.quantize when rewriting our pattern. + + Parameters + ---------- + name : str + Identifier at the beginning of the scale variable. + + is_weight : bool + Whether this scale is a weight scale or a data scale. If it is a weight scale, we + the returned variable has shape (channels,). Only used for per-channel quantization. + + Returns + ------- + var : relay.Var + Relay variable for scale. If the input name is 'conv2d_data', then the name of the + relay variable might be 'conv2d_data_scale_0'. + """ + + var = relay.var( + str(name) + "_scale_" + str(QuantizerPattern.scales_count), shape=(), dtype="float32" + ) + QuantizerPattern.scales_count += 1 + return var + + def zero_point(self, name): + """Helper to create the zero point variable for qnn.quantize when rewriting our + our pattern. + + Parameters + ---------- + name : str + Identifier at the beginning of the variable. + + Returns + ------- + var : relay.Var + Relay variable for scale. If the input name is 'conv2d_data', then the name of the + relay variable might be 'conv2d_data_zero_pt_0'. + """ + var = relay.var( + str(name) + "_zero_pt_" + str(QuantizerPattern.zp_count), shape=(), dtype="int32" + ) + QuantizerPattern.zp_count += 1 + return var + + def create_scale_zps(self, left_name, right_name): + """Helper to create scales and zero points for binops. + + Parameters + ---------- + left_name : str + Identifier of the left hand side scale and zero point. + + right_name : str + Identifier of the right hand side scale and zero point. + """ + data_scale = self.scale(left_name) + data_zp = self.zero_point(left_name) + weight_scale = self.scale(right_name) + weight_zp = self.zero_point(right_name) + self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp] + + +class Conv2DPattern(QuantizerPattern): + def __init__(self, calibration_callback: CalibrationCallback = None): + """Pattern to rewrite nn.conv2d ops as qnn.conv2d ops. + + Parameters + ---------- + calibration_callback : CalibrationCallback + The method we will use to calibrate this pattern. + """ + super().__init__(calibration_callback) + self.input = wildcard() + self.conv_weight = wildcard() + self.inputs = [self.input, self.conv_weight] + self.conv2d = is_op("nn.conv2d")(self.input, self.conv_weight) + self.pattern = self.conv2d + self.attrs = None + self.weight_channel_axis = None + self.data_channel_axis = None + self.channels = None + + def get_kernel_size(self, kernel_shape, kernel_layout): + """Gets the size of the kernel. + + Parameters + ---------- + kernel_shape : NDArray + Shape of the kernel + + kernel_layout : str + Layout of the kernel + + Returns + ------- + kernel_size : NDArray + Size of the kernel + """ + if kernel_layout == "OIHW": + kernel_size = tuple(kernel_shape[2:4]) + elif kernel_layout == "HWIO": + kernel_size = tuple(kernel_shape[0:2]) + else: + raise ValueError( + "Quantizting kernel layout %s for conv2d is not yet supported." + + "Please use OIHW or HWIO", + kernel_layout, + ) + return kernel_size + + def get_attrs(self, attrs, kernel_shape): + """Constructs the attributes for qnn.conv2d. + + Parameters + ---------- + attrs : dict + Attributes of the original nn.conv2d + + kernel_shape : NDArray + Shape of the kernel + + Returns + ------- + quantized_attrs : dict + Attributes for the qnn.conv2d + """ + new_attr_dict = {} + self.kernel_layout = attrs["kernel_layout"] + data_layout = attrs["data_layout"] + + if self.kernel_layout == "OIHW": + self.weight_channel_axis = 0 + elif self.kernel_layout == "HWIO": + self.weight_channel_axis = 3 + else: + raise ValueError( + "Quantizing kernel layout %s for conv2d is not yet supported." + + "Please use OIHW or HWIO", + self.kernel_layout, + ) + + if data_layout == "NCHW": + self.data_channel_axis = 1 + elif data_layout == "NHWC": + self.data_channel_axis = 3 + else: + raise ValueError( + "Quantizing data layout %s for conv2d is not yet supported." + + "Please use NCHW or NHWC", + data_layout, + ) + + for attr in attrs.keys(): + attr_value = attrs[attr] + if isinstance(attr_value, tvm.ir.container.Array): + attr_value = tuple(attr_value) + if attr == "kernel_size": + kernel_size = attrs[attr] + if kernel_size is None: + kernel_size = self.get_kernel_size(self.kernel_layout, kernel_shape) + else: + kernel_size = tuple([k.value for k in attrs[attr]]) + new_attr_dict[attr] = kernel_size + elif attr == "channels": + self.channels = attrs[attr] + if self.channels is None: + self.channels = kernel_shape[self.weight_channel_axis] + if isinstance(self.channels, tvm.tir.expr.IntImm): + self.channels = self.channels.value + new_attr_dict[attr] = self.channels + elif attr == "padding": + # We don't need to put padding in attr dict because we explicitly construct padding + self.padding = attrs[attr] + else: + new_attr_dict[attr] = attr_value + + new_attr_dict["out_dtype"] = "int32" + self.attrs = new_attr_dict + + def quantize_args(self): + """Helper to quantize the arguments to the qnn.conv2d.""" + quantized_data = relay.qnn.op.quantize( + self.args[0], self.scale_zps[0], self.scale_zps[1], axis=self.data_channel_axis + ) + quantized_weight = relay.qnn.op.quantize( + self.args[1], self.scale_zps[2], self.scale_zps[3], axis=self.weight_channel_axis + ) + self.quantized_args = [quantized_data, quantized_weight] + + def create_conv(self, args, node_map): + """Creates the qnn.conv2d. + + Parameters + ---------- + args : List[relay.Expr] + Quantized arguments for the qnn.conv2d. + + node_map : tvm.ir.container.Map + Node map from DFPatternCallback's callback + + Returns + ------- + q_conv2d : relay.Expr + Quantized version of the pattern. + """ + return relay.qnn.op.conv2d(*args, **self.attrs) + + def callback(self, pre, post, node_map): + self.args = [node_map[i][0] for i in self.inputs] + conv2d = node_map[self.conv2d][0] + + self.out_dtype = conv2d.checked_type.dtype + + self.get_attrs(conv2d.attrs, infer_type(self.args[1]).checked_type.shape) + + self.create_scale_zps("conv2d_data", "conv2d_weight") + self.quantize_args() + + conv_scale = self.scale_zps[0] * self.scale_zps[2] # data_scale * weight_scale + + # Conv zp is zero since QNN deals with input zps for us + conv_zp = relay.const(0, dtype="int32") + # args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale] + args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]] + + if self.padding is not None: + + top, left, bottom, right = [p.value for p in get_pad_tuple2d(self.padding)] + if self.kernel_layout == "OIHW": + pad_width = ((0, 0), (0, 0), (top, bottom), (left, right)) + elif self.kernel_layout == "HWIO": + pad_width = ( + (top, bottom), + (left, right), + (0, 0), + (0, 0), + ) + pad_val = 0 + args[0] = relay.op.nn.pad(args[0], pad_width, pad_val) + + # Construct quantized qnn.conv2d and dequantize + qnn_call = self.create_conv(args, node_map) + dequantized_call = relay.qnn.op.dequantize( + qnn_call, conv_scale, conv_zp, out_dtype=self.out_dtype, axis=self.data_channel_axis + ) + + return dequantized_call + + +class Conv2DBiasAddPattern(Conv2DPattern): + """Pattern to rewrite nn.conv2d -> nn.bias_add pattern as qnn.conv2d -> nn.bias_add. + + Parameters + ---------- + calibration_callback : CalibrationCallback + The method we will use to calibrate this pattern. + """ + + def __init__(self, calibration_callback: CalibrationCallback = None): + super().__init__(calibration_callback) + self.bias_weight = is_constant() + self.inputs.append(self.bias_weight) + self.add = is_op("add")(self.conv2d, self.bias_weight) + self.bias_add = is_op("nn.bias_add")(self.conv2d, self.bias_weight) + self.pattern = self.bias_add | self.add + + def quantize_args(self): + """Quantizes the arguments to the nn.conv2d -> nn.bias_add pattern.""" + super().quantize_args() + quantized_bias = relay.qnn.op.quantize( + self.args[2], self.scale_zps[0], self.scale_zps[1], axis=0, out_dtype="int32" + ) + self.quantized_args.append(quantized_bias) + + def create_conv(self, args, node_map): + """Creates the qnn.dense -> nn.bias_add. + + Parameters + ---------- + args : List[relay.Expr] + Quantized arguments for the qnn.conv2d and bias_add. + + node_map : tvm.ir.container.Map + Node map from DFPatternCallback's callback + + Returns + ------- + q_conv2d : relay.Expr + Quantized version of the pattern. + """ + qnn_call = relay.qnn.op.conv2d(*args, **self.attrs) + if node_map.get(self.add) is not None: + bias_add = relay.op.add(qnn_call, self.quantized_args[2]) + else: # self.bias_add in node_map + bias_add = relay.op.nn.bias_add( + qnn_call, self.quantized_args[2], axis=self.data_channel_axis + ) + return bias_add + + +class DensePattern(QuantizerPattern): + """Pattern to rewrite nn.dense pattern as qnn.dense. + Parameters + ---------- + calibration_callback : CalibrationCallback + The method we will use to calibrate this pattern. + """ + + def __init__(self, calibration_callback: CalibrationCallback = None): + super().__init__(calibration_callback) + self.data = wildcard() + self.weight = wildcard() + self.inputs = [self.data, self.weight] + + self.dense = is_op("nn.dense")(self.data, self.weight) + + self.pattern = self.dense + self.attrs = None + self.units = None + + def get_attrs(self, attrs, weight_shape): + """Constructs the attributes for qnn.conv2d. + + Parameters + ---------- + attrs : dict + Attributes of the original nn.dense + + weight_shape : NDArray + Shape of the dense weights + + Returns + ------- + quantized_attrs : dict + Attributes for the qnn.conv2d + """ + self.attrs = {} + units = attrs["units"] + if units is None: + units = weight_shape[0] + self.units = units.value + self.attrs["units"] = self.units + + def quantize_args(self): + """Quantizes the arguments to the nn.dense pattern.""" + # Quantize data and construct args for qnn.dense + quantized_data = relay.qnn.op.quantize(self.args[0], self.scale_zps[0], self.scale_zps[1]) + quantized_weight = relay.qnn.op.quantize( + self.args[1], self.scale_zps[2], self.scale_zps[3], axis=0 + ) # Axis = 0 for per channel quantization + self.quantized_args = [quantized_data, quantized_weight] + + def create_dense(self, args, node_map): + """Creates the qnn.dense. + + Parameters + ---------- + args : List[relay.Expr] + Quantized arguments for the qnn.dense. + + node_map : tvm.ir.container.Map + Node map from DFPatternCallback's callback + + Returns + ------- + q_dense : relay.Expr + Quantized version of the pattern. + """ + qnn_call = relay.qnn.op.dense(*args, **self.attrs) + return qnn_call + + def callback(self, pre, post, node_map): + self.args = [node_map[i][0] for i in self.inputs] + weight = node_map[self.weight][0] + + dense = node_map[self.dense][0] + out_dtype = dense.checked_type.dtype + self.get_attrs(dense.attrs, infer_type(weight).checked_type.shape) + self.create_scale_zps("dense_data", "dense_weight") + self.quantize_args() + + # args = [quantized_data, quantized_weight, data_zp, weight_zp, data_scale, weight_scale] + args = self.quantized_args[0:2] + [self.scale_zps[i] for i in [1, 3, 0, 2]] + qnn_call = self.create_dense(args, node_map) + + deq_call = relay.qnn.op.dequantize( + qnn_call, + self.scale_zps[0] * self.scale_zps[2], + relay.const(0, dtype="int32"), + out_dtype=out_dtype, + axis=1, + ) + + return deq_call + + +class DenseBiasAddPattern(DensePattern): + """Pattern to rewrite nn.dense -> add and nn.dense -> nn.bias_add pattern as qnn.dense -> nn.bias_add. + + Parameters + ---------- + calibration_callback : CalibrationCallback + The method we will use to calibrate this pattern. + """ + + def __init__(self, calibration_callback: CalibrationCallback = None): + super().__init__(calibration_callback) + self.bias_weight = is_constant() + self.inputs.append(self.bias_weight) + self.bias_add = is_op("nn.bias_add")(self.dense, self.bias_weight) + self.add = is_op("add")(self.dense, self.bias_weight) + self.pattern = self.bias_add | self.add + + def quantize_args(self): + super().quantize_args() + quantized_bias = relay.qnn.op.quantize( + self.args[2], self.scale_zps[0], self.scale_zps[1], axis=0, out_dtype="int32" + ) + self.quantized_args.append(quantized_bias) + + def create_dense(self, args, node_map): + qnn_call = relay.qnn.op.dense(*args, **self.attrs) + if node_map.get(self.add) is not None: + bias_add = relay.op.add(qnn_call, self.quantized_args[2]) + else: # self.bias_add in node_map + bias_add = relay.op.nn.bias_add( + qnn_call, self.quantized_args[2], axis=1 # Axis is always 1 for dense + ) + return bias_add + + +class AddPattern(QuantizerPattern): + """Pattern to rewrite add as quantized. + + Parameters + ---------- + calibration_callback : CalibrationCallback + The method we will use to calibrate this pattern. + """ + + def __init__(self, calibration_callback: CalibrationCallback = None): + super().__init__(calibration_callback) + self.lhs = wildcard() + self.rhs = wildcard() + self.add = is_op("add")(self.lhs, self.rhs) + self.pattern = self.add + + def callback(self, pre, post, node_map): + lhs = node_map[self.lhs][0] + rhs = node_map[self.rhs][0] + + add = node_map[self.add][0] + + out_dtype = infer_type(add).checked_type.dtype + + # Create quantization parameters for arguments to this addition + self.create_scale_zps("add_lhs", "add_rhs") + + # Quantize, dequantize, and requantize inputs to have scale lhs_scale + rhs_scale + # (Scale represents the lowest possible value representable in the quantized type, + # so the smallest representable output is lhs_scale + rhs_scale) + + # We do this to avoid the requantize op in qnn's add, which causes issues with compilation + # Requantize will be inserted in a future pass + lhs_scale, lhs_zp, rhs_scale, rhs_zp = self.scale_zps + quantized_lhs = relay.qnn.op.quantize(lhs, lhs_scale, lhs_zp) + quantized_rhs = relay.qnn.op.quantize(rhs, rhs_scale, rhs_zp) + + dequantized_lhs = relay.qnn.op.dequantize( + quantized_lhs, lhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype + ) + dequantized_rhs = relay.qnn.op.dequantize( + quantized_rhs, rhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype + ) + + add_scale = relay.op.add(lhs_scale, rhs_scale) + + requantized_lhs = relay.qnn.op.quantize( + dequantized_lhs, add_scale, relay.const(0, dtype="int32") + ) + requantized_rhs = relay.qnn.op.quantize( + dequantized_rhs, add_scale, relay.const(0, dtype="int32") + ) + + add = relay.op.add(requantized_lhs, requantized_rhs) + dequantized_call = relay.qnn.op.dequantize( + add, add_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype + ) + + return dequantized_call + + +class MultiplyPattern(QuantizerPattern): + """Pattern to rewrite multiply as quantized. + + Parameters + ---------- + calibration_callback : CalibrationCallback + The method we will use to calibrate this pattern. + """ + + def __init__(self, calibration_callback: CalibrationCallback = None): + super().__init__(calibration_callback) + self.lhs = wildcard() + self.rhs = wildcard() + + self.multiply = is_op("multiply")(self.lhs, self.rhs) + self.pattern = self.multiply + + def callback(self, pre, post, node_map): + lhs = node_map[self.lhs][0] + rhs = node_map[self.rhs][0] + + multiply = node_map[self.multiply][0] + + out_dtype = infer_type(multiply).checked_type.dtype + + # Create quantization parameters for arguments to this multiplication. + self.create_scale_zps("mul_lhs", "mul_rhs") + lhs_scale, lhs_zp, rhs_scale, rhs_zp = self.scale_zps + + # Quantize inputs and construct args for multiply + quantized_lhs = tvm.relay.cast(relay.qnn.op.quantize(lhs, lhs_scale, lhs_zp), "int32") + quantized_rhs = tvm.relay.cast(relay.qnn.op.quantize(rhs, rhs_scale, rhs_zp), "int32") + + # Use normal relay multiply instead of qnn multiply to avoid requantize in qnn.mul + # Subtract zero points to center on zero so that we can multiply lhs, rhs directly + zeroed_quantized_lhs = relay.op.subtract(quantized_lhs, lhs_zp) + zeroed_quantized_rhs = relay.op.subtract(quantized_rhs, rhs_zp) + + multiply = relay.op.multiply(zeroed_quantized_lhs, zeroed_quantized_rhs) + dequantized_call = relay.qnn.op.dequantize( + multiply, lhs_scale * rhs_scale, relay.const(0, dtype="int32"), out_dtype=out_dtype + ) + + return dequantized_call + + +class PerChannelPattern: + """A parent class for patterns that will be per-channel quantized. PerChannelPattern should + only be inherited by a class that also inherits QuantizerPattern or a subclass of it. + """ + + def extract_attrs(self, pre, post, node_map): + """A callback to get the quantized attributes of this pattern. Usually, we just call + self.get_attrs on the attributes of the original, unquantized node to construct the + quantized attributes. Since this callback is used by the pattern rewriter, we must return + a relay.Expr from it. + + Parameters + ---------- + pre : relay.Expr + Expression before transformation + + post : relay.Expr + Expression after transformation + + node_map : Map of pattern to relay.Expr + Contains expressions matching parts of the pattern. + + Returns + ------- + post : relay.Expr + Expression to rewrite the input expression as. We don't actually want to rewrite + anything in this pass, so you should just return post. + """ + raise NotImplementedError() + + def get_scale_size(self): + """Returns the size of the per-channel scale variable + + Returns + ------- + scale_size : tuple + The size of the scale variable + """ + raise NotImplementedError + + def weight_scale(self, name): + """Helper to create a variable for a per-channel scale. + Parameters + ---------- + name : str + Name of the variable + """ + var = relay.var( + str(name) + "_scale_" + str(QuantizerPattern.scales_count), + shape=self.get_scale_size(), + dtype="float32", + ) + QuantizerPattern.scales_count += 1 + return var + + def create_scale_zps(self, left_name, right_name): + """Helper to create scales and zero points for binops, with the per channel weight scale quantized. + + Parameters + ---------- + left_name : str + Identifier of the left hand side scale and zero point. + + right_name : str + Identifier of the right hand side scale and zero point. + """ + # Create quantization parameters for arguments with per channel on the right + data_scale = self.scale(left_name) + data_zp = self.zero_point(left_name) + + weight_scale = self.weight_scale(right_name) + weight_zp = self.zero_point(right_name) + self.scale_zps = [data_scale, data_zp, weight_scale, weight_zp] + + def attr_callback(self, expr): + """A function to get the attributes of the quantized version of the current + pattern. Meant to be called from inside calibrate_pattern. + + Parameters + ---------- + expr : relay.Expr + Expression that we want the attributes from. This will be the unquantized + version of the expression. + """ + pattern_ffi.rewrite( + [_DFPatternCallback(self.pattern, self.extract_attrs, self.require_type)], + infer_type(expr), + tvm.ir.IRModule(), + False, + ) diff --git a/python/tvm/relay/transform/quantize/_requantizer.py b/python/tvm/relay/transform/quantize/_requantizer.py new file mode 100644 index 000000000000..2f557fbfb68f --- /dev/null +++ b/python/tvm/relay/transform/quantize/_requantizer.py @@ -0,0 +1,312 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Removes extraneous qnn.quantize and qnn.dequantize from calibrated modules, and replaces them +with qnn.requanize ops.""" +import math + +import tvm +from tvm import relay +from tvm.relay.dataflow_pattern import DFPatternCallback, wildcard, is_op, dominates, rewrite + + +class Requantizer: + """Removes extraneous qnn.quantize and qnn.dequantize and replaces + them with qnn.requantize.""" + + class RequantizerCallback(DFPatternCallback): + """First pass that inserts requantize ops, specifically taking + qnn.dequantize -> qnn.quantize to qnn.requantize + and + qnn.dequantize -> int8_op* -> qnn.quantize to requantize -> int8_op* + """ + + def __init__(self): + super().__init__() + + self.data = wildcard() + self.dequantize_scale = wildcard() + self.dequantize_zp = wildcard() + + self.quantize_scale = wildcard() + self.quantize_zp = wildcard() + + # Ops that are permitted inbetween quantize and dequantize if we are + # rewriting to requantize + self.is_int_8_op = ( + is_op("nn.max_pool2d")(wildcard()) + | is_op("nn.max_pool2d")(wildcard()) + | is_op("nn.max_pool3d")(wildcard()) + | is_op("nn.relu")(wildcard()) + | is_op("transpose")(wildcard()) + | is_op("reshape")(wildcard()) + | is_op("nn.pad")(wildcard()) + | is_op("squeeze")(wildcard()) + | is_op("nn.global_avg_pool2d") + | is_op("nn.batch_flatten") + | is_op("copy") + | is_op("mean") + | is_op("sqrt") + ) + + # All ops in is_int_8_op must also be in self.op_map + self.op_map = { + relay.op.get("nn.max_pool2d"): relay.op.nn.max_pool2d, + relay.op.get("nn.max_pool3d"): relay.op.nn.max_pool3d, + relay.op.get("transpose"): relay.op.transpose, + relay.op.get("reshape"): relay.op.reshape, + relay.op.get("nn.pad"): relay.op.nn.pad, + relay.op.get("squeeze"): relay.op.squeeze, + relay.op.get("nn.global_avg_pool2d"): relay.op.nn.global_avg_pool2d, + relay.op.get("nn.batch_flatten"): relay.op.nn.batch_flatten, + relay.op.get("copy"): relay.op.copy, + relay.op.get("mean"): relay.op.mean, + relay.op.get("sqrt"): relay.op.sqrt, + } + + # Main pattern -- quantize(is_int_8_op*(dequantize(data))) -- + # (with 1 or more is_int_8_ops) + self.dequantize = is_op("qnn.dequantize")( + self.data, self.dequantize_scale, self.dequantize_zp + ) + + self.dominator = dominates(self.dequantize, self.is_int_8_op, self.is_int_8_op) + self.quantize = is_op("qnn.quantize")( + self.dominator, self.quantize_scale, self.quantize_zp + ) + + # Pattern with the null path : quantize(dequantize(data)) -- (no is_int_8_op inbetween) + # We have to do the null path outside the dominator pattern because of pattern matcher + # limitations + self.no_path_dequantize = is_op("qnn.dequantize")( + self.data, self.dequantize_scale, self.dequantize_zp + ) + self.no_path_quantize = is_op("qnn.quantize")( + self.no_path_dequantize, self.quantize_scale, self.quantize_zp + ) + + self.pattern = self.quantize | self.no_path_quantize + + def callback(self, pre, post, node_map): + # Extract data from the pattern + data = node_map[self.data][0] + dequantize_scale = node_map[self.dequantize_scale][0] + deq_zp = node_map[self.dequantize_zp][0] + + quantize_scale = node_map[self.quantize_scale][0] + quantize_zp = node_map[self.quantize_zp][0] + + # Case where there are no ops in between the dequantize and quantize + if self.no_path_quantize in node_map: + axis = node_map[self.no_path_dequantize][0].attrs.axis + res = relay.qnn.op.requantize( + data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis + ) + # Ops inbetween quantize and dequantize are dominated + elif self.quantize in node_map: + + axis = node_map[self.dequantize][0].attrs.axis + transformed_data = relay.qnn.op.requantize( + data, dequantize_scale, deq_zp, quantize_scale, quantize_zp, axis=axis + ) + for i in range(len(node_map[self.is_int_8_op]) - 1, -1, -1): + call = node_map[self.is_int_8_op][i] + # Transform relu into max(zeropoint) + if call.op == relay.op.get("nn.relu"): + if ( + quantize_zp.data.asnumpy() + == relay.const(0, dtype="int32").data.asnumpy() + ): + transformed_data = relay.op.nn.relu(transformed_data) + else: + transformed_data = relay.op.maximum( + transformed_data, relay.cast(quantize_zp, "int8") + ) + elif call.op in self.op_map.keys(): + transformed_data = self.op_map[call.op](transformed_data, **call.attrs) + else: + raise ValueError( + "Uh oh, %s is not copied properly in the requantizer. " % str(call.op) + ) + res = transformed_data + return res + + class RequantizeChainCallback(DFPatternCallback): + """Folds chains of requantizes into one requantize. + requantize(scale_a, zp_a, scale_b, zp_b) -> requantize(scale_b, zp_b, scale_c, zp_c) becomes + requantize(scale_a, zp_a, scale_c, zp_c) + """ + + # Takes a chain of requantizes and turns them into one requantize + def __init__(self): + super().__init__() + self.data = wildcard() + self.rq_parent_scale_in = wildcard() + self.rq_parent_zp_in = wildcard() + self.rq_parent_scale_out = wildcard() + self.rq_parent_zp_out = wildcard() + + self.rq_child_scale_in = wildcard() + self.rq_child_zp_in = wildcard() + self.rq_child_scale_out = wildcard() + self.rq_child_zp_out = wildcard() + + self.rq_parent = is_op("qnn.requantize")( + self.data, + self.rq_parent_scale_in, + self.rq_parent_zp_in, + self.rq_parent_scale_out, + self.rq_parent_zp_out, + ) + self.rq_child = is_op("qnn.requantize")( + wildcard(), + self.rq_child_scale_in, + self.rq_child_zp_in, + self.rq_child_scale_out, + self.rq_child_zp_out, + ) + + self.pattern = dominates(self.rq_parent, self.rq_child, self.rq_child) + + def callback(self, pre, post, node_map): + data = node_map[self.data][0] + rq_parent = node_map[self.rq_parent][0] + + rq_parent_scale_in = node_map[self.rq_parent_scale_in][0] + rq_parent_zp_in = node_map[self.rq_parent_zp_in][0] + + rq_parent_scale_out = node_map[self.rq_parent_scale_out][0] + rq_parent_zp_out = node_map[self.rq_parent_zp_out][0] + + child_in_scales = node_map[self.rq_child_scale_in] + child_in_zps = node_map[self.rq_child_zp_in] + child_out_scales = node_map[self.rq_child_scale_out] + child_out_zps = node_map[self.rq_child_zp_out] + + len_children = len(node_map[self.rq_child_scale_out]) + + # Check to make sure output and input scales and zps match before we apply this + # transformation + out_scale = rq_parent_scale_out + out_zp = rq_parent_zp_out + + for i in range(0, len_children): + + in_scale = child_in_scales[i] + in_zp = child_in_zps[i] + + assert math.isclose( + out_scale.data.asnumpy(), in_scale.data.asnumpy(), rel_tol=1e-05, abs_tol=1e-05 + ) and math.isclose( + out_zp.data.asnumpy(), in_zp.data.asnumpy(), rel_tol=1e-05, abs_tol=1e-05 + ), ( + "Out scales/zps should match in scales/zps. Indicates an internal issue " + "in the quantizer somewhere." + ) + + out_scale = child_out_scales[i] + out_zp = child_out_zps[i] + + parent_axis = rq_parent.attrs["axis"] + + return relay.qnn.op.requantize( + data, rq_parent_scale_in, rq_parent_zp_in, out_scale, out_zp, axis=parent_axis + ) + + class ConsolidateRequantizeandQuantize(DFPatternCallback): + """Gets rid of unnecessary requantizes directly following a quantize. Takes + quantize(scale_a, zp_a) -> requantize(scale_a, zp_a, scale_b, zp_b) to + quantize(scale_b, zp_b) + """ + + def __init__(self): + super().__init__() + + self.data = wildcard() + self.q_scale = wildcard() + self.q_zp = wildcard() + + self.rq_scale_out = wildcard() + self.rq_zp_out = wildcard() + self.rq_scale_in = wildcard() + self.rq_zp_in = wildcard() + + self.quantize = is_op("qnn.quantize")(self.data, self.q_scale, self.q_zp) + self.requantize = is_op("qnn.requantize")( + self.quantize, self.rq_scale_in, self.rq_zp_in, self.rq_scale_out, self.rq_zp_out + ) + + self.pattern = self.requantize + + def callback(self, pre, post, node_map): + + data = node_map[self.data][0] + requantize = node_map[self.requantize][0] + + q_scale = node_map[self.q_scale][0] + q_zp = node_map[self.q_zp][0] + + np_q_scale = q_scale.data.asnumpy() + np_q_zp = q_zp.data.asnumpy() + + rq_scale_in = node_map[self.rq_scale_in][0] + rq_zp_in = node_map[self.rq_zp_in][0] + + np_rq_scale = rq_scale_in.data.asnumpy() + np_rq_zp = rq_zp_in.data.asnumpy() + assert math.isclose( + np_q_scale, np_rq_scale, rel_tol=1e-05, abs_tol=1e-05 + ) and math.isclose(np_q_zp, np_rq_zp, rel_tol=1e-05, abs_tol=1e-05), ( + "Scales and zps should match between adjacent quantize and requantizes, " + "indicates a problem earlier in quantization" + ) + + output_scale = node_map[self.rq_scale_out][0] + output_zp = node_map[self.rq_zp_out][0] + + requantize_axis = requantize.attrs["axis"] + # Rewrite subgraph to just one quantize + return relay.qnn.op.quantize(data, output_scale, output_zp, axis=requantize_axis) + + def requantize(self, func): + """Removes extraneous qnn.quantize and qnn.dequantize ops and replaces them with + qnn.requantize ops. + + Parameters + ---------- + func : relay.Function + Function to requantize. + """ + # We have to fold scale/zp expressions for requantize to work + # AlterOpLayout adds some extra ops that mess things up + optimize = tvm.transform.Sequential( + [relay.transform.FoldConstant(), relay.transform.EliminateCommonSubexpr()] + ) + + mod = tvm.ir.IRModule.from_expr(func) + with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + rewritten_func = optimize(mod)["main"] + + rewritten_func = rewrite( + self.RequantizerCallback(), rewritten_func, allow_overlapping_groups=True + ) + rewritten_func = rewrite(self.RequantizeChainCallback(), rewritten_func) + rewritten_func = rewrite(self.ConsolidateRequantizeandQuantize(), rewritten_func) + + rewritten_mod = tvm.ir.IRModule.from_expr(rewritten_func) + + return rewritten_mod["main"] diff --git a/python/tvm/relay/transform/quantize/demos/average_mean_quantize_bert.py b/python/tvm/relay/transform/quantize/demos/average_mean_quantize_bert.py new file mode 100644 index 000000000000..5ecde285a571 --- /dev/null +++ b/python/tvm/relay/transform/quantize/demos/average_mean_quantize_bert.py @@ -0,0 +1,102 @@ +# Demo based on code from https://www.tensorflow.org/tutorials/images/cnn + +import tensorflow as tf +import tvm +from tvm import relay +from tvm.relay.data import DatasetManager +from tvm.relay.transform.quantize import Quantizer, GlobalCalibrator, Requantizer + +from tensorflow.keras import datasets, layers, models +import matplotlib.pyplot as plt +import onnx +import numpy as np + +class HardcodedBertInputs(DatasetManager): + # Assumes numpy_data is in form [num_inputs, c, h, w] and labels is [num_inputs] + def __init__(self, n_batches=100): + self.idx = 0 + self.num_batches = 100 + + def get_next_batch(self): + if self.is_empty(): + raise IndexError + + unique_ids_raw_output = np.random.randn([1]) + segment_ids = np.random.randn([1, 256]) + input_mask = np.random.randn([1, 256]) + input_ids = np.random.randn([1, 256]) + self.idx += 1 + return [unique_ids_raw_output, segment_ids, input_mask, input_ids], None + + def num_batches(self): + return self.num_batches + + def is_empty(self): + return self.idx >= self.num_batches + + def reset(self): + self.idx = 0 + +batch_size = 1 +onnx_model = onnx.load('/home/lorthsmith/tvm/python/tvm/relay/new_quantize/demos/bertsquad-10.onnx') +input_dict = {'unique_ids_raw_output___9:0': [1], 'segment_ids:0': [1, 256], 'input_mask:0': [1, 256], 'input_ids:0': [1, 256]} +mod, params = relay.frontend.from_onnx(onnx_model, input_dict) + +quantized_mod, calibration_map = Quantizer().quantize(mod, params, skip_layers=[0]) + +with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + #lib = relay.build(mod, target='llvm') + q_lib = relay.build(quantized_mod, target='llvm') + +# Calibrate +global_calibrator = GlobalCalibrator(0.05, 0) +calibrated_mod = global_calibrator.calibrate(quantized_mod, calibration_map) +print("Calibrated mod: \n", calibrated_mod.astext(False)) + +# Requantize +requantized_mod = Requantizer().requantize(calibrated_mod) +print("Requantized mod: \n", requantized_mod.astext(False)) + +with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + #lib = relay.build(mod, params=params, target='llvm') + q_lib = relay.build(requantized_mod, target='llvm') + +from tvm.contrib import graph_runtime +q_gmod = graph_runtime.GraphModule(q_lib["default"](tvm.cpu())) +q_gmod.set_input(**{'unique_ids_raw_output___9:0': np.random.randn(*[1]), 'segment_ids:0': np.random.randn(*[1, 256]), 'input_mask:0': np.random.randn(*[1, 256]), 'input_ids:0': np.random.randn(*[1, 256])}) +q_gmod.run() +q_out = q_gmod.get_output(0).asnumpy() +print(q_out) +exit() + +from tvm.contrib import graph_runtime +q_gmod = graph_runtime.GraphModule(q_lib["default"](tvm.cpu())) +gmod = graph_runtime.GraphModule(lib["default"](tvm.cpu())) +q_correct = 0 +correct = 0 +total = 0 + +while not test_dataset_manager.is_empty(): + image_list, label = test_dataset_manager.get_next_batch() + q_gmod.set_input(**{'conv2d_input:0': image_list[0]}) + q_gmod.run() + q_out = q_gmod.get_output(0).asnumpy() + + gmod.set_input(**{'conv2d_input:0': image_list[0]}) + gmod.run() + out = gmod.get_output(0).asnumpy() + + q_predicted_labels = np.argmax(q_out, axis=1) + predicted_labels = np.argmax(out, axis=1) + + #print("Int8 labels: ", q_predicted_labels) + #print("Float32 labels: ", predicted_labels) + #print("Actual labels: ", label) + + q_correct += np.sum(q_predicted_labels == label) + correct += np.sum(predicted_labels == label) + total += batch_size + +print("Int8 percent correct: ", (q_correct / total) * 100) +print("Float32 percent correct: ", (correct / total) * 100) +print("Difference: ", (((correct / total) * 100) - ((q_correct / total) * 100))) diff --git a/python/tvm/relay/transform/quantize/demos/average_mean_quantize_cifar.py b/python/tvm/relay/transform/quantize/demos/average_mean_quantize_cifar.py new file mode 100644 index 000000000000..ea621e20d363 --- /dev/null +++ b/python/tvm/relay/transform/quantize/demos/average_mean_quantize_cifar.py @@ -0,0 +1,179 @@ +# Demo based on code from https://www.tensorflow.org/tutorials/images/cnn +import onnx +import tensorflow as tf +import tvm +from tvm import relay +from tvm.data import DatasetManager +from tvm.relay.transform.quantize import ( + Quantizer, + QuantizationCalibrator, + AverageMaxCalibrationCallback, + GlobalCalibrationCallback, + Requantizer, + AverageMaxPerChannelConv2DBiasAddPattern, + AverageMaxPerChannelConv2DPattern, + Conv2DBiasAddPattern, + Conv2DPattern, + DensePattern, + AddPattern, + MultiplyPattern, + AverageMaxPerChannelConv2DBiasAddPattern, + AverageMaxPerChannelConv2DPattern, + AverageMaxPerChannelDensePattern, +) + +from tensorflow.keras import datasets + +import numpy as np + +# tf and onnx use different versions of protobuf?? +# Versions that work: pip installed protobuf version 3.12.2 +# Need libprotobuf version 3.0.0 (libprotobuf is also protoc) + + +class NumpyDatasetManager(DatasetManager): + # Assumes numpy_data is in form [num_inputs, c, h, w] and labels is [num_inputs] + def __init__(self, numpy_data, numpy_labels, batch_size=1, n_batches=None): + self.idx = 0 + self.numpy_data = numpy_data + self.numpy_labels = numpy_labels + assert ( + self.numpy_data.shape[0] == self.numpy_labels.shape[0] + ), "First dimension of data and label arrays must match." + assert ( + self.numpy_data.shape[0] >= batch_size + ), "Batch size too large. You must provide enough data points for at least one batch." + self.batch_sz = batch_size + if n_batches is None: + self.n_batches = numpy_data.shape[0] // self.batch_size + else: + assert n_batches * batch_size <= numpy_data.shape[0] + self.n_batches = n_batches + + def get_next_batch(self): + if self.is_empty(): + raise IndexError + batched_data = self.numpy_data[self.idx * self.batch_sz : (self.idx + 1) * self.batch_sz] + batched_label = self.numpy_labels[self.idx * self.batch_sz : (self.idx + 1) * self.batch_sz] + self.idx += 1 + return [batched_data], batched_label + + def batch_size(self): + return self.batch_sz + + def num_batches(self): + return self.n_batches + + def is_empty(self): + return self.idx >= self.n_batches + + def reset(self): + self.idx = 0 + + +(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() + +# Normalize pixel values to be between 0 and 1 +train_images, test_images = train_images / 255.0, test_images / 255.0 + +# Create dataset manager +# For "training", it seems like batch size 10 and n batches = 5000 works pretty well +batch_size = 10 +train_dataset_manager = NumpyDatasetManager( + train_images, np.ndarray.flatten(train_labels), batch_size, n_batches=5000 +) +test_dataset_manager = NumpyDatasetManager( + test_images, np.ndarray.flatten(test_labels), batch_size, n_batches=1000 +) + +# Load onnx model (model obtained from https://www.tensorflow.org/tutorials/images/cnn), exported to onnx +onnx_model = onnx.load( + "/home/lorthsmith/tvm/python/tvm/relay/transform/quantize/demos/cifar-model.onnx" +) +input_dict = {"conv2d_input:0": [batch_size, 32, 32, 3]} +mod, params = relay.frontend.from_onnx(onnx_model, input_dict) +print("main: ", mod["main"]) +cc = AverageMaxCalibrationCallback() +quantizer = Quantizer( + mod["main"], + params, + [ + AverageMaxPerChannelConv2DBiasAddPattern(cc), + AverageMaxPerChannelConv2DPattern(cc), + AverageMaxPerChannelDensePattern(cc), + AddPattern(cc), + MultiplyPattern(cc), + ], + skip_last=False, +) # , AddPattern(cc), MultiplyPattern(cc)], skip_last=False) +# quantizer = Quantizer(mod['main'], params, [Conv2DBiasAddPattern(cc), Conv2DPattern(cc), DensePattern(cc), AddPattern(cc), MultiplyPattern(cc)], skip_last=True, skip_first=True)#, AddPattern(cc), MultiplyPattern(cc)], skip_last=False) + +# cc = GlobalCalibrationCallback(2.0, 0) +# quantizer = Quantizer(mod['main'], params, [Conv2DBiasAddPattern(cc), Conv2DPattern(cc), DensePattern(cc), AddPattern(cc), MultiplyPattern(cc)], skip_last=False)#, AddPattern(cc), MultiplyPattern(cc)], skip_last=False) + + +calibrator = QuantizationCalibrator( + quantizer, + target="llvm", + ctx=tvm.cpu(), + dataset_manager=train_dataset_manager, + show_scale_zps=True, +) +calibrated_func = calibrator.calibrate() +calibrated_mod = tvm.ir.IRModule.from_expr(calibrated_func) +print("Calibrated func: ", calibrated_func) +print("Requantizing...") +requantized_func = Requantizer().requantize(calibrated_func) +print("Requantized func: ", requantized_func) +requantized_mod = tvm.ir.IRModule.from_expr(requantized_func) + +print("Calculating accuracy...") +with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + lib = relay.build(mod, params=params, target="llvm") + c_lib = relay.build(calibrated_mod, params=params, target="llvm") + q_lib = relay.build(requantized_mod, params=params, target="llvm") + + +from tvm.contrib import graph_runtime + +q_gmod = graph_runtime.GraphModule(q_lib["default"](tvm.cpu())) +c_gmod = graph_runtime.GraphModule(c_lib["default"](tvm.cpu())) +gmod = graph_runtime.GraphModule(lib["default"](tvm.cpu())) +q_correct = 0 +c_correct = 0 +correct = 0 +total = 0 + +while not test_dataset_manager.is_empty(): + image_list, label = test_dataset_manager.get_next_batch() + q_gmod.set_input(**{"conv2d_input:0": image_list[0]}) + q_gmod.run() + q_out = q_gmod.get_output(0).asnumpy() + + c_gmod.set_input(**{"conv2d_input:0": image_list[0]}) + c_gmod.run() + c_out = q_gmod.get_output(0).asnumpy() + + gmod.set_input(**{"conv2d_input:0": image_list[0]}) + gmod.run() + out = gmod.get_output(0).asnumpy() + + q_predicted_labels = np.argmax(q_out, axis=1) + c_predicted_labels = np.argmax(c_out, axis=1) + predicted_labels = np.argmax(out, axis=1) + + print("Int8 labels: ", q_predicted_labels) + print("Calibrated int8 labels: ", c_predicted_labels) + print("Float32 labels: ", predicted_labels) + print("Actual labels: ", label) + print() + + q_correct += np.sum(q_predicted_labels == label) + c_correct += np.sum(c_predicted_labels == label) + correct += np.sum(predicted_labels == label) + total += batch_size + +print("Int8 percent correct: ", (q_correct / total) * 100) +print("Calibrated Int8 percent correct: ", (c_correct / total) * 100) +print("Float32 percent correct: ", (correct / total) * 100) +print("Difference: ", (((correct / total) * 100) - ((q_correct / total) * 100))) diff --git a/python/tvm/relay/transform/quantize/demos/average_mean_quantize_mnist.py b/python/tvm/relay/transform/quantize/demos/average_mean_quantize_mnist.py new file mode 100644 index 000000000000..4e65586edc60 --- /dev/null +++ b/python/tvm/relay/transform/quantize/demos/average_mean_quantize_mnist.py @@ -0,0 +1,139 @@ +import tvm +from tvm import relay +from tvm.data import TFDatasetManager +from tvm.relay.transform.quantize import ( + Quantizer, + QuantizationCalibrator, + AverageMaxCalibrationCallback, + Conv2DBiasAddPattern, + Conv2DPattern, + DenseBiasAddPattern, + DensePattern, + AddPattern, + MultiplyPattern, + Requantizer, + AverageMaxPerChannelConv2DPattern, + AverageMaxPerChannelConv2DBiasAddPattern, + AverageMaxPerChannelDensePattern, + AverageMaxPerChannelDenseBiasAddPattern, +) +import onnx +import tensorflow.compat.v2 as tf +import tensorflow_datasets as tfds + +tf.enable_v2_behavior() + +import numpy as np + +batch_size = 5 + +# TFDS loading from https://www.tensorflow.org/datasets/keras_example +(ds_train, ds_test), ds_info = tfds.load( + "mnist", split=["train", "test"], shuffle_files=True, as_supervised=True, with_info=True +) + +# Import data +def normalize_img(image, label): + """Normalizes images: `uint8` -> `float32`.""" + return tf.cast(image, tf.float32) / 255.0, label + + +ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) +ds_train = ds_train.cache() +ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) +ds_train = ds_train.batch(batch_size) +ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE) + +ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE) +ds_test = ds_test.batch(batch_size) +ds_test = ds_test.cache() +ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE) + +num_batches = 2000 +mnist_train_manager = TFDatasetManager(ds_train, batch_size, 12000) +mnist_test_manager = TFDatasetManager(ds_test, batch_size, 2000) + +# Import onnx model, quantize and calibrate +onnx_model = onnx.load( + "/home/lorthsmith/tvm/python/tvm/relay/transform/quantize/demos/mnist_model.onnx" +) +input_dict = {"flatten_input": [batch_size, 28, 28, 1]} +mod, params = relay.frontend.from_onnx(onnx_model, input_dict) + +cc = AverageMaxCalibrationCallback() + +print(mod["main"]) +cc = AverageMaxCalibrationCallback() +# quantizer = Quantizer(mod['main'], params, [Conv2DBiasAddPattern(cc), Conv2DPattern(cc), DenseBiasAddPattern(cc), DensePattern(cc), AddPattern(cc), MultiplyPattern(cc)], skip_first=False) +# quantizer = Quantizer(mod['main'], params, [AverageMaxPerChannelConv2DBiasAddPattern(), AverageMaxPerChannelConv2DPattern(), AverageMaxPerChannelDenseBiasAddPattern(), AverageMaxPerChannelDensePattern(), AddPattern(cc), MultiplyPattern(cc)], skip_first=False, skip_last=False) +quantizer = Quantizer( + mod["main"], + params, + [ + AverageMaxPerChannelConv2DBiasAddPattern(cc), + AverageMaxPerChannelConv2DPattern(cc), + AverageMaxPerChannelDenseBiasAddPattern(cc), + AverageMaxPerChannelDensePattern(cc), + AddPattern(cc), + MultiplyPattern(cc), + ], + skip_first=False, + skip_last=False, +) +quantizer = Quantizer( + mod["main"], + params, + [ + AverageMaxPerChannelConv2DBiasAddPattern(cc), + AverageMaxPerChannelConv2DPattern(cc), + AverageMaxPerChannelDensePattern(cc), + AddPattern(cc), + MultiplyPattern(cc), + ], + skip_first=False, + skip_last=False, +) + +calibrator = QuantizationCalibrator( + quantizer, target="llvm", ctx=tvm.cpu(), dataset_manager=mnist_train_manager +) +calibrated_func = calibrator.calibrate() +calibrated_mod = tvm.ir.IRModule.from_expr(calibrated_func) +requantized_func = Requantizer().requantize(calibrated_func) +print(requantized_func) +with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + lib = relay.build(mod, params=params, target="llvm") + q_lib = relay.build(calibrated_mod, params=params, target="llvm") +from tvm.contrib import graph_runtime + +q_gmod = graph_runtime.GraphModule(q_lib["default"](tvm.cpu())) +gmod = graph_runtime.GraphModule(lib["default"](tvm.cpu())) +q_correct = 0 +correct = 0 +total = 0 + +while not mnist_test_manager.is_empty(): + images, labels = mnist_test_manager.get_next_batch() + + q_gmod.set_input(**{"flatten_input": images[0]}) + q_gmod.run() + q_out = q_gmod.get_output(0).asnumpy() + + gmod.set_input(**{"flatten_input": images[0]}) + gmod.run() + out = gmod.get_output(0).asnumpy() + + q_predicted_labels = np.argmax(q_out, axis=1) + predicted_labels = np.argmax(out, axis=1) + + print("Int8 labels: ", q_predicted_labels) + print("Float32 labels: ", predicted_labels) + print("Actual labels: ", labels) + + q_correct += np.sum(q_predicted_labels == labels) + correct += np.sum(predicted_labels == labels) + total += batch_size + +print("Int8 percent correct: ", (q_correct / total) * 100) +print("Float32 percent correct: ", (correct / total) * 100) +print("Difference: ", (((correct / total) * 100) - ((q_correct / total) * 100))) diff --git a/python/tvm/relay/transform/quantize/demos/average_mean_quantize_resnet.py b/python/tvm/relay/transform/quantize/demos/average_mean_quantize_resnet.py new file mode 100644 index 000000000000..20bea756f6fb --- /dev/null +++ b/python/tvm/relay/transform/quantize/demos/average_mean_quantize_resnet.py @@ -0,0 +1,62 @@ + +import tvm +import tvm.relay.testing +from tvm import relay +import torch + +from torchvision.models import resnet +from tvm.data import RandomDatasetManager +from tvm.relay.transform.quantize import Quantizer, QuantizationCalibrator, AverageMaxCalibrationCallback, AverageMaxPerChannelConv2DBiasAddPattern, AverageMaxPerChannelConv2DPattern, AverageMaxPerChannelDenseBiasAddPattern, AverageMaxPerChannelConv2DPattern, AverageMaxPerChannelDensePattern, AddPattern, MultiplyPattern, Requantizer + +import numpy as np + + +pytorch_model = resnet.resnet18(pretrained=True) +input_name = "input" # the input name can be be arbitrary for PyTorch frontend. +input_shape = (3, 3, 224, 224) +named_input_shape = [(input_name, input_shape)] +input_data = torch.randn(input_shape) +script_module = torch.jit.trace(pytorch_model, input_data) + +input_shapes = [(input_name, input_shape)] +mod, params = relay.frontend.from_pytorch(script_module, named_input_shape) +print(mod['main']) +cc = AverageMaxCalibrationCallback() +# Conv2d bias does does not work +# Dense works + +#quantizer = Quantizer(mod['main'], params, [AverageMaxPerChannelDenseBiasAddPattern(cc), AverageMaxPerChannelDensePattern(cc)]) +quantizer = Quantizer(mod['main'], params, [AverageMaxPerChannelConv2DBiasAddPattern(cc), AverageMaxPerChannelConv2DPattern(cc), AverageMaxPerChannelDenseBiasAddPattern(cc), AverageMaxPerChannelDensePattern(cc), AddPattern(cc), MultiplyPattern(cc)]) +random_dataset_manager = RandomDatasetManager(input_shape, 'float32', 3, 20) + +calibrator = QuantizationCalibrator(quantizer, target='llvm', ctx=tvm.cpu(), dataset_manager=random_dataset_manager) +calibrated_func = calibrator.calibrate() +print(calibrated_func) +requantized_func = Requantizer().requantize(calibrated_func) +requantized_mod = tvm.ir.IRModule.from_expr(requantized_func) +print(requantized_mod) + +with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + lib = relay.build(mod, target='llvm') + #q_lib = relay.build(requantized_mod, target='llvm') + +from tvm.contrib import graph_runtime +input_np = np.random.randn(*input_shape).astype('float32') + +gmod = graph_runtime.GraphModule(lib["default"](tvm.cpu())) +gmod.set_input(**params) +gmod.set_input(input_name, input_np) +gmod.run() +out = gmod.get_output(0).asnumpy() +print("Unquantized Output:") +print(out) + + +print(" ___________ ") +q_gmod = graph_runtime.GraphModule(q_lib["default"](tvm.cpu())) +q_gmod.set_input(input_name, input_np) +q_gmod.set_input(**params) +q_gmod.run() +q_out = q_gmod.get_output(0).asnumpy() +print("Quantized output:") +print(q_out) diff --git a/python/tvm/relay/transform/quantize/demos/per_channel_test.py b/python/tvm/relay/transform/quantize/demos/per_channel_test.py new file mode 100644 index 000000000000..c8f6d52d600b --- /dev/null +++ b/python/tvm/relay/transform/quantize/demos/per_channel_test.py @@ -0,0 +1,144 @@ +# Demo based on code from https://www.tensorflow.org/tutorials/images/cnn + +import tensorflow as tf +import tvm +from tvm import relay +from tvm.relay.transform.quantize import Quantizer, AverageMeanCalibrator, DatasetManager, Requantizer, Calibrator + +from tensorflow.keras import datasets, layers, models +import onnx +import numpy as np + +class NumpyDatasetManager(DatasetManager): + # Assumes numpy_data is in form [num_inputs, c, h, w] and labels is [num_inputs] + def __init__(self, numpy_data, numpy_labels, batch_size=1, n_batches=None): + self.idx = 0 + self.numpy_data = numpy_data + self.numpy_labels = numpy_labels + assert self.numpy_data.shape[0] == self.numpy_labels.shape[0], "First dimension of data and label arrays must match." + assert self.numpy_data.shape[0] >= batch_size, "Batch size too large. You must provide enough data points for at least one batch." + self.batch_size = batch_size + if n_batches is None: + self.n_batches = numpy_data.shape[0] // self.batch_size + else: + assert n_batches * batch_size <= numpy_data.shape[0] + self.n_batches = n_batches + + def get_next_batch(self): + if self.is_empty(): + raise IndexError + batched_data = self.numpy_data[self.idx * self.batch_size : (self.idx + 1) * self.batch_size] + batched_label = self.numpy_labels[self.idx * self.batch_size : (self.idx + 1) * self.batch_size] + self.idx += 1 + return [batched_data], batched_label + + def num_batches(self): + return self.n_batches + + def is_empty(self): + return self.idx >= self.n_batches + + def reset(self): + self.idx = 0 + +class PerChannelTestCalibrator(Calibrator): + + def __init__(self, input_shape): + super().__init__() + self.input_shape = input_shape + + def _calibration_callback(self, variable_pairs): + value_dict = {} + op = self._get_layer_op() + attrs = self._get_layer_attributes() + # How will dequantize work? I don't know + + if (op == relay.op.get('qnn.dense')): + units = attrs['units'] + scales = np.random.randn(units).astype('float32') + ((data_scale, data_zp), (weight_scale, weight_zp)) = variable_pairs + value_dict[data_scale.name_hint] = np.array(2.0).astype('float32') + value_dict[data_zp.name_hint] = np.array(0).astype('int32') + value_dict[weight_zp.name_hint] = np.array(0).astype('int32') + #value_dict[weight_scale.name_hint] = np.array(2.0).astype('float32') + value_dict[weight_scale.name_hint] = scales + + elif op == relay.op.get('qnn.conv2d'): + channels = attrs['channels'] + scales = np.random.randn(channels).astype('float32') + ((data_scale, data_zp), (weight_scale, weight_zp)) = variable_pairs + value_dict[data_scale.name_hint] = np.array(2.0).astype('float32') + value_dict[data_zp.name_hint] = np.array(0).astype('int32') + value_dict[weight_zp.name_hint] = np.array(0).astype('int32') + #value_dict[weight_scale.name_hint] = np.array(2.0).astype('float32') + value_dict[weight_scale.name_hint] = scales + print(scales.shape) + else: + for (scale_var, zp_var) in variable_pairs: + value_dict[scale_var.name_hint] = np.array(2.0).astype('float32') + value_dict[zp_var.name_hint] = np.array(0).astype('int32') + return value_dict + +(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() + +# Normalize pixel values to be between 0 and 1 +train_images, test_images = train_images / 255.0, test_images / 255.0 + +# Create dataset manager +batch_size = 1 +train_dataset_manager = NumpyDatasetManager(train_images, np.ndarray.flatten(train_labels), batch_size, n_batches=100) +test_dataset_manager = NumpyDatasetManager(test_images, np.ndarray.flatten(test_labels), batch_size, n_batches=100) + +# Load onnx model (model obtained from https://www.tensorflow.org/tutorials/images/cnn), exported to onnx +onnx_model = onnx.load('/home/lorthsmith/tvm/python/tvm/relay/new_quantize/demos/cifar-model.onnx') +input_dict = {'conv2d_input:0': [batch_size, 32, 32, 3]} +mod, params = relay.frontend.from_onnx(onnx_model, input_dict) + +# Quantize +quantized_mod, calibration_map = Quantizer().quantize(mod, params, skip_layers=[0]) +#print("Quantized mod: \n", quantized_mod.astext(False)) + +# Calibrate +average_mean_calibrator = PerChannelTestCalibrator([batch_size, 32, 32, 3]) +calibrated_mod = average_mean_calibrator.calibrate(quantized_mod, calibration_map) +#print("Calibrated mod: \n", calibrated_mod.astext(False)) + +# Requantize +requantized_mod = Requantizer().requantize(calibrated_mod) +#print("Requantized mod: \n", requantized_mod.astext(False)) + +with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + lib = relay.build(mod, target='llvm') + q_lib = relay.build(requantized_mod, target='llvm') + +from tvm.contrib import graph_runtime +q_gmod = graph_runtime.GraphModule(q_lib["default"](tvm.cpu())) +gmod = graph_runtime.GraphModule(lib["default"](tvm.cpu())) +q_correct = 0 +correct = 0 +total = 0 + +while not test_dataset_manager.is_empty(): + image_list, label = test_dataset_manager.get_next_batch() + q_gmod.set_input(**{'conv2d_input:0': image_list[0]}) + q_gmod.run() + q_out = q_gmod.get_output(0).asnumpy() + + gmod.set_input(**{'conv2d_input:0': image_list[0]}) + gmod.run() + out = gmod.get_output(0).asnumpy() + + q_predicted_labels = np.argmax(q_out, axis=1) + predicted_labels = np.argmax(out, axis=1) + + #print("Int8 labels: ", q_predicted_labels) + #print("Float32 labels: ", predicted_labels) + #print("Actual labels: ", label) + + q_correct += np.sum(q_predicted_labels == label) + correct += np.sum(predicted_labels == label) + total += batch_size + +print("Int8 percent correct: ", (q_correct / total) * 100) +print("Float32 percent correct: ", (correct / total) * 100) +print("Difference: ", (((correct / total) * 100) - ((q_correct / total) * 100))) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index a43f50f600df..51049bf9a4af 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -22,6 +22,8 @@ * \brief The dataflow pattern matcher for Relay. */ +#include "dataflow_matcher.h" + #include #include #include @@ -34,45 +36,6 @@ namespace tvm { namespace relay { -// Pattern Matcher - -class DominatorMatcher; - -class DFPatternMatcher : public DFPatternFunctor { - public: - explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} - bool Match(const DFPattern& pattern, const Expr& expr); - Map> GetMemo() { return Map>(memo_); } - const IndexedGraph expr_graph_; - - protected: - bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; - bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; - bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; - - void ClearMap(size_t watermark); - bool MatchesPath(const DominatorPatternNode* op, const Expr& expr); - bool DominatesParent(const DominatorPatternNode* op, const Expr& expr); - - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; - std::vector matched_nodes_; - bool memoize_ = true; -}; - bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { memo_.clear(); matched_nodes_.clear(); @@ -530,6 +493,8 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern) class PatternGrouper { public: /*! \brief Internal Group class for storing analysis */ + PatternGrouper(bool allow_overlapping_groups = false) + : allow_overlapping_groups_(allow_overlapping_groups) {} struct Group { Expr root_node; int gid; @@ -658,8 +623,10 @@ class PatternGrouper { // Don't treat fuzzy Dominator patterns input variables for partition if (auto op = node->ref_.as()) { for (auto fuzzy_op : {op->parent, op->path}) { - for (auto match : node_map[fuzzy_op]) { - fuzzy_matches.insert(match); + if (node_map.count(fuzzy_op)) { + for (auto match : node_map[fuzzy_op]) { + fuzzy_matches.insert(match); + } } } } @@ -708,35 +675,37 @@ class PatternGrouper { group.function = Function(params, body, NullValue(), Array()); group.name = extractor.GetName(); - // Check to make sure we aren't overlapping with another group or creating an invalid fusion - // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the - // pattern with the input FunctionVar* Variables. The resulting memoization map will only - // contain nodes in the expression that matched the pattern. If a non-input node of the pattern - // (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a - // situation where we try to rewrite the same node twice in the second rewriting or parition - // pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants - // because they exist more globally outside of the fusion. - // Similiarly, if interior nodes in a group are used outside of the group fusing to a single - // output would create an invalid graph tranformation, so we block the creation of such groups. - auto memo = extractor.GetMemo(); - for (auto kv : memo) { - // Check to ensure that this node isn't an input or a global - if (inputs.count(kv.first) == 0 && kv.first.as() == nullptr && - kv.first.as() == nullptr && kv.first.as() == nullptr) { - if (gid_assignments_.count(kv.first) != 0) { - // check to see if the node is use in other groups - // Exit due to overlapping partitions - return; - } else if (kv.second != body) { - // if the node isn't the ouput of the group - auto node = matcher_->expr_graph_.node_map_.at(kv.first); - for (auto* output : node->outputs_) { - // and the node is used by nodes outside of the group - if (memo.count(output->ref_) == 0 && - !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { - // Exit because nodes in this pattern's body are used outside the pattern - // fusing it would be invalid - return; + if (!allow_overlapping_groups_) { + // Check to make sure we aren't overlapping with another group or creating an invalid fusion + // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the + // pattern with the input FunctionVar* Variables. The resulting memoization map will only + // contain nodes in the expression that matched the pattern. If a non-input node of the + // pattern (i.e., some piece of computation) overlaps with the nodes in a previous group, + // we'll have a situation where we try to rewrite the same node twice in the second rewriting + // or parition pass. This isn't valid, so we check for it here. We ignore Ops, functions, and + // constants because they exist more globally outside of the fusion. Similiarly, if interior + // nodes in a group are used outside of the group fusing to a single output would create an + // invalid graph tranformation, so we block the creation of such groups. + auto memo = extractor.GetMemo(); + for (auto kv : memo) { + // Check to ensure that this node isn't an input or a global + if (inputs.count(kv.first) == 0 && kv.first.as() == nullptr && + kv.first.as() == nullptr && kv.first.as() == nullptr) { + if (gid_assignments_.count(kv.first) != 0) { + // check to see if the node is use in other groups + // Exit due to overlapping partitions + return; + } else if (kv.second != body) { + // if the node isn't the ouput of the group + auto node = matcher_->expr_graph_.node_map_.at(kv.first); + for (auto* output : node->outputs_) { + // and the node is used by nodes outside of the group + if (memo.count(output->ref_) == 0 && + !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { + // Exit because nodes in this pattern's body are used outside the pattern + // fusing it would be invalid + return; + } } } } @@ -790,6 +759,7 @@ class PatternGrouper { IndexedGraph pattern_graph_; int gid_ = 0; int graph_number_ = 0; + bool allow_overlapping_groups_ = false; }; // Rewrite @@ -817,7 +787,8 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") */ class PatternRewriter : protected MixedModeMutator { public: - PatternRewriter(IRModule mod) : mod_(mod) {} + PatternRewriter(IRModule mod, bool allow_overlapping_groups) + : mod_(mod), allow_overlapping_groups_(allow_overlapping_groups) {} /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the * callbacks until it stops changing */ Expr Rewrite(const Array& callbacks, const Expr& pre) { @@ -835,7 +806,7 @@ class PatternRewriter : protected MixedModeMutator { if (callback_->require_type) { post = InferTypeWithModule(post, mod_); } - auto grouper = PatternGrouper(); + auto grouper = PatternGrouper(allow_overlapping_groups_); groups_ = grouper.GroupMatches(callback_->pattern, post); gid_assignments_ = grouper.GetGIDAssignments(); memo_.clear(); @@ -874,10 +845,12 @@ class PatternRewriter : protected MixedModeMutator { DFPatternCallback callback_; std::unordered_map groups_; std::unordered_map gid_assignments_; + bool allow_overlapping_groups_ = false; }; -Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod) { - return PatternRewriter(mod).Rewrite(callbacks, expr); +Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod, + int allow_overlapping_groups) { + return PatternRewriter(mod, allow_overlapping_groups).Rewrite(callbacks, expr); } TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatterns); diff --git a/src/relay/ir/dataflow_matcher.h b/src/relay/ir/dataflow_matcher.h new file mode 100644 index 000000000000..115798101971 --- /dev/null +++ b/src/relay/ir/dataflow_matcher.h @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/dataflow_matcher.h + * \brief The dataflow pattern matcher for Relay. + */ + +#include +#include +#include +#include + +#include "indexed_graph.h" + +namespace tvm { +namespace relay { + +// Pattern Matcher + +class DominatorMatcher; + +class DFPatternMatcher : public DFPatternFunctor { + public: + explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} + bool Match(const DFPattern& pattern, const Expr& expr); + Map> GetMemo() { return Map>(memo_); } + const IndexedGraph expr_graph_; + + protected: + bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; + bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override; + bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override; + + void ClearMap(size_t watermark); + bool MatchesPath(const DominatorPatternNode* op, const Expr& expr); + bool DominatesParent(const DominatorPatternNode* op, const Expr& expr); + + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; + std::vector matched_nodes_; + bool memoize_ = true; +}; + +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 95a83a905908..60b500b575d4 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -93,10 +93,10 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs int e_ndim = static_cast(e->shape.size()); const DataType& e_dtype = e->dtype; if (e_ndim != ndim) { - throw Error("relay.concatenate requires all tensors have the same ndim"); + throw Error("relay.concatenate requires all tensors to have the same ndim"); } if (e_dtype != dtype) { - throw Error("relay.concatenate requires all tensors have the same dtype"); + throw Error("relay.concatenate requires all tensors to have the same dtype"); } } diff --git a/src/relay/qnn/op/add.cc b/src/relay/qnn/op/add.cc index b0dc3e4af5c4..7f4de1ee1805 100644 --- a/src/relay/qnn/op/add.cc +++ b/src/relay/qnn/op/add.cc @@ -53,7 +53,7 @@ Expr QnnAddCanonicalize(const Attrs& attrs, const Array& new_args, // Since the input qnn params can be different than output qnn params, we first requantize the // input tensors to the output qnn params. Then we call relay.add on the requantized inputs. This // addition results in extra addition of the output zero point. We futher subtract the zero - // point. The whole process can be represented using following equations + // point. The whole process can be represented using following equations for C = A + B: // // scale_c * (Q_c - zp_c) = scale_a * (Q_a - zp_a) + scale_b * (Q_b - zp_b) // diff --git a/src/relay/qnn/op/dequantize.cc b/src/relay/qnn/op/dequantize.cc index 724441e0c523..48874062ef97 100644 --- a/src/relay/qnn/op/dequantize.cc +++ b/src/relay/qnn/op/dequantize.cc @@ -44,36 +44,44 @@ bool DequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, if (data == nullptr) { return false; } - const auto input_dtype = data->dtype; ICHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) || input_dtype == DataType::Int(32)) << "Input type should be one of the quantized types [unit8, int8, int32] but was " << input_dtype; - const auto* dequantize_attrs = attrs.as(); - int axis = dequantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1 : axis; - ICHECK_LT(axis, static_cast(data->shape.size())) - << "axis " << dequantize_attrs->axis << " is out of range"; - ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; - - // Check and assign types for scale and zero points. - AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale - AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point + const DequantizeAttrs* dequantize_attrs = attrs.as(); + DataType out_dtype = dequantize_attrs->out_dtype; + CHECK(out_dtype == DataType::Float(32) || out_dtype == DataType::Int(32)) + << "out_dtype for dequantize must be float32 or int32, but got " << out_dtype; + + // Assign type to scale and zero point if they're channelwise. + if (data->shape.size() != 0) { + int axis = dequantize_attrs->axis; + axis = (axis == -1) ? data->shape.size() - 1 : axis; + ICHECK_LT(axis, static_cast(data->shape.size())) + << "axis " << dequantize_attrs->axis << " is out of range"; + ICHECK_GE(axis, 0) << "axis " << dequantize_attrs->axis << " is out of range"; + + // Check and assign types for scale and zero points. + AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale + AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point + } const Array oshape = data->shape; - // assign output type, output will always be float 32. - reporter->Assign(types[3], TensorType(oshape, DataType::Float(32))); + // assign output type based on out_dtype attribute. + reporter->Assign(types[3], TensorType(oshape, out_dtype)); return true; } -Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis) { +Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis, + DataType out_dtype) { // real_value = scale * (quantized_value - zero_point) // A more detailed explanation can be found here - // https://github.com/google/gemmlowp/blob/master/doc/quantization.md auto attrs = make_object(); attrs->axis = axis; + attrs->out_dtype = out_dtype; static const Op& op = Op::Get("qnn.dequantize"); return Call(op, {data, input_scale, input_zero_point}, Attrs(attrs), {}); } @@ -82,6 +90,7 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, const Expr& input_zero_point, const Array& types, const DequantizeAttrs* attrs) { const auto axis = attrs->axis; + const DataType out_dtype = attrs->out_dtype; ICHECK_EQ(types.size(), 4); auto in_type = types[0]; @@ -105,6 +114,10 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale, auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point); auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale); + + if (out_dtype != DataType::Float(32)) { + scaled_output = Cast(scaled_output, out_dtype); + } return scaled_output; } diff --git a/src/relay/qnn/op/quantize.cc b/src/relay/qnn/op/quantize.cc index 9829834f43a3..8e68b3bd59b3 100644 --- a/src/relay/qnn/op/quantize.cc +++ b/src/relay/qnn/op/quantize.cc @@ -38,27 +38,30 @@ TVM_REGISTER_NODE_TYPE(QuantizeAttrs); bool QuantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { + // types = [data_type, scale_type, zp_type, ret_type] ICHECK_EQ(types.size(), 4); const auto* data = types[0].as(); - if (data == nullptr) { return false; } const auto input_dtype = data->dtype; - ICHECK(input_dtype == DataType::Float(32)) - << "Input type should be one of float32 but was " << input_dtype; + ICHECK(input_dtype == DataType::Float(32) || input_dtype == DataType::Int(32)) + << "Input type should be one of float32 or int32 but was " << input_dtype; const auto* quantize_attrs = attrs.as(); - int axis = quantize_attrs->axis; - axis = (axis == -1) ? data->shape.size() - 1 : axis; - ICHECK_LT(axis, static_cast(data->shape.size())) - << "axis " << quantize_attrs->axis << " is out of range"; - ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; - - // Check and assign types for scale and zero points. - AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale - AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point + + // Assign type to scale and zero point if they're channelwise. + if (data->shape.size() != 0) { + int axis = quantize_attrs->axis; + axis = (axis == -1) ? data->shape.size() - 1 : axis; + ICHECK_LT(axis, static_cast(data->shape.size())) + << "axis " << quantize_attrs->axis << " is out of range"; + ICHECK_GE(axis, 0) << "axis " << quantize_attrs->axis << " is out of range"; + // Check and assign types for scale and zero points. + AssignType(types[1], DataType::Float(32), data->shape[axis], reporter); // scale + AssignType(types[2], DataType::Int(32), data->shape[axis], reporter); // zero point + } const Array oshape = data->shape; const DataType out_dtype = quantize_attrs->out_dtype; diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 2ae879595659..7b9c9b79eecd 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -256,8 +256,9 @@ Expr RequantizeQnnCanonicalize(const Attrs& attrs, const Array& new_args, */ bool RequantizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // Expected Types: data, input_scale, input_zero_point, output_scale, output_zero_point, output - ICHECK_EQ(types.size(), 6); + // types = [data_type, input_scale_type, input_zero_point_type, + // output_scale_type, output_zero_point_type, ret_type] + CHECK_EQ(types.size(), 6); const auto* data = types[0].as(); if (data == nullptr) { diff --git a/src/relay/transforms/quantize.cc b/src/relay/transforms/quantize.cc new file mode 100644 index 000000000000..7dd000719fce --- /dev/null +++ b/src/relay/transforms/quantize.cc @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/transforms/new_quantize.cc + * \brief Relay Quantization related passes + */ + +#include +#include +#include +#include + +#include "../ir/dataflow_matcher.h" + +namespace tvm { +namespace relay { +namespace quantize { + +class PatternCalibrationInfoNode : public Object { // Change name later + public: + DFPattern pattern; + Expr expr; + + Array> input_scale_zps; + Array input_idxs; + Integer output_idx; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("expr", &expr); + v->Visit("input_scale_zps", &input_scale_zps); + v->Visit("input_idxs", &input_idxs); + v->Visit("output_idx", &output_idx); + } + + static constexpr const char* _type_key = "PatternCalibrationInfoNode"; + TVM_DECLARE_BASE_OBJECT_INFO(PatternCalibrationInfoNode, Object); +}; + +class PatternCalibrationInfo : public ObjectRef { + public: + TVM_DLL PatternCalibrationInfo(DFPattern pattern, Expr expr, Array> input_scale_zps, + Array input_idxs, Integer output_idx); + TVM_DEFINE_OBJECT_REF_METHODS(PatternCalibrationInfo, ObjectRef, PatternCalibrationInfoNode); +}; + +PatternCalibrationInfo::PatternCalibrationInfo(DFPattern pattern, Expr expr, + Array> input_scale_zps, + Array input_idxs, Integer output_idx) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->expr = std::move(expr); + n->input_scale_zps = std::move(input_scale_zps); + n->input_idxs = std::move(input_idxs); + n->output_idx = std::move(output_idx); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PatternCalibrationInfoNode); + +TVM_REGISTER_GLOBAL("relay.new_quantize.PatternCalibrationInfo") + .set_body_typed([](DFPattern pattern, Expr expr, Array> input_scale_zps, + Array input_idxs, Integer output_idx) { + return PatternCalibrationInfo(pattern, expr, input_scale_zps, input_idxs, output_idx); + }); + +class PartitionOutputs : public MixedModeMutator { + public: + Expr GetPartitionOutputs(const Expr& expr) { + new_outputs.clear(); + if (auto func = expr.as()) { + new_outputs.push_back(func->body); + } else if (auto tuple = expr.as()) { + new_outputs = tuple->fields; // Do I need to copy this explicitly? + } else { + new_outputs.push_back(expr); + } + VisitExpr(expr); + Expr out; + if (auto func = expr.as()) { + out = Function(func->params, Tuple(new_outputs), Type{}, Array{}, func->attrs); + } else { + out = Tuple(new_outputs); + } + return out; + } + + protected: + Expr Rewrite_(const CallNode* pre, const Expr& post) { + auto* post_node = post.as(); + ICHECK(post_node != nullptr); + if (auto* func_node = post_node->op.as()) { + if (func_node->attrs.defined() && + func_node->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { + for (const auto& arg : post_node->args) { + new_outputs.push_back(arg); + } + new_outputs.push_back(post); + } + } + return post; + } + + Array new_outputs; +}; + +class PartitionsInOrder : protected MixedModeVisitor { + public: + PartitionsInOrder(bool skip_first, bool skip_last) + : skip_first_(skip_first), skip_last_(skip_last) {} + Array partitions; + bool skip_first_; + bool skip_last_; + Array GetPartitionsInOrder(const Expr& expr) { + VisitExpr(expr); + Array out; + if (partitions.size() > 0) { + if (skip_first_) { + out.push_back(partitions[0]); + } + if (skip_last_) { + out.push_back(partitions.back()); + } + } + return out; + } + void VisitExpr_(const CallNode* op) override { + if (auto func_node = op->op.as()) { + // If it's calling a function, check to see if it has attributes that it's a been partitioned + // from a Pattern + if (func_node->attrs.defined() && + func_node->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { + // If this is a pattern function, create a matcher on it's body + partitions.push_back(op->op); + } + } + } +}; + +class RewritePartitions : protected MixedModeMutator { + public: + RewritePartitions(const Array& callbacks) : callbacks_(callbacks) {} + Map Rewrite(const Expr& expr) { + // Preprocessing + if (auto* func = expr.as()) { + if (auto* tuple = func->body.as()) { + orig_outputs_ = tuple->fields; + } else { + orig_outputs_.push_back(func->body); + } + for (auto param : func->params) { + new_params_.push_back(param); + } + } else { + if (auto* tuple = expr.as()) { + orig_outputs_ = tuple->fields; + } else { + orig_outputs_.push_back(expr); + } + } + Expr new_out = MixedModeMutator::Mutate(expr); + + // Add new parameters to the function + if (auto* new_out_func = new_out.as()) { + new_out = + Function(new_params_, new_out_func->body, Type{}, Array{}, new_out_func->attrs); + } + // TVM object system doesn't have pairs, so we'll return new_out and infos_ in a Map + Map out_pair = {{"new_out", new_out}, {"infos_", infos_}}; + return out_pair; //{new_out, infos_}; + } + + protected: + Array FindScaleZp(const Expr& input, const Expr& new_body) { + Array ScaleZp; + auto x = WildcardPattern(make_object()); + auto scale = WildcardPattern(make_object()); + auto zp = WildcardPattern(make_object()); + DFPattern pattern = IsOp("qnn.quantize")({x, scale, zp}); + + runtime::PackedFunc callback([&](TVMArgs args, TVMRetValue* ret) { + Expr post = args[1]; + Map> node_map = args[2]; + + if (node_map[x][0] == input) { + auto scale_var = node_map[scale][0].as(); + auto zp_var = node_map[zp][0].as(); + + CHECK((scale_var && zp_var) || (!scale_var && !zp_var)) + << "The scale and zero point passed to a " + << "qnn.quantize must both be expressions composed of other variables, or be variables " + "themselves. " + << "Please change the AST returned from your QuantizerPattern to meet this " + "requirement."; + + // Only add them to the list of scales / zps we will set later if they are not expressions + if (scale_var && zp_var) { + ScaleZp.push_back(GetRef(scale_var)); + ScaleZp.push_back(GetRef(zp_var)); + } + } + + *ret = post; + }); + RewritePatterns({DFPatternCallback(pattern, callback, false)}, new_body); + return ScaleZp; + } + Expr Rewrite_(const CallNode* pre, const Expr& post) { + // Cast the post as a call node and assert it actually is a call + auto* post_node = post.as(); + ICHECK(post_node != nullptr); + // Check to see if the Call is calling a Function + if (auto* func_node = post_node->op.as()) { + // If it's calling a function, check to see if it has attributes that it's a been partitioned + // from a Pattern + if (func_node->attrs.defined() && + func_node->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { + // If this is a pattern function, create a matcher on it's body + auto matcher = DFPatternMatcher(func_node->body); + // Find the callback that matches this pattern + for (const auto& callback : callbacks_) { + if (matcher.Match(callback->pattern, func_node->body)) { + // extract the current params and call-level args + Array params = func_node->params; + Array call_args = post_node->args; + + Array input_idx; + // Get the indices of the arguments to this function in the output tuple + for (auto arg : pre->args) { + auto itr = std::find(orig_outputs_.begin(), orig_outputs_.end(), arg); + ICHECK(itr != orig_outputs_.end()) + << "Didn't find the arguement in the output tuple. Indicates a possible problem " + "in PartitionOutputs. "; + input_idx.push_back(std::distance(orig_outputs_.begin(), itr)); + } + // Get the index of the output of this function + auto itr = std::find(orig_outputs_.begin(), orig_outputs_.end(), GetRef(pre)); + ICHECK(itr != orig_outputs_.end()) + << "Didn't find the output in the output tuple. Indicates a possible problem in " + "PartitionOutputs. "; + Integer output_idx(std::distance(orig_outputs_.begin(), itr)); + + // create a new body based on the callback + Expr new_body = callback->function(pre->op.as()->body, func_node->body, + matcher.GetMemo()); + + // FIND THE SCALE / ZPS + Array> input_scale_zps; + for (auto param : params) { + Array scale_zp = FindScaleZp(param, new_body); + // If FindScaleZp returns an empty array, we don't need to provide these as parameters + if (scale_zp.size() != 0) { + ICHECK(scale_zp.size() == 2) + << "scale_zp should have two items in it, the scale variable and the zp " + "variable. This points to an issue with FindScaleZp. "; + input_scale_zps.push_back(FindScaleZp(param, new_body)); + } + } + + infos_.push_back(PatternCalibrationInfo(callback->pattern, + pre->op.as()->body, + input_scale_zps, input_idx, output_idx)); + // find parameters added to the new body that weren't there before + // find all of the free variables in the new body + for (const auto& param : FreeVars(new_body)) { + // check to see if that free variable is in the old parameter list + if (std::find(params.begin(), params.end(), param) == params.end()) { + // if not, add it to the new parameter list + params.push_back(param); + // Create a new call-level arg for it + // Make that new arg an input to the top-level function + new_params_.push_back(Var(param->name_hint(), param->type_annotation)); + call_args.push_back(new_params_.back()); + } + } + // Create a new function with new params and body + Expr new_func = Function(params, new_body, Type{}, Array{}, func_node->attrs); + // Call the new function with the new args + return Call(new_func, call_args, Attrs{}, Array{}); + } + } + } + } + return post; + } + Array callbacks_; + Array infos_; + Array new_params_; + Array orig_outputs_; +}; + +class ReplaceArgs : protected MixedModeMutator { + public: + Expr Rewrite(const Expr& body, + std::unordered_map arg_map) { + // Leverage the memoizer to replace parameters with arguments automatically + memo_ = arg_map; + return MixedModeMutator::Mutate(body); + } +}; + +class LowerPartitions : protected MixedModeMutator { + public: + LowerPartitions(const Array targets = Array(), const bool skipping_partitions = false) + : targets_(targets), skipping_partitions_(skipping_partitions) {} + Expr Rewrite(const Expr& expr) { + Expr new_out = MixedModeMutator::Mutate(expr); + return new_out; + } + Expr Rewrite_(const CallNode* pre, const Expr& post) { + // Targets is usually length 0, 1, or 2 + if ((!skipping_partitions_) || + (skipping_partitions_ && + std::find(targets_.begin(), targets_.end(), pre->op) != targets_.end())) { + auto* post_node = post.as(); + ICHECK(post_node != nullptr); + if (auto* func_node = post_node->op.as()) { + // If the function was created by the pattern matcher, remove it + if (func_node->attrs.defined() && + func_node->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { + std::unordered_map arg_map = {}; + Array args = post_node->args; + Array params = func_node->params; + + for (uint i = 0; i < args.size(); i++) { + arg_map.insert({params[i], args[i]}); + } + return ReplaceArgs().Rewrite(func_node->body, arg_map); + } + } + } + return post; + } + + protected: + Array targets_; + bool skipping_partitions_; +}; + +TVM_REGISTER_GLOBAL("relay.transform.quantize.partition_outputs").set_body_typed([](const Expr& expr) { + return PartitionOutputs().GetPartitionOutputs(expr); +}); +TVM_REGISTER_GLOBAL("relay.transform.quantize.rewrite_partitions") + .set_body_typed([](const Array& callbacks, const Expr& expr) { + return RewritePartitions(callbacks).Rewrite(expr); + }); +TVM_REGISTER_GLOBAL("relay.transform.quantize.lower_partitions").set_body_typed([](const Expr& expr) { + return LowerPartitions().Rewrite(expr); +}); +TVM_REGISTER_GLOBAL("relay.transform.quantize.skip_partitions") + .set_body_typed([](const Expr& expr, bool skip_first, bool skip_last) { + auto targets = PartitionsInOrder(skip_first, skip_last).GetPartitionsInOrder(expr); + return LowerPartitions(targets, true).Rewrite(expr); + }); +} // namespace quantize +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/quantize/test_calibrate.py b/tests/python/relay/quantize/test_calibrate.py new file mode 100644 index 000000000000..5ec36e8cdc35 --- /dev/null +++ b/tests/python/relay/quantize/test_calibrate.py @@ -0,0 +1,501 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relay +from tvm.data import RandomDatasetManager +from tvm.relay.transform.quantize import ( + Quantizer, + QuantizerPattern, + QuantizationCalibrator, + Conv2DPattern, + Conv2DBiasAddPattern, + DensePattern, + DenseBiasAddPattern, + AddPattern, + MultiplyPattern, + CalibrationCallback, + GlobalCalibrationCallback, + AverageMaxCalibrationCallback, + AverageMaxPerChannelConv2DPattern, + AverageMaxPerChannelConv2DBiasAddPattern, + AverageMaxPerChannelDensePattern, + AverageMaxPerChannelDenseBiasAddPattern, +) +from test_quantize import ( + create_conv2d_func, + create_q_conv2d_func, + create_conv2d_bias_func, + create_q_conv2d_bias_func, + create_dense_func, + create_q_dense_func, + create_dense_bias_func, + create_q_dense_bias_func, + create_add_func, + create_q_add_func, + create_mul_func, + create_q_mul_func, +) +from tvm.relay.frontend.common import infer_type + +import numpy as np + +# Calls all the methods of CalibrationCallback to make sure they work OK + + +class TestCalibrationCallback(CalibrationCallback): + def __init__(self): + self.scale_value = np.array(2).astype("float32") + self.zp_value = np.array(0.5).astype("int32") + + def calibrate_pattern(self, calibration_info): + scale_zp_values = {} + + for i in range(len(calibration_info.input_scale_zps)): + scale_name = calibration_info.input_scale_zps[i][0].name_hint + scale_zp_values[scale_name] = self.scale_value + zp_name = calibration_info.input_scale_zps[i][1].name_hint + scale_zp_values[zp_name] = self.zp_value + + inputs, _ = calibration_info.dataset_manager.get_next_batch() + + calibration_info.get_unquantized_layer_inputs(inputs) + calibration_info.get_unquantized_layer_output(inputs) + calibration_info.get_quantized_layer_inputs(inputs, scale_zp_values) + calibration_info.get_quantized_layer_output(inputs, scale_zp_values) + + return scale_zp_values + + +def test_calibrate(quantizer, quantized_func, params, dataset_manager): + calibrator = QuantizationCalibrator(quantizer, dataset_manager=dataset_manager) + calibrated_func = calibrator.calibrate() + + quantized_func = relay.build_module.bind_params_by_name( + quantized_func, calibrator.calibration_info.scale_zp_value_map + ) + quantized_func = relay.build_module.bind_params_by_name(quantized_func, params) + quantized_func = infer_type(quantized_func) + calibrated_func = infer_type(calibrated_func) + + assert tvm.ir.structural_equal(quantized_func, calibrated_func) + + +def reset_scale_zp_counter(): + # For testing purposes, we reset the static scale counter to zero before calibrating so that our variable names + # match up properly + QuantizerPattern.scales_count = 0 + QuantizerPattern.zp_count = 0 + + +def verify_conv2d(data_shape, weight_shape, attrs, cc=None, pattern_list=None): + reset_scale_zp_counter() + + conv2d_func, data, weight = create_conv2d_func(data_shape, weight_shape, attrs) + q_conv2d_func = create_q_conv2d_func(data, weight, weight_shape, attrs) + + if cc is None: + cc = TestCalibrationCallback() + + params = {"weight": np.random.randn(*weight_shape).astype("float32")} + if pattern_list is None: + pattern_list = [Conv2DPattern(cc)] + quantizer = Quantizer(conv2d_func, params, pattern_list, skip_first=False, skip_last=False) + + test_calibrate( + quantizer, q_conv2d_func, params, RandomDatasetManager(data_shape, "float32", 1, 3) + ) + + +def verify_conv2d_bias( + data_shape, weight_shape, bias_shape, attrs, bias_type="bias_add", cc=None, pattern_list=None +): + reset_scale_zp_counter() + + conv2d_func, data, weight, bias = create_conv2d_bias_func( + data_shape, weight_shape, bias_shape, attrs, bias_type + ) + q_conv2d_func = create_q_conv2d_bias_func(data, weight, bias, weight_shape, attrs, bias_type) + + if cc is None: + cc = TestCalibrationCallback() + + params = {"weight": np.random.randn(*weight_shape).astype("float32")} + if pattern_list is None: + pattern_list = [Conv2DBiasAddPattern(cc)] + quantizer = Quantizer(conv2d_func, params, pattern_list, skip_first=False, skip_last=False) + + test_calibrate( + quantizer, q_conv2d_func, params, RandomDatasetManager(data_shape, "float32", 1, 3) + ) + + +def verify_dense(data_shape, weight_shape, attrs, cc=None, pattern_list=None): + reset_scale_zp_counter() + + dense_func, data, weight = create_dense_func(data_shape, weight_shape, attrs) + q_dense_func = create_q_dense_func(data, weight, attrs) + + if cc is None: + cc = TestCalibrationCallback() + + params = {"weight": np.random.randn(*weight_shape).astype("float32")} + + if pattern_list is None: + pattern_list = [DensePattern(cc)] + quantizer = Quantizer(dense_func, params, pattern_list, skip_first=False, skip_last=False) + + test_calibrate( + quantizer, q_dense_func, params, RandomDatasetManager(data_shape, "float32", 1, 3) + ) + + +def verify_dense_bias( + data_shape, weight_shape, bias_shape, attrs, bias_type="bias_add", cc=None, pattern_list=None +): + reset_scale_zp_counter() + + dense_bias_func, data, weight, bias = create_dense_bias_func( + data_shape, weight_shape, bias_shape, attrs, bias_type + ) + q_dense_bias_func = create_q_dense_bias_func(data, weight, bias, attrs, bias_type) + + if cc is None: + cc = TestCalibrationCallback() + + params = {"weight": np.random.randn(*weight_shape).astype("float32")} + if pattern_list is None: + pattern_list = [DenseBiasAddPattern(cc)] + quantizer = Quantizer(dense_bias_func, params, pattern_list, skip_first=False, skip_last=False) + + test_calibrate( + quantizer, q_dense_bias_func, params, RandomDatasetManager(data_shape, "float32", 1, 3) + ) + + +def verify_add(lhs_shape, rhs_shape, cc=None): + reset_scale_zp_counter() + + add_func, lhs, rhs = create_add_func(lhs_shape, rhs_shape) + q_add_func = create_q_add_func(lhs, rhs) + + if cc is None: + cc = TestCalibrationCallback() + + params = {"weight": np.random.randn(*rhs_shape).astype("float32")} + quantizer = Quantizer(add_func, params, [AddPattern(cc)], skip_first=False, skip_last=False) + + test_calibrate(quantizer, q_add_func, params, RandomDatasetManager(lhs_shape, "float32", 1, 3)) + + +def verify_mul(lhs_shape, rhs_shape, cc=None): + reset_scale_zp_counter() + + mul_func, lhs, rhs = create_mul_func(lhs_shape, rhs_shape) + q_mul_func = create_q_mul_func(lhs, rhs) + + if cc is None: + cc = TestCalibrationCallback() + + params = {"weight": np.random.randn(*rhs_shape).astype("float32")} + quantizer = Quantizer( + mul_func, params, [MultiplyPattern(cc)], skip_first=False, skip_last=False + ) + + test_calibrate(quantizer, q_mul_func, params, RandomDatasetManager(lhs_shape, "float32", 1, 3)) + + +def verify_cc_all_patterns(cc): + verify_conv2d( + (2, 3, 32, 32), + (32, 3, 3, 3), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + cc=cc, + ) + verify_conv2d( + (2, 32, 32, 3), + (3, 3, 3, 32), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + cc=cc, + ) + verify_conv2d_bias( + (2, 3, 32, 32), + (32, 3, 3, 3), + (32,), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + cc=cc, + ) + verify_conv2d_bias( + (2, 32, 32, 3), + (3, 3, 3, 32), + (32,), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + cc=cc, + ) + verify_conv2d_bias( + (2, 3, 32, 32), + (32, 3, 3, 3), + (1, 32, 1, 1), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + bias_type="normal_add", + cc=cc, + ) + verify_conv2d_bias( + (2, 32, 32, 3), + (3, 3, 3, 32), + (1, 1, 1, 32), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + bias_type="normal_add", + cc=cc, + ) + verify_dense((1, 8), (16, 8), {"units": 16}, cc=cc) + verify_dense((1, 4), (3, 4), {"units": 3}, cc=cc) + verify_dense_bias((1, 8), (16, 8), (16,), {"units": 16}, cc=cc) + verify_dense_bias((1, 4), (3, 4), (3,), {"units": 3}, cc=cc) + verify_dense_bias((1, 8), (16, 8), (16,), {"units": 16}, bias_type="normal_add", cc=cc) + verify_dense_bias((1, 4), (3, 4), (3,), {"units": 3}, bias_type="normal_add", cc=cc) + verify_add((1, 2, 3), (1, 2, 3), cc=cc) + verify_mul((1, 2, 3), (1, 2, 3), cc=cc) + + +def test_conv2d(): + verify_conv2d( + (2, 3, 32, 32), + (32, 3, 3, 3), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + ) + verify_conv2d( + (2, 32, 32, 3), + (3, 3, 3, 32), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + ) + + +def test_conv2d_bias(): + verify_conv2d_bias( + (2, 3, 32, 32), + (32, 3, 3, 3), + (32,), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + ) + verify_conv2d_bias( + (2, 32, 32, 3), + (3, 3, 3, 32), + (32,), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + ) + verify_conv2d_bias( + (2, 3, 32, 32), + (32, 3, 3, 3), + (1, 32, 1, 1), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + bias_type="normal_add", + ) + verify_conv2d_bias( + (2, 32, 32, 3), + (3, 3, 3, 32), + (1, 1, 1, 32), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + bias_type="normal_add", + ) + + +def test_dense(): + verify_dense((1, 8), (16, 8), {"units": 16}) + verify_dense((1, 4), (3, 4), {"units": 3}) + + +def test_dense_bias(): + verify_dense_bias((1, 8), (16, 8), (16,), {"units": 16}) + verify_dense_bias((1, 4), (3, 4), (3,), {"units": 3}) + verify_dense_bias((1, 8), (16, 8), (16,), {"units": 16}, bias_type="normal_add") + verify_dense_bias((1, 4), (3, 4), (3,), {"units": 3}, bias_type="normal_add") + + +def test_add(): + verify_add((1, 2, 3), (1, 2, 3)) + + +def test_mul(): + verify_mul((1, 2, 3), (1, 2, 3)) + + +def test_global_cc(): + verify_cc_all_patterns(GlobalCalibrationCallback(0.05, 0.01)) + + +def test_average_max_cc(): + verify_cc_all_patterns(AverageMaxCalibrationCallback()) + + +def test_per_channel_average_max_cc(): + pl = [AverageMaxPerChannelConv2DPattern()] + verify_conv2d( + (2, 3, 32, 32), + (32, 3, 3, 3), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + pattern_list=pl, + ) + verify_conv2d( + (2, 32, 32, 3), + (3, 3, 3, 32), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + pattern_list=pl, + ) + pl = [AverageMaxPerChannelConv2DBiasAddPattern()] + verify_conv2d_bias( + (2, 3, 32, 32), + (32, 3, 3, 3), + (32,), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + pattern_list=pl, + ) + verify_conv2d_bias( + (2, 32, 32, 3), + (3, 3, 3, 32), + (32,), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + pattern_list=pl, + ) + verify_conv2d_bias( + (2, 3, 32, 32), + (32, 3, 3, 3), + (1, 32, 1, 1), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + bias_type="normal_add", + pattern_list=pl, + ) + verify_conv2d_bias( + (2, 32, 32, 3), + (3, 3, 3, 32), + (1, 1, 1, 32), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + bias_type="normal_add", + pattern_list=pl, + ) + pl = [AverageMaxPerChannelDensePattern()] + verify_dense((1, 8), (16, 8), {"units": 16}, pattern_list=pl) + verify_dense((1, 4), (3, 4), {"units": 3}, pattern_list=pl) + pl = [AverageMaxPerChannelDenseBiasAddPattern()] + verify_dense_bias((1, 8), (16, 8), (16,), {"units": 16}, pattern_list=pl) + verify_dense_bias((1, 4), (3, 4), (3,), {"units": 3}, pattern_list=pl) + verify_dense_bias( + (1, 8), (16, 8), (16,), {"units": 16}, bias_type="normal_add", pattern_list=pl + ) + verify_dense_bias((1, 4), (3, 4), (3,), {"units": 3}, bias_type="normal_add", pattern_list=pl) + + +if __name__ == "__main__": + test_conv2d() + test_conv2d_bias() + test_dense() + test_dense_bias() + test_add() + test_mul() + test_global_cc() + test_average_max_cc() + test_per_channel_average_max_cc() diff --git a/tests/python/relay/quantize/test_pass.py b/tests/python/relay/quantize/test_pass.py new file mode 100644 index 000000000000..46fcd0f750d0 --- /dev/null +++ b/tests/python/relay/quantize/test_pass.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.contrib import graph_runtime +from test_quantize import ( + create_conv2d_bias_func, + create_q_conv2d_bias_func, +) + +from tvm.relay.transform.quantize import ( + QuantizePass, + all_patterns, + average_max_per_channel_patterns, + GlobalCalibrationCallback, +) + +import numpy as np + + +def verify_pass(pre_mod, params, input_dict, quantizer_pattern_list): + opt = QuantizePass(quantizer_pattern_list, params, skip_first=False) + with relay.build_config(opt_level=3): + post_mod = opt(pre_mod) + q_lib = relay.build(post_mod, params=params, target="llvm") + + q_gmod = graph_runtime.GraphModule(q_lib["default"](tvm.cpu())) + q_gmod.set_input(**input_dict) + q_gmod.run() + + +def create_conv2d_bias_mods(data_shape, weight_shape, bias_shape, attrs, bias_type="bias_add"): + pre_func, data, weight, bias = create_conv2d_bias_func( + data_shape, weight_shape, bias_shape, attrs, bias_type + ) + pre_mod = tvm.IRModule.from_expr(pre_func) + params = {"weight": np.random.randn(*weight_shape).astype("float32")} + input_dict = {"data": np.random.randn(*data_shape).astype("float32")} + return pre_mod, params, input_dict + + +def test_pass(): + cc = GlobalCalibrationCallback(0.05, 0.1) + + pre_mod, params, input_dict = create_conv2d_bias_mods( + (2, 3, 32, 32), + (32, 3, 3, 3), + (32,), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + ) + verify_pass(pre_mod, params, input_dict, all_patterns(cc)) + + +if __name__ == "__main__": + test_pass() diff --git a/tests/python/relay/quantize/test_quantize.py b/tests/python/relay/quantize/test_quantize.py new file mode 100644 index 000000000000..41f8947a1d7c --- /dev/null +++ b/tests/python/relay/quantize/test_quantize.py @@ -0,0 +1,574 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay.transform.quantize import ( + Quantizer, + Conv2DPattern, + Conv2DBiasAddPattern, + DensePattern, + DenseBiasAddPattern, + AddPattern, + MultiplyPattern, +) +from tvm.relay.op.nn.utils import get_pad_tuple2d +from tvm.relay.frontend.common import infer_type +import numpy as np + + +def quantize_and_check( + pre_func, expected_func, quantizer_pattern_list, skip_first=False, skip_last=False +): + quantizer = Quantizer( + pre_func, None, quantizer_pattern_list, skip_first=skip_first, skip_last=skip_last + ) + q_func = infer_type(quantizer.quantized_func) + expected_func = infer_type(expected_func) + assert tvm.ir.structural_equal(q_func, expected_func) + + +def create_scale_zps(lhs_name, rhs_name, channels=None): + data_scale_var = relay.var(lhs_name + "_scale_0", shape=(), dtype="float32") + data_zp_var = relay.var(lhs_name + "_zero_pt_0", shape=(), dtype="int32") + + if not channels: + weight_scale_var = relay.var(rhs_name + "_scale_1", shape=(), dtype="float32") + weight_zp_var = relay.var(rhs_name + "_zero_pt_1", shape=(), dtype="int32") + else: + weight_scale_var = relay.var(rhs_name + "_scale_1", shape=(channels,), dtype="float32") + weight_zp_var = relay.var(rhs_name + "_zero_pt_1", shape=(channels,), dtype="int32") + + return data_scale_var, data_zp_var, weight_scale_var, weight_zp_var + + +def get_conv2d_axes(attrs): + kernel_layout = attrs["kernel_layout"] + data_layout = attrs["data_layout"] + + if kernel_layout == "OIHW": + weight_channel_axis = 0 + elif kernel_layout == "HWIO": + weight_channel_axis = 3 + else: + raise ValueError( + "We don't support layouts other than OIHW or HWIO, but got %s. Please provide a compatible layout to the test. ", + kernel_layout, + ) + + if data_layout == "NCHW": + data_channel_axis = 1 + elif data_layout == "NHWC": + data_channel_axis = 3 + else: + raise ValueError( + "We don't support layouts other than NCHW or NHWC, but got %s. Please provide a compatible layout to the test. ", + data_layout, + ) + + return data_channel_axis, weight_channel_axis + + +def create_conv2d_func(data_shape, weight_shape, attrs): + data = relay.var("data", relay.TensorType(data_shape, "float32")) + weight = relay.var("weight", relay.TensorType(weight_shape, "float32")) + + # Pre quantize input + conv2d = relay.op.nn.conv2d(data, weight, **attrs) + pre_func = relay.Function([data, weight], conv2d) + return pre_func, data, weight + + +def create_q_conv2d_func(data, weight, weight_shape, attrs): + data_channel_axis, weight_channel_axis = get_conv2d_axes(attrs) + # Post quantize output + data_scale_var, data_zp_var, weight_scale_var, weight_zp_var = create_scale_zps( + "conv2d_data", "conv2d_weight" + ) + + q_data = relay.qnn.op.quantize(data, data_scale_var, data_zp_var, axis=data_channel_axis) + q_weight = relay.qnn.op.quantize( + weight, weight_scale_var, weight_zp_var, axis=weight_channel_axis + ) + + if "padding" in attrs.keys(): + padding = attrs["padding"] + else: + padding = None + + kernel_layout = attrs["kernel_layout"] + data_layout = attrs["data_layout"] + + if padding is not None: + top, left, bottom, right = get_pad_tuple2d(padding) + if kernel_layout == "OIHW": + pad_width = ((0, 0), (0, 0), (top, bottom), (left, right)) + elif kernel_layout == "HWIO": + pad_width = ( + (top, bottom), + (left, right), + (0, 0), + (0, 0), + ) + pad_val = 0 + q_data = relay.op.nn.pad(q_data, pad_width, pad_val) + + if kernel_layout == "OIHW": + kernel_size = tuple(weight_shape[2:4]) + elif kernel_layout == "HWIO": + kernel_size = tuple(weight_shape[0:2]) + else: + raise ValueError( + "We don't support layouts other than OIHW or HWIO, but got %s. Please provide a compatible layout to the test. ", + kernel_layout, + ) + + q_conv2d = relay.qnn.op.conv2d( + q_data, + q_weight, + data_zp_var, + weight_zp_var, + data_scale_var, + weight_scale_var, + data_layout=data_layout, + kernel_layout=kernel_layout, + kernel_size=kernel_size, + channels=weight_shape[weight_channel_axis], + ) + + deq_conv2d = relay.qnn.op.dequantize( + q_conv2d, + data_scale_var * weight_scale_var, + relay.const(0, dtype="int32"), + out_dtype="float32", + axis=data_channel_axis, + ) + quantized_func = relay.Function( + [data, weight, data_scale_var, data_zp_var, weight_scale_var, weight_zp_var], deq_conv2d + ) + return quantized_func + + +def verify_conv2d(data_shape, weight_shape, attrs): + pre_func, data, weight = create_conv2d_func(data_shape, weight_shape, attrs) + + quantized_func = create_q_conv2d_func(data, weight, weight_shape, attrs) + quantize_and_check(pre_func, quantized_func, [Conv2DPattern(None)]) + + +def create_conv2d_bias_func(data_shape, weight_shape, bias_shape, attrs, bias_type="bias_add"): + data = relay.var("data", relay.TensorType(data_shape, "float32")) + weight = relay.var("weight", relay.TensorType(weight_shape, "float32")) + bias = relay.const(np.random.rand(*bias_shape).astype("float32"), "float32") + + conv2d = relay.op.nn.conv2d(data, weight, **attrs) + data_channel_axis, _ = get_conv2d_axes(attrs) + if bias_type == "normal_add": + bias_add = relay.op.add(conv2d, bias) + elif bias_type == "bias_add": + bias_add = relay.op.nn.bias_add(conv2d, bias, axis=data_channel_axis) + else: + raise ValueError( + "Please pass in a valid bias type to the test function, got %s" % bias_type + ) + pre_func = relay.Function([data, weight], bias_add) + return pre_func, data, weight, bias + + +def create_q_conv2d_bias_func(data, weight, bias, weight_shape, attrs, bias_type="bias_add"): + + data_scale_var, data_zp_var, weight_scale_var, weight_zp_var = create_scale_zps( + "conv2d_data", "conv2d_weight" + ) + data_channel_axis, weight_channel_axis = get_conv2d_axes(attrs) + + q_data = relay.qnn.op.quantize( + data, data_scale_var, data_zp_var, axis=data_channel_axis + ) # Put axis in + q_weight = relay.qnn.op.quantize( + weight, weight_scale_var, weight_zp_var, axis=weight_channel_axis + ) + + if "padding" in attrs.keys(): + padding = attrs["padding"] + else: + padding = None + + kernel_layout = attrs["kernel_layout"] + data_layout = attrs["data_layout"] + + if padding is not None: + top, left, bottom, right = get_pad_tuple2d(padding) + kernel_layout = attrs["kernel_layout"] + if kernel_layout == "OIHW": + pad_width = ((0, 0), (0, 0), (top, bottom), (left, right)) + elif kernel_layout == "HWIO": + pad_width = ( + (top, bottom), + (left, right), + (0, 0), + (0, 0), + ) + pad_val = 0 + q_data = relay.op.nn.pad(q_data, pad_width, pad_val) + + if kernel_layout == "OIHW": + kernel_size = tuple(weight_shape[2:4]) + elif kernel_layout == "HWIO": + kernel_size = tuple(weight_shape[0:2]) + else: + raise ValueError( + "We don't support layouts other than OIHW or HWIO, but got %s. Please provide a compatible layout to the test. ", + kernel_layout, + ) + + q_conv2d = relay.qnn.op.conv2d( + q_data, + q_weight, + data_zp_var, + weight_zp_var, + data_scale_var, + weight_scale_var, + data_layout=data_layout, + kernel_layout=kernel_layout, + kernel_size=kernel_size, + channels=weight_shape[weight_channel_axis], + ) + + bias_add = relay.op.nn.bias_add( + q_conv2d, + relay.qnn.op.quantize(bias, data_scale_var, data_zp_var, axis=0, out_dtype="int32"), + axis=data_channel_axis, + ) + + if bias_type == "normal_add": + bias_add = relay.op.add( + q_conv2d, + relay.qnn.op.quantize(bias, data_scale_var, data_zp_var, axis=0, out_dtype="int32"), + ) + elif bias_type == "bias_add": + bias_add = relay.op.nn.bias_add( + q_conv2d, + relay.qnn.op.quantize(bias, data_scale_var, data_zp_var, axis=0, out_dtype="int32"), + axis=data_channel_axis, + ) + else: + raise ValueError( + "Please pass in a valid bias type to the test function, got %s" % bias_type + ) + + deq_conv2d = relay.qnn.op.dequantize( + bias_add, + data_scale_var * weight_scale_var, + relay.const(0, dtype="int32"), + out_dtype="float32", + axis=data_channel_axis, + ) + quantized_func = relay.Function( + [data, weight, data_scale_var, data_zp_var, weight_scale_var, weight_zp_var], deq_conv2d + ) + return quantized_func + + +def verify_conv2d_bias(data_shape, weight_shape, bias_shape, attrs, bias_type="bias_add"): + pre_func, data, weight, bias = create_conv2d_bias_func( + data_shape, weight_shape, bias_shape, attrs, bias_type + ) + quantized_func = create_q_conv2d_bias_func(data, weight, bias, weight_shape, attrs, bias_type) + quantize_and_check(pre_func, quantized_func, [Conv2DBiasAddPattern(None)]) + + +def create_dense_func(data_shape, weight_shape, attrs): + data = relay.var("data", relay.TensorType(data_shape, "float32")) + weight = relay.var("weight", relay.TensorType(weight_shape, "float32")) + + pre_func = relay.Function([data, weight], relay.nn.dense(data, weight, **attrs)) + + return pre_func, data, weight + + +def create_q_dense_func(data, weight, attrs): + data_scale_var, data_zp_var, weight_scale_var, weight_zp_var = create_scale_zps( + "dense_data", "dense_weight" + ) + + q_data = relay.qnn.op.quantize(data, data_scale_var, data_zp_var) + q_weight = relay.qnn.op.quantize(weight, weight_scale_var, weight_zp_var, axis=0) + + q_dense = relay.qnn.op.dense( + q_data, q_weight, data_zp_var, weight_zp_var, data_scale_var, weight_scale_var, **attrs + ) + deq_dense = relay.qnn.op.dequantize( + q_dense, data_scale_var * weight_scale_var, relay.const(0, dtype="int32"), axis=1 + ) + quantized_func = relay.Function( + [data, weight, data_scale_var, data_zp_var, weight_scale_var, weight_zp_var], deq_dense + ) + + return quantized_func + + +def verify_dense(data_shape, weight_shape, attrs): + pre_func, data, weight = create_dense_func(data_shape, weight_shape, attrs) + quantized_func = create_q_dense_func(data, weight, attrs) + quantize_and_check(pre_func, quantized_func, [DensePattern(None)]) + + +def create_dense_bias_func(data_shape, weight_shape, bias_shape, attrs, bias_type="bias_add"): + data = relay.var("data", relay.TensorType(data_shape, "float32")) + weight = relay.var("weight", relay.TensorType(weight_shape, "float32")) + bias = relay.const(np.random.rand(*bias_shape).astype("float32"), "float32") + dense = relay.nn.dense(data, weight, **attrs) + + if bias_type == "normal_add": + bias_add = relay.op.add(dense, bias) + elif bias_type == "bias_add": + bias_add = relay.op.nn.bias_add(dense, bias, axis=1) + else: + raise ValueError( + "Please pass in a valid bias type to the test function, got %s" % bias_type + ) + + pre_func = relay.Function([data, weight], bias_add) + + return pre_func, data, weight, bias + + +def create_q_dense_bias_func(data, weight, bias, attrs, bias_type="bias_add"): + data_scale_var, data_zp_var, weight_scale_var, weight_zp_var = create_scale_zps( + "dense_data", "dense_weight" + ) + + q_data = relay.qnn.op.quantize(data, data_scale_var, data_zp_var) + q_weight = relay.qnn.op.quantize(weight, weight_scale_var, weight_zp_var, axis=0) + q_bias = relay.qnn.op.quantize(bias, data_scale_var, data_zp_var, axis=0, out_dtype="int32") + + q_dense = relay.qnn.op.dense( + q_data, q_weight, data_zp_var, weight_zp_var, data_scale_var, weight_scale_var, **attrs + ) + + if bias_type == "normal_add": + bias_add = relay.op.add(q_dense, q_bias) + elif bias_type == "bias_add": + bias_add = relay.op.nn.bias_add(q_dense, q_bias, axis=1) + else: + raise ValueError( + "Please pass in a valid bias type to the test function, got %s" % bias_type + ) + + deq_dense = relay.qnn.op.dequantize( + bias_add, data_scale_var * weight_scale_var, relay.const(0, dtype="int32"), axis=1 + ) + quantized_func = relay.Function( + [data, weight, data_scale_var, data_zp_var, weight_scale_var, weight_zp_var], deq_dense + ) + return quantized_func + + +def verify_dense_bias(data_shape, weight_shape, bias_shape, attrs, bias_type="bias_add"): + pre_func, data, weight, bias = create_dense_bias_func( + data_shape, weight_shape, bias_shape, attrs, bias_type + ) + quantized_func = create_q_dense_bias_func(data, weight, bias, attrs, bias_type) + quantize_and_check(pre_func, quantized_func, [DenseBiasAddPattern(None)]) + + +def create_add_func(lhs_shape, rhs_shape): + lhs = relay.var("lhs", relay.TensorType(lhs_shape, dtype="float32")) + rhs = relay.var("rhs", relay.TensorType(rhs_shape, dtype="float32")) + pre_func = relay.Function([lhs, rhs], relay.add(lhs, rhs)) + + return pre_func, lhs, rhs + + +def create_q_add_func(lhs, rhs): + lhs_scale_var, lhs_zp_var, rhs_scale_var, rhs_zp_var = create_scale_zps("add_lhs", "add_rhs") + q_lhs = relay.qnn.op.quantize(lhs, lhs_scale_var, lhs_zp_var) + q_rhs = relay.qnn.op.quantize(rhs, rhs_scale_var, rhs_zp_var) + + deq_lhs = relay.qnn.op.dequantize(q_lhs, lhs_scale_var, relay.const(0, dtype="int32")) + deq_rhs = relay.qnn.op.dequantize(q_rhs, rhs_scale_var, relay.const(0, dtype="int32")) + + add_scale = relay.op.add(lhs_scale_var, rhs_scale_var) + + requantized_lhs = relay.qnn.op.quantize(deq_lhs, add_scale, relay.const(0, dtype="int32")) + requantized_rhs = relay.qnn.op.quantize(deq_rhs, add_scale, relay.const(0, dtype="int32")) + + add = relay.op.add(requantized_lhs, requantized_rhs) + deq_add = relay.qnn.op.dequantize(add, add_scale, relay.const(0, dtype="int32")) + + quantized_func = relay.Function( + [lhs, rhs, lhs_scale_var, lhs_zp_var, rhs_scale_var, rhs_zp_var], deq_add + ) + return quantized_func + + +def verify_add(lhs_shape, rhs_shape): + pre_func, lhs, rhs = create_add_func(lhs_shape, rhs_shape) + quantized_func = create_q_add_func(lhs, rhs) + + quantize_and_check(pre_func, quantized_func, [AddPattern(None)]) + + +def create_mul_func(lhs_shape, rhs_shape): + lhs = relay.var("lhs", relay.TensorType(lhs_shape, dtype="float32")) + rhs = relay.var("rhs", relay.TensorType(rhs_shape, dtype="float32")) + pre_func = relay.Function([lhs, rhs], relay.multiply(lhs, rhs)) + + return pre_func, lhs, rhs + + +def create_q_mul_func(lhs, rhs): + lhs_scale_var, lhs_zp_var, rhs_scale_var, rhs_zp_var = create_scale_zps("mul_lhs", "mul_rhs") + q_lhs = relay.qnn.op.quantize(lhs, lhs_scale_var, lhs_zp_var) + q_rhs = relay.qnn.op.quantize(rhs, rhs_scale_var, rhs_zp_var) + + zeroed_q_lhs = relay.op.subtract(relay.op.cast(q_lhs, "int32"), lhs_zp_var) + zeroed_q_rhs = relay.op.subtract(relay.op.cast(q_rhs, "int32"), rhs_zp_var) + + multiply = relay.op.multiply(zeroed_q_lhs, zeroed_q_rhs) + deq_multiply = relay.qnn.op.dequantize( + multiply, lhs_scale_var * rhs_scale_var, relay.const(0, dtype="int32") + ) + + quantized_func = relay.Function( + [lhs, rhs, lhs_scale_var, lhs_zp_var, rhs_scale_var, rhs_zp_var], deq_multiply + ) + return quantized_func + + +def verify_mul(lhs_shape, rhs_shape): + pre_func, lhs, rhs = create_mul_func(lhs_shape, rhs_shape) + quantized_func = create_q_mul_func(lhs, rhs) + quantize_and_check(pre_func, quantized_func, [MultiplyPattern(None)]) + + +def verify_skip_layers(data_shape, weight_shape, attrs): + # We'll test skip_layers with the dense op + data = relay.var("data", relay.TensorType(data_shape, "float32")) + weight = relay.var("weight", relay.TensorType(weight_shape, "float32")) + pre_func = relay.Function([data, weight], relay.nn.dense(data, weight)) + + quantize_and_check(pre_func, pre_func, [DensePattern(None)], skip_first=True, skip_last=False) + quantize_and_check(pre_func, pre_func, [DensePattern(None)], skip_first=False, skip_last=True) + quantize_and_check(pre_func, pre_func, [DensePattern(None)], skip_first=True, skip_last=True) + + +def test_conv2d(): + verify_conv2d( + (2, 3, 32, 32), + (32, 3, 3, 3), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + ) + verify_conv2d( + (2, 32, 32, 3), + (3, 3, 3, 32), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + ) + + +def test_conv2d_bias(): + verify_conv2d_bias( + (2, 3, 32, 32), + (32, 3, 3, 3), + (32,), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + ) + verify_conv2d_bias( + (2, 32, 32, 3), + (3, 3, 3, 32), + (32,), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + ) + verify_conv2d_bias( + (2, 3, 32, 32), + (32, 3, 3, 3), + (1, 32, 1, 1), + { + "kernel_size": [3, 3], + "kernel_layout": "OIHW", + "data_layout": "NCHW", + "padding": [0, 0, 0, 0], + }, + bias_type="normal_add", + ) + verify_conv2d_bias( + (2, 32, 32, 3), + (3, 3, 3, 32), + (1, 1, 1, 32), + { + "kernel_size": [3, 3], + "kernel_layout": "HWIO", + "data_layout": "NHWC", + "padding": [0, 0, 0, 0], + }, + bias_type="normal_add", + ) + + +def test_dense(): + verify_dense((1, 8), (16, 8), {"units": 16}) + verify_dense((1, 4), (3, 4), {"units": 3}) + + +def test_dense_bias(): + verify_dense_bias((1, 8), (16, 8), (16,), {"units": 16}) + verify_dense_bias((1, 4), (3, 4), (3,), {"units": 3}) + verify_dense_bias((1, 8), (16, 8), (16,), {"units": 16}, bias_type="normal_add") + verify_dense_bias((1, 4), (3, 4), (3,), {"units": 3}, bias_type="normal_add") + + +def test_add(): + verify_add((1, 2, 3), (1, 2, 3)) + + +def test_mul(): + verify_mul((1, 2, 3), (1, 2, 3)) + + +def test_skip_layers(): + verify_skip_layers((1, 8), (16, 8), {"units": 16}) + verify_skip_layers((1, 4), (3, 4), {"units": 3}) + + +if __name__ == "__main__": + test_conv2d() + test_conv2d_bias() + test_dense() + test_dense_bias() + test_add() + test_mul() + test_skip_layers() diff --git a/tests/python/relay/quantize/test_requantize.py b/tests/python/relay/quantize/test_requantize.py new file mode 100644 index 000000000000..4f8dc6ccaed8 --- /dev/null +++ b/tests/python/relay/quantize/test_requantize.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relay +from tvm.relay.transform.quantize import Requantizer +from tvm.relay.frontend.common import infer_type + +import numpy as np + + +def check_requantize(pre_graph, expected_graph): + post_graph = Requantizer().requantize(pre_graph) + + post_graph = infer_type(post_graph) + expected_graph = infer_type(expected_graph) + + assert tvm.ir.structural_equal(post_graph, expected_graph) + + +def test_simple_requantize(): + data_shape = (1, 2, 3, 4) + int8_data = relay.var("int8_data", relay.TensorType(data_shape, dtype="int8")) + scale1, zp1 = relay.const(np.array(1).astype("float32")), relay.const( + np.array(2).astype("int32") + ) + + deq_data = relay.qnn.op.dequantize(int8_data, scale1, zp1) + scale2, zp2 = relay.const(np.array(3).astype("float32")), relay.const( + np.array(4).astype("int32") + ) + pre_graph = relay.Function([int8_data], relay.qnn.op.quantize(deq_data, scale2, zp2)) + + expected_graph = relay.Function( + [int8_data], relay.qnn.op.requantize(int8_data, scale1, zp1, scale2, zp2) + ) + check_requantize(pre_graph, expected_graph) + + +def test_int8_requantize(): + data_shape = (1, 2, 3, 4) + int8_data = relay.var("int8_data", relay.TensorType(data_shape, dtype="int8")) + scale1, zp1 = relay.const(np.array(1).astype("float32")), relay.const( + np.array(2).astype("int32") + ) + deq = relay.qnn.op.dequantize(int8_data, scale1, zp1) + int8_op = relay.op.nn.relu(deq) + scale2, zp2 = relay.const(np.array(3).astype("float32")), relay.const( + np.array(0).astype("int32") + ) + quantize = relay.qnn.op.quantize(int8_op, scale2, zp2) + pre_graph = relay.Function([int8_data], quantize) + + requantize = relay.qnn.op.requantize(int8_data, scale1, zp1, scale2, zp2) + int8_op = relay.op.nn.relu(requantize) + expected_graph = relay.Function([int8_data], int8_op) + + check_requantize(pre_graph, expected_graph) + + +def test_int8_requantize_zp(): + data_shape = (1, 2, 3, 4) + int8_data = relay.var("int8_data", relay.TensorType(data_shape, dtype="int8")) + scale1, zp1 = relay.const(np.array(1).astype("float32")), relay.const( + np.array(2).astype("int32") + ) + deq = relay.qnn.op.dequantize(int8_data, scale1, zp1) + int8_op = relay.op.nn.relu(deq) + scale2, zp2 = relay.const(np.array(3).astype("float32")), relay.const( + np.array(4).astype("int32") + ) + quantize = relay.qnn.op.quantize(int8_op, scale2, zp2) + pre_graph = relay.Function([int8_data], quantize) + + requantize = relay.qnn.op.requantize(int8_data, scale1, zp1, scale2, zp2) + zp = relay.op.cast(zp2, dtype="int8") + int8_op = relay.op.maximum(requantize, zp) + expected_graph = relay.Function([int8_data], int8_op) + + check_requantize(pre_graph, expected_graph) + + +def test_chain_removal(): + data_shape = (1, 2, 3, 4) + int8_data = relay.var("int8_data", relay.TensorType(data_shape, dtype="int8")) + scale1, zp1 = relay.const(np.array(1).astype("float32")), relay.const( + np.array(2).astype("int32") + ) + scale2, zp2 = relay.const(np.array(3).astype("float32")), relay.const( + np.array(4).astype("int32") + ) + requantize = relay.qnn.op.requantize(int8_data, scale1, zp1, scale2, zp2) + + scale3, zp3 = relay.const(np.array(5).astype("float32")), relay.const( + np.array(6).astype("int32") + ) + requantize2 = relay.qnn.op.requantize(requantize, scale2, zp2, scale3, zp3) + + scale4, zp4 = relay.const(np.array(7).astype("float32")), relay.const( + np.array(8).astype("int32") + ) + requantize3 = relay.qnn.op.requantize(requantize2, scale3, zp3, scale4, zp4) + pre_graph = relay.Function([int8_data], requantize3) + + expected_graph = relay.Function( + [int8_data], relay.qnn.op.requantize(int8_data, scale1, zp1, scale4, zp4) + ) + + check_requantize(pre_graph, expected_graph) + + +def test_consolidate(): + data_shape = (1, 2, 3, 4) + data = relay.var("data", relay.TensorType(data_shape, dtype="float32")) + scale1, zp1 = relay.const(np.array(1).astype("float32")), relay.const( + np.array(2).astype("int32") + ) + quantize = relay.qnn.op.quantize(data, scale1, zp1) + + scale2, zp2 = relay.const(np.array(3).astype("float32")), relay.const( + np.array(4).astype("int32") + ) + requantize = relay.qnn.op.requantize(quantize, scale1, zp1, scale2, zp2) + pre_graph = relay.Function([data], requantize) + + expected_graph = relay.Function([data], relay.qnn.op.quantize(data, scale2, zp2)) + + check_requantize(pre_graph, expected_graph) + + +if __name__ == "__main__": + test_simple_requantize() + test_int8_requantize() + test_int8_requantize_zp() + test_chain_removal() + test_consolidate() diff --git a/tutorials/quantization/README.txt b/tutorials/quantization/README.txt new file mode 100644 index 000000000000..2b953835db73 --- /dev/null +++ b/tutorials/quantization/README.txt @@ -0,0 +1,2 @@ +Quantization : A flexible, extensible framework for auto-quantizing models +-------------------------------------------------------------------------- diff --git a/tutorials/quantization/quantize_resnet.py b/tutorials/quantization/quantize_resnet.py new file mode 100644 index 000000000000..e69de29bb2d1