From b4b2e191f6c6d660cf7596eee8fabe19a921bbb0 Mon Sep 17 00:00:00 2001 From: Lyalyushkin Nikolay Date: Tue, 14 Nov 2023 13:44:25 +0100 Subject: [PATCH] Compress weights with a single reduction axis only (#2254) ### Changes Exclude from weight compression nodes that has more than one reduction axes ### Reason for changes There's only one model that has multiple reduction axes. It's `chatglm` with one embedding layer having [8132,32,2] shape. It was decided to not quantize this layer, since it would save just 6Mb in 4Gb model in case of int8 quantization with risk to reduce accuracy, and it can't be quantized group-wise. The idea is to switch to multiple reduction axes when it will be really needed. ### Related tickets n/a ### Tests Tested on 104 models from share with IR's for llm models. In all cases except chatglm there's a single reduction axis. --- nncf/common/graph/layer_attributes.py | 2 +- nncf/common/graph/operator_metatypes.py | 2 +- nncf/common/pruning/tensor_processor.py | 2 +- nncf/common/tensor_statistics/collectors.py | 2 +- nncf/onnx/graph/node_utils.py | 8 +-- .../weight_compression/openvino_backend.py | 70 ++++++++++--------- .../quantization/test_weights_compression.py | 14 ++-- 7 files changed, 52 insertions(+), 48 deletions(-) diff --git a/nncf/common/graph/layer_attributes.py b/nncf/common/graph/layer_attributes.py index ce934c23b8c..6ef514a1645 100644 --- a/nncf/common/graph/layer_attributes.py +++ b/nncf/common/graph/layer_attributes.py @@ -95,7 +95,7 @@ def __init__( :param weight_requires_grad: Is True if gradients need to be computed for the corresponding Tensor, False otherwise. :param weight_shape: shape of weight tensor. - :param filter_dimension_idx: the axis along which the filters are stored. + :param filter_dimension_idx: the axis, along which the filters are stored. """ super().__init__(weight_requires_grad=weight_requires_grad, with_bias=with_bias) self.weight_shape = weight_shape diff --git a/nncf/common/graph/operator_metatypes.py b/nncf/common/graph/operator_metatypes.py index 305a3b66668..fc9d89f5f27 100644 --- a/nncf/common/graph/operator_metatypes.py +++ b/nncf/common/graph/operator_metatypes.py @@ -21,7 +21,7 @@ class OperatorMetatype: :param name: The name of the operator. :param hw_config_names: The names of the hardware configurations. - :param output_channel_axis: The axis along which the output channels of the operator are arranged. + :param output_channel_axis: The axis, along which the output channels of the operator are arranged. :param ignored_input_ports: Input ports of the operations that should not be considered for purposes of compression. """ diff --git a/nncf/common/pruning/tensor_processor.py b/nncf/common/pruning/tensor_processor.py index c7c57432059..e160c045120 100644 --- a/nncf/common/pruning/tensor_processor.py +++ b/nncf/common/pruning/tensor_processor.py @@ -28,7 +28,7 @@ def concatenate(cls, tensors: List[NNCFTensor], axis: int) -> NNCFTensor: Join a list of NNCFTensors along an existing axis. :param tensors: List of NNCFTensors. - :param axis: The axis along which the tensors will be joined. + :param axis: The axis, along which the tensors will be joined. :returns: The concatenated List of the tensors. """ diff --git a/nncf/common/tensor_statistics/collectors.py b/nncf/common/tensor_statistics/collectors.py index 894fe9a2a39..e1b01310216 100644 --- a/nncf/common/tensor_statistics/collectors.py +++ b/nncf/common/tensor_statistics/collectors.py @@ -358,7 +358,7 @@ def cat(x: List[NNCFTensor], axis: int) -> NNCFTensor: Join a sequence of arrays along an existing axis. :param x: The input tensor. - :param axis: The axis along which the arrays will be joined. + :param axis: The axis, along which the arrays will be joined. :return: The concatenated array. """ diff --git a/nncf/onnx/graph/node_utils.py b/nncf/onnx/graph/node_utils.py index 1e9a162211d..7d46356904c 100644 --- a/nncf/onnx/graph/node_utils.py +++ b/nncf/onnx/graph/node_utils.py @@ -152,11 +152,11 @@ def get_reduction_shape(shape: List[int], axis: int) -> ReductionAxes: def _get_weight_quantization_axis(node: NNCFNode, port_id: int) -> int: """ - Returns weight tensor axis along quantizer parameters are calculated. + Returns weight tensor axis, along which quantizer parameters are calculated. :param node: NNCFNode, which has a weight on input port_id. :param port_id: Input port id on which there is a weight of a node. - :return: Axis along quantizer parameters are calculated. + :return: Axis, along which quantizer parameters are calculated. """ weight_channel_axis = node.metatype.weight_channel_axis if node.layer_attributes.has_node_attrs(): @@ -174,9 +174,9 @@ def _get_weight_quantization_axis(node: NNCFNode, port_id: int) -> int: def _get_activation_quantization_axis() -> int: """ - Returns activation tensor axis along quantizer parameters are calculated. + Returns activation tensor axis, along which quantizer parameters are calculated. - :return: Axis along quantizer parameters are calculated. + :return: Axis, along which quantizer parameters are calculated. """ return 1 # Activations have channel first layout: [N, C, Z, Y, X] diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index a30b279d95b..a6b55a16889 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Tuple, TypeVar, Union +from typing import List, Optional, Tuple, TypeVar import numpy as np import openvino.runtime as ov @@ -73,13 +73,23 @@ def do_compression( continue const_shape = nncf_node.layer_attributes.constant_attributes[weight_port_id]["shape"] channel_axes = get_weight_channel_axes(nncf_node, weight_port_id) - axes = get_channel_agnostic_reduction_axes(channel_axes, const_shape) + reduction_axes = get_channel_agnostic_reduction_axes(channel_axes, const_shape) + if isinstance(reduction_axes, tuple) and len(reduction_axes) != 1: + nncf_logger.warning( + f"Weight compression expects a single reduction axes, but given {len(reduction_axes)}. " + f"Weight shape: {const_shape}, reduction axes: {reduction_axes}, node name: {nncf_node.name}. " + "The node won't be quantized." + ) + continue + reduction_axis = reduction_axes[0] if isinstance(reduction_axes, tuple) else reduction_axes + fq_name = f"{weight_op_friendly_name}/fq_weights_{weight_port_id}" num_weights = np.prod(const_shape) - weight_params = WeightNodeParams(axes, num_weights, fq_name, weight_node, original_weight_dtype) + weight_params = WeightNodeParams( + reduction_axis, num_weights, fq_name, weight_node, original_weight_dtype + ) all_weight_params.append(weight_params) quantized_nodes_ids.add(id(weight_node)) - if mode != CompressWeightsMode.INT8: primary_config = WeightCompressionConfig(mode=mode, group_size=group_size) _assign_mixed_precision(all_weight_params, ratio, primary_config) @@ -98,7 +108,7 @@ def do_compression( config = wp.compression_config if config.mode == CompressWeightsMode.NF4: original_shape = weight.shape - norm_weight, scale = _get_norm_weight_and_nf4_scale(weight, wp.reduction_axes, group_size) + norm_weight, scale = _get_norm_weight_and_nf4_scale(weight, wp.reduction_axis, group_size) compressed_const = opset.constant(norm_weight, dtype=ov.Type.nf4, name=weight_name) convert = opset.convert(compressed_const, original_weight_dtype) mul = opset.multiply(convert, scale.astype(original_weight_dtype), name=wp.fq_name) @@ -107,7 +117,7 @@ def do_compression( last_output = mul.output(0) else: original_shape = weight.shape - compressed_weights, scale, zero_point = _do_integer_quantization(weight, wp.reduction_axes, config) + compressed_weights, scale, zero_point = _do_integer_quantization(weight, wp.reduction_axis, config) compression_type = np.uint8 if config.num_bits == 8 else ov.Type.u4 compressed_weights_node = opset.constant(compressed_weights, dtype=compression_type, name=weight_name) convert_weights_node = opset.convert(compressed_weights_node, original_weight_dtype) @@ -153,7 +163,7 @@ class WeightNodeParams: """ Information about weight node in the ov.Model that is useful for weight compression. - :param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max). + :param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max). :param num_weights: Number of elements in the weight array. :param fq_name: Name for the inserted weight compression operation. :param weight_node: The weight node itself. @@ -161,7 +171,7 @@ class WeightNodeParams: :param compression_config: Configuration of weight compression for the weight node. """ - reduction_axes: Union[int, Tuple[int]] + reduction_axis: int num_weights: int fq_name: str weight_node: ov.Node @@ -170,7 +180,7 @@ class WeightNodeParams: def _do_integer_quantization( - weight: np.ndarray, reduction_axes: Union[int, Tuple[int]], config: WeightCompressionConfig + weight: np.ndarray, reduction_axis: int, config: WeightCompressionConfig ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ The method quantizes the given weights to integer data type in accordance with the compression config. @@ -186,7 +196,7 @@ def _do_integer_quantization( (scales). :param weight: Weight array to compress. - :param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max). + :param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max). :param config: Information on how to compress (quantize) a specific weight. :return: The compressed weights, scale and zero point that was used for its quantization. """ @@ -200,16 +210,16 @@ def _do_integer_quantization( if group_size != -1: # weights are reshaped from [a1, r, a2] to [a1, r//gs, gs, a2] - weight, reduction_axes = _reshape_weights_for_grouped_quantization(weight, reduction_axes, group_size) + weight, reduction_axis = _reshape_weights_for_grouped_quantization(weight, reduction_axis, group_size) if mode in [CompressWeightsMode.INT8, CompressWeightsMode.INT4_ASYM]: - min_values = np.min(weight, axis=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] - max_values = np.max(weight, axis=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] + min_values = np.min(weight, axis=reduction_axis, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] + max_values = np.max(weight, axis=reduction_axis, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] scale, zero_point = calculate_scale_zero_point( min_values, max_values, level_low, level_high, narrow_range=False ) else: - scale = np.max(np.abs(weight), axis=reduction_axes, keepdims=True) # [a1, r//gs, 1, a2] + scale = np.max(np.abs(weight), axis=reduction_axis, keepdims=True) # [a1, r//gs, 1, a2] level_low_sym = -(2 ** (num_bits - 1)) level_high_sym = 2 ** (num_bits - 1) - 1 scale = scale / level_high_sym @@ -223,33 +233,31 @@ def _do_integer_quantization( return compressed_weights, scale, zero_point -def _get_integer_quantization_error( - weight: np.ndarray, reduction_axes: Union[int, Tuple[int]], config: WeightCompressionConfig -) -> float: +def _get_integer_quantization_error(weight: np.ndarray, reduction_axis: int, config: WeightCompressionConfig) -> float: """ Calculates a quantity characterizing the difference between floating point weights and fake quantized (compressed and decompressed) to integer ones. :param weight: Weight array to compress. - :param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max). + :param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max). :param config: Information on how to compress (quantize) a specific weight. :return: The quantity characterizing the error of integer quantization. """ orig_shape = weight.shape - compressed_weights, scale, zero_point = _do_integer_quantization(weight, reduction_axes, config) + compressed_weights, scale, zero_point = _do_integer_quantization(weight, reduction_axis, config) decompressed_weight = compressed_weights.astype(dtype=scale.dtype) decompressed_weight = (compressed_weights - zero_point) * scale decompressed_weight = decompressed_weight.reshape(orig_shape) diff = (decompressed_weight - weight) ** 2 - layer_err = np.mean(diff, axis=reduction_axes) + layer_err = np.mean(diff, axis=reduction_axis) val = np.max(layer_err) return val def _reshape_weights_for_grouped_quantization( - weight: np.ndarray, reduction_axes: Union[int, Tuple[int]], group_size: int + weight: np.ndarray, reduction_axis: int, group_size: int ) -> Tuple[np.ndarray, int]: """ Reshapes weights for group-wise quantization and return a new reduction axis for collecting statistics per group @@ -257,16 +265,12 @@ def _reshape_weights_for_grouped_quantization( [c_out, c_in // 128, 128]. :param weight: Weight array to compress. - :param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max). + :param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max). :param group_size: Number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). :return: reshaped weights and new reduction axis. """ assert group_size != -1 - if isinstance(reduction_axes, tuple) and len(reduction_axes) != 1: - raise RuntimeError( - f"group-quantization is supported for a single reduction axes, but got {len(reduction_axes)}" - ) - reduction_axis = reduction_axes[0] if isinstance(reduction_axes, tuple) else reduction_axes + assert isinstance(reduction_axis, int) channel_size = weight.shape[reduction_axis] if channel_size % group_size != 0: raise RuntimeError(f"Channel size {channel_size} should be divisible by size of group {group_size}") @@ -280,24 +284,24 @@ def _reshape_weights_for_grouped_quantization( def _get_norm_weight_and_nf4_scale( - weight: np.ndarray, reduction_axes: Tuple[int], group_size: int = -1 + weight: np.ndarray, reduction_axis: int, group_size: int = -1 ) -> Tuple[np.ndarray, np.ndarray]: """ Calculates scale for nf4 quantization and normalizes weights by the scale. Weights are reshaped in case of positive value of group size. :param weight: Weight array to compress. - :param reduction_axes: Axis or axes along which to reduce (collect) different statistics (e.g. min, max). + :param reduction_axis: Axis, along which to reduce (collect) different statistics (e.g. min, max). :param group_size: Number of weights (e.g. 128) in the channel dimension that share quantization parameters (scale). The value -1 means no grouping. Defaults to -1. :return: Normalized weights and nf4 scale. """ if group_size != -1: # weights are reshaped: [a1, r, a2] -> [a1, r//gs, gs, a2] - weight, reduction_axis = _reshape_weights_for_grouped_quantization(weight, reduction_axes, group_size) + weight, reduction_axis = _reshape_weights_for_grouped_quantization(weight, reduction_axis, group_size) scale = np.max(np.abs(weight), axis=reduction_axis, keepdims=True) # [a1, r//gs, 1, a2] else: - scale = np.max(np.abs(weight), axis=reduction_axes, keepdims=True) # [a1, 1, a2] + scale = np.max(np.abs(weight), axis=reduction_axis, keepdims=True) # [a1, 1, a2] eps = np.finfo(weight.dtype).eps # NOTE: adding machine epsilon to avoid division by zero scale[np.abs(scale) < eps] = eps @@ -372,8 +376,8 @@ def _assign_mixed_precision( for weight_param in track(all_weight_params[1:-1], description="Searching for Mixed-Precision Configuration"): weight = get_const_value(weight_param.weight_node) backup_config = weight_param.compression_config - reduction_axes = weight_param.reduction_axes - backup_error = _get_integer_quantization_error(weight, reduction_axes, backup_config) + reduction_axis = weight_param.reduction_axis + backup_error = _get_integer_quantization_error(weight, reduction_axis, backup_config) eps = np.finfo(weight.dtype).eps error = 1 / (backup_error + eps) errors.append(error) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 94501cae790..561d1927509 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -298,7 +298,7 @@ def __str__(self): @pytest.mark.parametrize("desc", LIST_DESCS, ids=map(str, LIST_DESCS)) def test_quantization_error_calculation(desc: QuantErrorDesc): weight = desc.weight - axis = (1,) + axis = 1 actual_error = _get_integer_quantization_error(weight, axis, desc.config) ref_error = desc.ref_error atol = desc.atol if desc.atol is not None else 1e-8 @@ -374,20 +374,20 @@ def test_weight_compress_with_ignored_scope(ignored_scope, num_compressed): @pytest.mark.parametrize("desc", CALCULATE_SCALE_DESCS) def test_calculate_scale_per_group(desc: CalculateScaleDesc): reshaped_weight, reduction_axis = _reshape_weights_for_grouped_quantization( - desc.weight, reduction_axes=desc.axis, group_size=desc.group_size + desc.weight, reduction_axis=desc.axis, group_size=desc.group_size ) act_scale = np.max(np.abs(reshaped_weight), axis=reduction_axis, keepdims=True) # [a1, r//gs, 1, a2] assert np.allclose(act_scale, desc.ref_scale) def test_raise_error_for_many_axes(): - with pytest.raises(RuntimeError): - _reshape_weights_for_grouped_quantization(WEIGHTS_2x4, reduction_axes=(0, 1), group_size=1) + with pytest.raises(AssertionError): + _reshape_weights_for_grouped_quantization(WEIGHTS_2x4, reduction_axis=(0, 1), group_size=1) -def test_raise_error_with_incorrect_group_size(): - with pytest.raises(RuntimeError): - _reshape_weights_for_grouped_quantization(WEIGHTS_2x4, reduction_axes=(0,), group_size=3) +def test_raise_error_with_tuple(): + with pytest.raises(AssertionError): + _reshape_weights_for_grouped_quantization(WEIGHTS_2x4, reduction_axis=(0,), group_size=3) def test_raise_error_with_int8_and_non_default_ratio(mocker):