From e6bf1d5bd2408535b875a0c3482d79d69d143b2d Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 23 Oct 2024 14:54:41 +0200 Subject: [PATCH] WIP experimental quantization --- .../algorithms/post_training/__init__.py | 10 + .../algorithms/post_training/algorithm.py | 103 ++++ .../algorithms/post_training/pipeline.py | 139 +++++ .../algorithms/quantizer/fx_quantizer.py | 110 ++++ .../algorithms/quantizer/quantizer.py | 26 + .../algorithms/range_estimator/backend.py | 154 ++++++ .../range_estimator/range_estimator.py | 495 ++++++++++++++++++ .../range_estimator/torch_fx_backend.py | 221 ++++++++ .../torch/fx/quantization/quantize_pt2e.py | 91 ++++ torch_fx_experimental_q.py | 107 ++++ 10 files changed, 1456 insertions(+) create mode 100644 nncf/experimental/common/quantization/algorithms/post_training/__init__.py create mode 100644 nncf/experimental/common/quantization/algorithms/post_training/algorithm.py create mode 100644 nncf/experimental/common/quantization/algorithms/post_training/pipeline.py create mode 100644 nncf/experimental/common/quantization/algorithms/quantizer/fx_quantizer.py create mode 100644 nncf/experimental/common/quantization/algorithms/quantizer/quantizer.py create mode 100644 nncf/experimental/common/quantization/algorithms/range_estimator/backend.py create mode 100644 nncf/experimental/common/quantization/algorithms/range_estimator/range_estimator.py create mode 100644 nncf/experimental/common/quantization/algorithms/range_estimator/torch_fx_backend.py create mode 100644 nncf/experimental/torch/fx/quantization/quantize_pt2e.py create mode 100644 torch_fx_experimental_q.py diff --git a/nncf/experimental/common/quantization/algorithms/post_training/__init__.py b/nncf/experimental/common/quantization/algorithms/post_training/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/common/quantization/algorithms/post_training/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/common/quantization/algorithms/post_training/algorithm.py b/nncf/experimental/common/quantization/algorithms/post_training/algorithm.py new file mode 100644 index 00000000000..0f7a90d3d67 --- /dev/null +++ b/nncf/experimental/common/quantization/algorithms/post_training/algorithm.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from typing import Callable, List, Optional, TypeVar + +from nncf import Dataset +from nncf.common.graph.graph import NNCFGraph +from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.common.utils.backend import BackendType +from nncf.experimental.common.quantization.algorithms.post_training.pipeline import create_ptq_pipeline +from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer +from nncf.parameters import ModelType +from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters +from nncf.quantization.algorithms.algorithm import Algorithm + +TModel = TypeVar("TModel") +TPass = Callable[[TModel], TModel] + + +class PostTrainingQuantization(Algorithm): + """ + Implements Post-Training Quantization algorithm, which basically includes: + 1) ChannelAlignment + 2) MinMaxQuantization + 3) FastBiasCorrection or BiasCorrection + """ + + def __init__( + self, + quantizer: NNCFQuantizer, + subset_size: int = 300, + fast_bias_correction: bool = True, + model_type: Optional[ModelType] = None, + advanced_parameters: Optional[AdvancedQuantizationParameters] = None, + ): + """ + :param mode: Special quantization mode that specify different ways of the optimization. + :param preset: A preset controls the quantization mode (symmetric and asymmetric). + It can take the following values: + - `performance`: Symmetric quantization of weights and activations. + - `mixed`: Symmetric quantization of weights and asymmetric quantization of activations. + Default value is None. In this case, `mixed` preset is used for `transformer` + model type otherwise `performace`. + :param target_device: A target device the specificity of which will be taken + into account while compressing in order to obtain the best performance + for this type of device. + :param subset_size: Size of a subset to calculate activations + statistics used for quantization. + :param fast_bias_correction: Setting this option to `False` enables a different + bias correction method which is more accurate, in general, and takes + more time but requires less memory. + :param model_type: Model type is needed to specify additional patterns + in the model. Supported only `transformer` now. + :param ignored_scope: An ignored scope that defined the list of model control + flow graph nodes to be ignored during quantization. + :param advanced_parameters: Advanced quantization parameters for + fine-tuning the quantization algorithm + """ + self._pipeline = create_ptq_pipeline( + quantizer=quantizer, + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + advanced_parameters=advanced_parameters, + ) + + @property + def available_backends(self) -> List[BackendType]: + backends = set(BackendType) + for algorithm in itertools.chain.from_iterable(self._pipeline.pipeline_steps): + backends = backends.intersection(algorithm.available_backends) + return list(backends) + + def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: + return self._pipeline.get_statistic_points_for_step(0, model, graph) + + def apply( + self, + model: TModel, + graph: NNCFGraph, + statistic_points: Optional[StatisticPointsContainer] = None, + dataset: Optional[Dataset] = None, + ) -> TModel: + if dataset is None and len(self._pipeline.pipeline_steps) > 1: + raise ValueError( + "A dataset is required for the post-training quantization " + "algorithm to collect statistics for intermediate models." + ) + + step_index_to_statistics = None + if statistic_points: + step_index_to_statistics = {0: statistic_points} + + return self._pipeline.run_from_step(model, dataset, graph, 0, step_index_to_statistics) diff --git a/nncf/experimental/common/quantization/algorithms/post_training/pipeline.py b/nncf/experimental/common/quantization/algorithms/post_training/pipeline.py new file mode 100644 index 00000000000..bb1c0ec0bba --- /dev/null +++ b/nncf/experimental/common/quantization/algorithms/post_training/pipeline.py @@ -0,0 +1,139 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, TypeVar + +from nncf.common.deprecation import warning_deprecated +from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer +from nncf.experimental.common.quantization.algorithms.range_estimator.range_estimator import MinMaxRangeEstimator +from nncf.parameters import ModelType +from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters +from nncf.quantization.algorithms.bias_correction.algorithm import BIAS_CORRECTION_THRESHOLD +from nncf.quantization.algorithms.bias_correction.algorithm import BiasCorrection +from nncf.quantization.algorithms.channel_alignment.algorithm import ChannelAlignment +from nncf.quantization.algorithms.fast_bias_correction.algorithm import FAST_BIAS_CORRECTION_THRESHOLD +from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection +from nncf.quantization.algorithms.pipeline import Pipeline +from nncf.quantization.algorithms.smooth_quant.algorithm import SmoothQuant + +TModel = TypeVar("TModel") + + +def create_ptq_pipeline( + quantizer: NNCFQuantizer, + subset_size: int = 300, + fast_bias_correction: bool = True, + model_type: Optional[ModelType] = None, + advanced_parameters: Optional[AdvancedQuantizationParameters] = None, +) -> Pipeline: + """ + Creates a post-training quantization pipeline. + + The post-training quantization pipeline includes the following steps: + 1) SmoothQuant + 2) ChannelAlignment + 3) MinMaxQuantization + 4) FastBiasCorrection or BiasCorrection + + :param mode: Special quantization mode that specify different ways of the optimization. + :param preset: A preset controls the quantization mode (symmetric and asymmetric). + It can take the following values: + - `performance`: Symmetric quantization of weights and activations. + - `mixed`: Symmetric quantization of weights and asymmetric quantization of activations. + Default value is None. In this case, `mixed` preset is used for `transformer` + model type otherwise `performace`. + :param target_device: A target device the specificity of which will be taken + into account while compressing in order to obtain the best performance + for this type of device. + :param subset_size: Size of a subset to calculate activations + statistics used for quantization. + :param fast_bias_correction: Setting this option to `False` enables a different + bias correction method which is more accurate, in general, and takes + more time but requires less memory. + :param model_type: Model type is needed to specify additional patterns + in the model. Supported only `transformer` now. + :param advanced_parameters: Advanced quantization parameters for + fine-tuning the quantization algorithm + :return: A post-training quantization pipeline. + """ + + if advanced_parameters is None: + advanced_parameters = AdvancedQuantizationParameters() + + # Build the post-training quantization pipeline. + pipeline_steps = [] + + # Add the `SmoothQuant` algorithm as the first step of the pipeline. + # It is added only for `ModelType.TRANSFORMER`. + sq_params = advanced_parameters.smooth_quant_alphas + sq_alpha = advanced_parameters.smooth_quant_alpha + if sq_alpha is not None: + warning_deprecated( + "`AdvancedQuantizationParameters(smooth_quant_alpha=..)` is deprecated." + "Please, use `AdvancedQuantizationParameters(smooth_quant_alphas)` option " + "with AdvancedSmoothQuantParameters(convolution=.., matmul=..) as value instead." + ) + if sq_alpha < 0: + sq_params.convolution = -1 + sq_params.matmul = -1 + else: + sq_params.matmul = sq_alpha + + if model_type == ModelType.TRANSFORMER and (sq_params.convolution >= 0 or sq_params.matmul >= 0): + alpha_map = {"convolution": sq_params.convolution, "matmul": sq_params.matmul} + pipeline_steps.append([SmoothQuant(subset_size, advanced_parameters.inplace_statistics, alpha_map=alpha_map)]) + + # Add the `ChannelAlignment` algorithm as the second step of the pipeline. + if not advanced_parameters.disable_channel_alignment: + pipeline_steps.append([ChannelAlignment(subset_size, advanced_parameters.inplace_statistics)]) + + # Add the `MinMaxQuantization` algorithm as the third step of the pipeline. + pipeline_steps.append( + [ + MinMaxRangeEstimator( + quantizer=quantizer, + subset_size=subset_size, + inplace_statistics=advanced_parameters.inplace_statistics, + batchwise_statistics=advanced_parameters.batchwise_statistics, + activations_range_estimator_params=advanced_parameters.activations_range_estimator_params, + weights_range_estimator_params=advanced_parameters.weights_range_estimator_params, + ) + ] + ) + + if not advanced_parameters.disable_bias_correction: + # Add the `FastBiasCorrection` or `BiasCorrection` as additional algorithm + # inside the third step of the pipeline. It is added after `MinMaxQuantization` + # algorithm. + bias_correction_params = advanced_parameters.bias_correction_params + if fast_bias_correction: + threshold = FAST_BIAS_CORRECTION_THRESHOLD + bias_correction_subset_size = subset_size + bias_correction_cls = FastBiasCorrection + else: + threshold = BIAS_CORRECTION_THRESHOLD + bias_correction_subset_size = max(int(subset_size * 0.2), 1) + bias_correction_cls = BiasCorrection + + if bias_correction_params.threshold is not None: + threshold = bias_correction_params.threshold + + pipeline_steps[-1].append( + bias_correction_cls( + bias_correction_subset_size, + threshold, + bias_correction_params.apply_for_all_nodes, + advanced_parameters.inplace_statistics, + advanced_parameters.backend_params, + ) + ) + + return Pipeline(pipeline_steps) diff --git a/nncf/experimental/common/quantization/algorithms/quantizer/fx_quantizer.py b/nncf/experimental/common/quantization/algorithms/quantizer/fx_quantizer.py new file mode 100644 index 00000000000..e7d80fbaff4 --- /dev/null +++ b/nncf/experimental/common/quantization/algorithms/quantizer/fx_quantizer.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import defaultdict +from copy import deepcopy + +import torch +import torch.fx +from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id +from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec +from torch.ao.quantization.pt2e.prepare import _get_obs_or_fq_map +from torch.ao.quantization.quantizer import Quantizer +from torch.ao.quantization.quantizer.quantizer import QuantizationSpec +from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec + +import nncf +from nncf.common.graph.graph import NNCFGraph +from nncf.common.quantization.quantizer_setup import ActivationQuantizationInsertionPoint +from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint +from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup +from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint +from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode +from nncf.common.quantization.structs import QuantizerConfig +from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer + + +class NNCFFXQuantizer(NNCFQuantizer): + def __init__(self, quantizer: Quantizer): + self._quantizer = quantizer + + def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup: + anotated_model = deepcopy(model) + + self._quantizer.transform_for_annotation(anotated_model) + self._quantizer.annotate(anotated_model) + self._quantizer.validate(anotated_model) + return self.get_quantizer_config_from_anotated_model(anotated_model) + + @staticmethod + def get_quantizer_config_from_anotated_model(anotated_model: torch.fx.GraphModule) -> SingleConfigQuantizerSetup: + is_qat = False + edge_or_node_to_qspec = _get_edge_or_node_to_qspec(anotated_model) + edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec) + obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat) + if obs_or_fq_map: + pass + + q_map = defaultdict(list) + for edge, qspec in edge_or_node_to_qspec.items(): + if not isinstance(edge, tuple): + continue + from_n, to_n = edge + q_map[from_n].append(to_n) + + q_setup = SingleConfigQuantizerSetup() + for from_n, to_nodes in q_map.items(): + to_n = to_nodes[0] + qspec = edge_or_node_to_qspec[(from_n, to_n)] + if qspec is None: + continue + if isinstance(qspec, QuantizationSpec): + if qspec.qscheme in [torch.per_channel_affine, torch.per_channel_symmetric]: + per_channel = True + elif qspec.qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]: + per_channel = False + else: + raise nncf.InternalError(f"Unknown qscheme: {qspec.qscheme}") + signed = qspec.dtype is torch.uint8 + mode = ( + QuantizationMode.SYMMETRIC + if qspec.qscheme in [torch.per_channel_symmetric, torch.per_tensor_symmetric] + else QuantizationMode.ASYMMETRIC + ) + qconfig = QuantizerConfig(mode=mode, signedness_to_force=signed, per_channel=per_channel) + qps = [] + # If input node is a constant and placed not at activations port (0) + if from_n.op == "get_attr" and to_n.args.index(from_n) != 0: + qip = WeightQuantizationInsertionPoint(to_n.name) + qp = SingleConfigQuantizationPoint(qip, qconfig, [x.name for x in to_nodes]) + qps.append(qp) + else: + if len(from_n.users) == len(to_nodes): + qip = ActivationQuantizationInsertionPoint(from_n.name) + qp = SingleConfigQuantizationPoint(qip, qconfig, [x.name for x in to_nodes]) + qps.append(qp) + else: + for to_n_ in to_nodes: + input_port_id = to_n_.args.index(from_n) + qip = ActivationQuantizationInsertionPoint(to_n_.name, input_port_id) + qp = SingleConfigQuantizationPoint(qip, qconfig, [to_n_.name]) + qps.append(qp) + + for qp in qps: + q_setup.add_independent_quantization_point(qp) + + elif isinstance(qspec, SharedQuantizationSpec): + pass + else: + raise nncf.InternalError(f"Unknown torch.ao quantization spec: {qspec}") + + return q_setup diff --git a/nncf/experimental/common/quantization/algorithms/quantizer/quantizer.py b/nncf/experimental/common/quantization/algorithms/quantizer/quantizer.py new file mode 100644 index 00000000000..b0d40234210 --- /dev/null +++ b/nncf/experimental/common/quantization/algorithms/quantizer/quantizer.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from typing import TypeVar + +from nncf.common.graph.graph import NNCFGraph +from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup + +TModel = TypeVar("TModel") + + +class NNCFQuantizer: + @abstractmethod + def get_quantization_setup(self, model: TModel, nncf_graph: NNCFGraph) -> SingleConfigQuantizerSetup: + """ + Return quantization setup. + """ diff --git a/nncf/experimental/common/quantization/algorithms/range_estimator/backend.py b/nncf/experimental/common/quantization/algorithms/range_estimator/backend.py new file mode 100644 index 00000000000..dbd11f3f6b7 --- /dev/null +++ b/nncf/experimental/common/quantization/algorithms/range_estimator/backend.py @@ -0,0 +1,154 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC +from abc import abstractmethod +from typing import List, Optional, Set, Tuple, TypeVar + +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.transformations.commands import TargetPoint +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationCommand +from nncf.common.quantization.structs import QuantizerConfig +from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase +from nncf.quantization.fake_quantize import FakeQuantizeParameters +from nncf.quantization.range_estimator import RangeEstimatorParameters + +TModel = TypeVar("TModel") + + +class RangeEstimatorAlgoBackend(ABC): + @staticmethod + @abstractmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint: + """ + Returns backend-specific target point. + + :param target_type: Type of the location that should be modified. + :param target_node_name: Name of the located node. + :param port_id: Port ID of the tensor for the statistics distribution. + :return: Backend-specific TargetPoint. + """ + + @staticmethod + @abstractmethod + def create_quantizer_insertion_command( + nncf_graph: NNCFGraph, + target_point: TargetPoint, + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> TransformationCommand: + """ + Returns backend-specific quantizer insertion command. + + :param nncf_graph: NNCFGraph to get input/output shapes for the target point. + :param target_point: Target location for the quantizer insertion. + :param quantizer_config: QuantizerConfig instance for the current layer. + :param parameters: FakeQuantizeParameters to calculate activation quantization parameters. + :return: Backend-specific TransformationCommand for the quantizer insertion operation. + """ + + @staticmethod + @abstractmethod + def create_unified_scales_quantizers_insertion_commands( + nncf_graph: NNCFGraph, + target_points: List[TargetPoint], + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> List[TransformationCommand]: + """ + Returns backend-specific unified scales quantizers insertion commands. + + :param nncf_graph: NNCFGraph to get input/output shapes for the target point. + :param target_points: List of target locations for the quantizers insertion. + :param quantizer_config: QuantizerConfig instance for the current layer. + :param parameters: FakeQuantizeParameters to calculate activation quantization parameters. + :return: List of backend-specific TransformationCommands + for the quantizers with unified scales insertion operations. + """ + + @staticmethod + @abstractmethod + def get_target_point_shape(nncf_graph: NNCFGraph, node: NNCFNode, target_point: TargetPoint) -> Tuple[int, ...]: + """ + Returns shape of a target point tensor. + + :param nncf_graph: NNCFGraph instance. + :param node: NNCFNode. + :param target_point: Target point of which tensor shape is seeked. + :return: Shape of target point tensor. + """ + + @staticmethod + @abstractmethod + def get_weight_quantization_axes(node: NNCFNode, target_point: TargetPoint, ndims: int) -> Tuple[int, ...]: + """ + Returns axes for per-channel quantization of weights of the node placed on a input port_id. + + :param node: Quantized node with the weight. + :param target_point: Corresponding target point. + :param ndims: Number of dimensions of weight. + :return: Axes for per-channel quantization of weights. + """ + + @staticmethod + @abstractmethod + def get_statistic_collector( + range_estimator_params: RangeEstimatorParameters, + use_abs_max: bool, + reduction_axes: Optional[Tuple[int, ...]], + aggregation_axes: Optional[Tuple[int, ...]], + inplace: bool, + num_samples: Optional[int] = None, + ) -> TensorStatisticCollectorBase: + """ + Returns backend-specific statistic collector. + + :param range_estimator_params: Parameters that specify estimators types. + :param use_abs_max: Wheather reduce absolute values of input tensors or not. + :param reduction_axes: Axes for reducer. + :param aggregation_axes: Axes for aggregator. + :param inplace: Whether to calculate statistic inplace or not. + :param num_samples: Maximum number of samples to collect. + :return: Backend-specific TensorStatisticCollectorBase for the statistics calculation. + """ + + @staticmethod + @abstractmethod + def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]: + """ + Returns node's input port indices with weight tensors. + + :param node: NNCFNode to find its weight input port indices. + :param graph: NNCFGraph instance. + :return: Weights input port indices. + """ + + @staticmethod + def get_weight_name(nncf_graph: NNCFGraph, target_point: TargetPoint) -> str: + """ + Returns node's weight name corresponding to port ID. + + :param nncf_graph: NNCFGraph instance. + :param target_point: The TargetPoint instance that contains layer's information. + :return: Weight name. + """ + + @staticmethod + def should_quantize_weight(weight_name: str, quantized_weight_names: Set[str]) -> bool: + """ + Return True if weight should be quantized. + + :param weight_name: Weight name. + :param quantized_weight_names: Set containing already quantized weight names. + :return: A boolean value specifying whether a weight should be quantized. + """ diff --git a/nncf/experimental/common/quantization/algorithms/range_estimator/range_estimator.py b/nncf/experimental/common/quantization/algorithms/range_estimator/range_estimator.py new file mode 100644 index 00000000000..5431703cdb1 --- /dev/null +++ b/nncf/experimental/common/quantization/algorithms/range_estimator/range_estimator.py @@ -0,0 +1,495 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import dataclasses +from copy import deepcopy +from typing import List, Optional, OrderedDict, Tuple, TypeVar + +import nncf +import nncf.tensor.functions as fns +from nncf import Dataset +from nncf.common.factory import ModelTransformerFactory +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.transformations.commands import TargetPoint +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.logging import nncf_logger +from nncf.common.quantization.initialization.range import RangeInitCollectorParams +from nncf.common.quantization.quantizer_setup import SingleConfigQuantizationPoint +from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup +from nncf.common.quantization.structs import QuantizerConfig +from nncf.common.quantization.structs import QuantizerGroup +from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase +from nncf.common.tensor_statistics.statistic_point import StatisticPoint +from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.common.utils.backend import BackendType +from nncf.common.utils.backend import get_backend +from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer +from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.quantization.advanced_parameters import changes_asdict +from nncf.quantization.algorithms.algorithm import Algorithm +from nncf.quantization.fake_quantize import calculate_quantizer_parameters +from nncf.quantization.fake_quantize import get_quantizer_narrow_range +from nncf.quantization.range_estimator import RangeEstimatorParameters +from nncf.quantization.range_estimator import RangeEstimatorParametersSet + +TModel = TypeVar("TModel") + + +class MinMaxRangeEstimator(Algorithm): + def __init__( + self, + quantizer: NNCFQuantizer, + subset_size: int = 300, + inplace_statistics: bool = True, + batchwise_statistics: bool = False, + activations_range_estimator_params: Optional[RangeEstimatorParameters] = None, + weights_range_estimator_params: Optional[RangeEstimatorParameters] = None, + ): + """ + :param subset_size: Size of a subset to calculate activations statistics used + for quantization, defaults to 300. + :param inplace_statistics: Defines wheather to calculate quantizers statistics + by backend graph operations or by default Python implementation, defaults + to True. + :param batchwise_statistics: Determines whether quantizer statistics should be calculated + for each item of the batch or for the entire batch, default is False. + :param activations_range_estimator_params: Quantization range estimation + parameters for activation. + :param weights_range_estimator_params: Quantization range estimation parameters + for weights. + """ + self._quantizer = quantizer + self._subset_size = subset_size + self._inplace_statistics = inplace_statistics + self._batchwise_statistics = batchwise_statistics + self._activations_range_estimator_params = activations_range_estimator_params + self._weights_range_estimator_params = weights_range_estimator_params + + self._range_estimator_params = { + QuantizerGroup.WEIGHTS: self._weights_range_estimator_params, + QuantizerGroup.ACTIVATIONS: self._activations_range_estimator_params, + } + # Calculates global quantizer constraints + self._reset_cache() + self._algorithm_key = f"MMQ_{hash(self)}" + + def _reset_cache(self) -> None: + """ + Marks cache by noninitialized values. Needs to be called when the new quantizer setup is needed. + """ + self._quantization_target_points_to_qconfig: OrderedDict[TargetPoint, QuantizerConfig] = None + self._unified_scale_groups = None + + def _init_cache(self) -> None: + """ + Initializes cache. + """ + self._quantization_target_points_to_qconfig: OrderedDict[TargetPoint, QuantizerConfig] = ( + collections.OrderedDict() + ) + self._unified_scale_groups = [] + + @property + def available_backends(self) -> List[BackendType]: + return [BackendType.TORCH_FX] + + def _set_backend_entity(self, model: TModel) -> None: + """ + Creates a helper class with a backed-specific logic of the algorithm + + :param model: backend-specific input model + """ + model_backend = get_backend(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.common.quantization.algorithms.range_estimator.torch_fx_backend import ( + FXRangeEstimatorAlgoBackend, + ) + + self._backend_entity = FXRangeEstimatorAlgoBackend() + else: + raise nncf.UnsupportedBackendError( + "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) + ) + + def _get_range_estimator_parameters( + self, target_point: TargetPoint, quantizer_config: QuantizerConfig + ) -> RangeEstimatorParameters: + """ + Returns range estimator parameters. + + :param target_point: Quantizer target point. + :param quantizer_config: Quantizer config. + :return: Range estimator parameters. + """ + quantizer_group = QuantizerGroup.ACTIVATIONS + if target_point.is_weight_target_point(): + quantizer_group = QuantizerGroup.WEIGHTS + + if quantizer_group == QuantizerGroup.WEIGHTS or ( + quantizer_group == QuantizerGroup.ACTIVATIONS and quantizer_config.per_channel + ): + params = RangeEstimatorParametersSet.MINMAX + else: + params = RangeEstimatorParametersSet.MEAN_MINMAX + + user_params = self._range_estimator_params[quantizer_group] + if user_params is None: + return deepcopy(params) + + min_changes = changes_asdict(user_params.min) + min_statistic_collector = dataclasses.replace(params.min, **min_changes) + + max_changes = changes_asdict(user_params.max) + max_statistic_collector = dataclasses.replace(params.max, **max_changes) + + return RangeEstimatorParameters(min_statistic_collector, max_statistic_collector) + + def _get_stat_collector( + self, + graph: NNCFGraph, + target_point: TargetPoint, + qconfig: QuantizerConfig, + batchwise_statistics: bool, + ) -> TensorStatisticCollectorBase: + """ + Creates and returns a statistic collector based on the quantizer's configuration. + + :param graph: NNCFGraph instance. + :param target_point: Target point indicates where statistics should be collected. + :param qconfig: Configuration of a quantizer layer, + defining the configuration of created statistic collector. + :param batchwise_statistics: Determines whether quantizer statistics should be calculated + for each item of the batch or for the entire batch. + :return: Statistic Collector. + """ + is_weight = target_point.is_weight_target_point() + node = graph.get_node_by_name(target_point.target_node_name) + shape = self._backend_entity.get_target_point_shape(graph, node, target_point) + range_estimator_params = self._get_range_estimator_parameters(target_point, qconfig) + + channel_axes = () + if qconfig.per_channel: + channel_axes = ( + self._backend_entity.get_weight_quantization_axes(node, target_point, len(shape)) if is_weight else (1,) + ) + + # Weight statistics is constant, so only one collection is enough. + num_samples = self._subset_size if not is_weight else 1 + + batchwise_statistics = batchwise_statistics and not is_weight + + collector_params = RangeInitCollectorParams( + is_weights=is_weight, scheme=qconfig.mode, per_channel=qconfig.per_channel + ) + reduction_axes, aggregation_axes = None, None + if shape is not None: + reduction_axes, aggregation_axes = collector_params.get_reduction_aggregation_axes( + shape, channel_axes, batchwise_statistics + ) + + return self._backend_entity.get_statistic_collector( + range_estimator_params, + collector_params.use_abs_max, + reduction_axes, + aggregation_axes, + self._inplace_statistics, + num_samples=num_samples, + ) + + def _add_weight_quantization_target_point( + self, quantization_point: SingleConfigQuantizationPoint, nncf_graph: NNCFGraph + ) -> None: + """ + Adds weight quantization target point to the set of existing points. + + :param quantization_point: SingleConfigQuantizationPoint for the needed layer. + :param nncf_graph: The built NNCFGraph of the model. + """ + weight_quantization_target_points = self._get_weight_quantization_target_points(quantization_point, nncf_graph) + for weight_quantization_target_point in weight_quantization_target_points: + self._quantization_target_points_to_qconfig[weight_quantization_target_point] = quantization_point.qconfig + + def _add_activation_quantization_target_point( + self, quantization_point: SingleConfigQuantizationPoint, nncf_graph: NNCFGraph + ) -> None: + """ + Adds activation quantization target point to the set of existing points. + + :param quantization_point: SingleConfigQuantizationPoint for the needed layer. + :param nncf_graph: NNCFGraph instance for working with the graph and nodes. + """ + activation_quantization_target_point = self._get_activation_quantization_target_point( + quantization_point, nncf_graph + ) + self._quantization_target_points_to_qconfig[activation_quantization_target_point] = quantization_point.qconfig + + def _get_weight_quantization_target_points( + self, quantization_point: SingleConfigQuantizationPoint, nncf_graph: NNCFGraph + ) -> List[SingleConfigQuantizationPoint]: + """ + Returns weight quantization target points to the set of existing points. + + :param quantization_point: SingleConfigQuantizationPoint for the needed layer. + :param nncf_graph: NNCFGraph instance for working with the graph and nodes. + :return: List of SingleConfigQuantizationPoints for the needed layer. + """ + weight_quantization_target_points = [] + node_name = quantization_point.insertion_point.target_node_name + node = nncf_graph.get_node_by_name(node_name) + weights_port_ids = self._backend_entity.get_weight_tensor_port_ids(node, nncf_graph) + for port_id in weights_port_ids: + weight_quantization_target_points.append( + self._backend_entity.target_point(TargetType.OPERATION_WITH_WEIGHTS, node_name, port_id) + ) + return weight_quantization_target_points + + def _get_activation_quantization_target_point( + self, quantization_point: SingleConfigQuantizationPoint, nncf_graph: NNCFGraph + ) -> SingleConfigQuantizationPoint: + """ + Returns activation quantization target point to the set of existing points. + + :param quantization_point: SingleConfigQuantizationPoint for the needed layer. + :param nncf_graph: NNCFGraph instance for working with the graph and nodes. + :return: SingleConfigQuantizationPoint for the needed layer. + """ + node_name = quantization_point.insertion_point.target_node_name + # If Quantization of node's input + if quantization_point.insertion_point.input_port_id is not None: + input_port_id = quantization_point.insertion_point.input_port_id + activation_quantization_target_point = self._backend_entity.target_point( + TargetType.PRE_LAYER_OPERATION, node_name, input_port_id + ) + # If quantization of node's output or Model Input node + else: + # NOTE: Assumes that the operation has output edges only from one output port because + # we haven't encountered a model with operations that have multiple output edges with different + # output port IDs. Currently, such models are not supported. Usually, `output_port_id = 0` is used. + # However, there are operations, such as LSTMSequence, where the `output_port_id` changes from case + # to case. Therefore, the code below is required to dynamically determine the `output_port_id` where + # the quantize operation should be inserted." + node = nncf_graph.get_node_by_name(node_name) + unique_output_port_ids = set(e.output_port_id for e in nncf_graph.get_output_edges(node)) + if len(unique_output_port_ids) > 1: + nncf_logger.warning( + f"Cannot determine the output_port_id for the operation: {node_name}, " + "output_port_id = 0 will be used." + ) + output_port_id = 0 + else: + output_port_id = next(iter(unique_output_port_ids)) + + activation_quantization_target_point = self._backend_entity.target_point( + TargetType.POST_LAYER_OPERATION, node_name, output_port_id + ) + return activation_quantization_target_point + + def _find_quantization_target_points( + self, model: TModel, nncf_graph: NNCFGraph + ) -> Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]]: + """ + Initializes a cache, finds quantization target points and them puts in the cache. + + :param model: Backend-specific model, for which Quantization Target Points are being seek. + :param nncf_graph: NNCFGraph instance. + :return: Mapping of quantization target points with associated quantization configuration, + along with target points for scale unification. + """ + quantizer_setup = self._quantizer.get_quantization_setup(model, nncf_graph) + self._unified_scale_groups = self._collect_unified_groups(quantizer_setup, nncf_graph) + quantization_points = list(quantizer_setup.quantization_points.values()) + quantization_points = self._topological_sort_quantization_points(quantization_points, nncf_graph) + for quantization_point in quantization_points: + if quantization_point.is_weight_quantization_point(): + self._add_weight_quantization_target_point(quantization_point, nncf_graph) + elif quantization_point.is_activation_quantization_point(): + self._add_activation_quantization_target_point(quantization_point, nncf_graph) + else: + raise nncf.InternalError("Incorrect quantization point") + return self._quantization_target_points_to_qconfig, self._unified_scale_groups + + def _get_quantization_target_points( + self, model: TModel, nncf_graph: NNCFGraph + ) -> Tuple[OrderedDict[TargetPoint, QuantizerConfig], List[List[TargetPoint]]]: + """ + Returns Quantization Target Points. + Returns a cache with target points if exists. Otherwise, initiates a procedure of finding them. + + :param model: Backend-specific model, for which Quantization Target Points are being seek. + :param nncf_graph: NNCFGraph instance. + :return: Mapping of quantization target points with associated quantization configuration, + along with target points for scale unification. + """ + if self._quantization_target_points_to_qconfig is not None: + return self._quantization_target_points_to_qconfig, self._unified_scale_groups + self._init_cache() + return self._find_quantization_target_points(model, nncf_graph) + + def _collect_unified_groups( + self, quantizer_setup: SingleConfigQuantizerSetup, nncf_graph: NNCFGraph + ) -> List[List[TargetPoint]]: + """ + Collects the group of quantizers for unification. + + :param quantizer_setup: SingleConfigQuantizerSetup instance. + :param nncf_graph: NNCFGraph instance. + :return: List with the groups of the TargetPoints. + """ + unified_scale_groups = [] + for quantizer_ids in quantizer_setup.unified_scale_groups.values(): + unified_scale_group = [] + for quantizer_id in quantizer_ids: + quantization_point = quantizer_setup.quantization_points[quantizer_id] + + # Only activation quantizers can be unified + if quantization_point.is_activation_quantization_point(): + activation_target_point = self._get_activation_quantization_target_point( + quantization_point, nncf_graph + ) + unified_scale_group.append(activation_target_point) + else: + weight_target_points = self._get_weight_quantization_target_points(quantization_point, nncf_graph) + for weight_target_point in weight_target_points: + unified_scale_group.append(weight_target_point) + unified_scale_groups.append(unified_scale_group) + return unified_scale_groups + + def _topological_sort_quantization_points( + self, quantization_points: List[SingleConfigQuantizationPoint], nncf_graph: NNCFGraph + ) -> List[SingleConfigQuantizationPoint]: + """ + Sorts quantization_points based on the topological order of nodes obtained form nncf_graph. + + :param quantization_points: Quantization points. + :param nncf_graph: Instance of NNCFgraph used to get topological sort. + :return: Sorted quantization_points. + """ + node_names_to_pos = {node.node_name: i for i, node in enumerate(nncf_graph.topological_sort())} + quantization_points.sort(key=lambda point: node_names_to_pos[point.insertion_point.target_node_name]) + return quantization_points + + def apply( + self, + model: TModel, + graph: NNCFGraph, + statistic_points: Optional[StatisticPointsContainer] = None, + dataset: Optional[Dataset] = None, + ) -> TModel: + transformation_layout = TransformationLayout() + model_transformer = ModelTransformerFactory.create(model) + quantization_target_points, unified_scale_groups = self._get_quantization_target_points(model, graph) + weight_layer_names = set() + + def filter_func(point: StatisticPoint) -> bool: + return ( + self._algorithm_key in point.algorithm_to_tensor_collectors + and point.target_point == quantization_target_point + ) + + unified_ops_list = set() + for unified_scale_group in unified_scale_groups: + group_statistics = [] + for quantization_target_point in unified_scale_group: + target_node_name = quantization_target_point.target_node_name + for tensor_collector in statistic_points.get_algo_statistics_for_node( + target_node_name, filter_func, self._algorithm_key + ): + statistics = tensor_collector.get_statistics() + if statistics.min_values is None or statistics.max_values is None: + raise nncf.InternalError(f"Statistics were not collected for the node {target_node_name}") + group_statistics.append(statistics) + + unified_values = self._unify_statistics(group_statistics) + qconfigs = [quantization_target_points[qtp] for qtp in unified_scale_group] + if any(qconfigs[0] != qconfig for qconfig in qconfigs[1:]): + raise nncf.InternalError(f"QConfigs for unified scale group {unified_scale_group} are not equal") + qconfig = qconfigs[0] + q_group = QuantizerGroup.ACTIVATIONS + narrow_range = get_quantizer_narrow_range(qconfig, q_group) + parameters = calculate_quantizer_parameters(unified_values, qconfig, q_group, narrow_range) + commands = self._backend_entity.create_unified_scales_quantizers_insertion_commands( + graph, unified_scale_group, qconfig, parameters + ) + for command in commands: + transformation_layout.register(command) + unified_ops_list.update(unified_scale_group) + + for quantization_target_point, qconfig in quantization_target_points.items(): + if quantization_target_point in unified_ops_list: + continue + target_node_name = quantization_target_point.target_node_name + for tensor_collector in statistic_points.get_algo_statistics_for_node( + target_node_name, filter_func, self._algorithm_key + ): + if quantization_target_point.is_weight_target_point(): + weights_name = self._backend_entity.get_weight_name(graph, quantization_target_point) + if not self._backend_entity.should_quantize_weight(weights_name, weight_layer_names): + continue + weight_layer_names.add(weights_name) + quant_group = QuantizerGroup.WEIGHTS + else: + quant_group = QuantizerGroup.ACTIVATIONS + + half_range = False + narrow_range = get_quantizer_narrow_range(qconfig, quant_group) + statistics = tensor_collector.get_statistics() + if statistics.min_values is None or statistics.max_values is None: + raise nncf.InternalError(f"Statistics were not collected for the node {target_node_name}") + parameters = calculate_quantizer_parameters(statistics, qconfig, quant_group, narrow_range, half_range) + command = self._backend_entity.create_quantizer_insertion_command( + graph, quantization_target_point, qconfig, parameters + ) + transformation_layout.register(command) + if not transformation_layout.transformations: + nncf_logger.info("The model has no operations to apply quantization.") + quantized_model = model_transformer.transform(transformation_layout) + return quantized_model + + def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: + self._set_backend_entity(model) + self._reset_cache() + quantization_target_points, _ = self._get_quantization_target_points(model, graph) + output = StatisticPointsContainer() + for quantization_target_point, qconfig in quantization_target_points.items(): + nncf_logger.debug( + f"Adding target point {quantization_target_point.target_node_name}" + f" with type {quantization_target_point.type} for statistics collection" + ) + stat_collector = self._get_stat_collector( + graph, quantization_target_point, qconfig, self._batchwise_statistics + ) + output.add_statistic_point( + StatisticPoint( + target_point=quantization_target_point, + tensor_collector=stat_collector, + algorithm=self._algorithm_key, + ) + ) + return output + + @staticmethod + def _unify_statistics(statistics: List[MinMaxTensorStatistic]) -> MinMaxTensorStatistic: + """ + Returns backend-specific unified statistics. + + :param statistics: List of MinMaxTensorStatistic instances. + :return: Unified MinMaxTensorStatistic value. + """ + + max_values, min_values = [], [] + for statistic in statistics: + max_values.append(statistic.max_values.flatten()) + min_values.append(statistic.min_values.flatten()) + max_values = fns.max(fns.stack(max_values), axis=0) + min_values = fns.min(fns.stack(min_values), axis=0) + return MinMaxTensorStatistic(min_values=min_values, max_values=max_values) diff --git a/nncf/experimental/common/quantization/algorithms/range_estimator/torch_fx_backend.py b/nncf/experimental/common/quantization/algorithms/range_estimator/torch_fx_backend.py new file mode 100644 index 00000000000..0e30e70ae57 --- /dev/null +++ b/nncf/experimental/common/quantization/algorithms/range_estimator/torch_fx_backend.py @@ -0,0 +1,221 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Set, Tuple + +import torch +from torch.quantization.fake_quantize import FakeQuantize + +import nncf +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode +from nncf.common.quantization.structs import QuantizerConfig +from nncf.experimental.common.quantization.algorithms.range_estimator.backend import RangeEstimatorAlgoBackend +from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand +from nncf.experimental.torch.fx.model_utils import get_target_point +from nncf.experimental.torch.fx.transformations import qdq_insertion_transformation_builder +from nncf.quantization.advanced_parameters import StatisticsType +from nncf.quantization.fake_quantize import FakeQuantizeParameters +from nncf.quantization.range_estimator import AggregatorType +from nncf.quantization.range_estimator import RangeEstimatorParameters +from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.graph import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.model_graph_manager import get_weight_tensor_port_ids +from nncf.torch.quantization.layers import QUANTIZATION_MODULES +from nncf.torch.quantization.layers import AsymmetricQuantizer +from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.quantization.layers import get_scale_shape +from nncf.torch.quantization.strip import convert_to_torch_fakequantizer +from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP + + +class FXRangeEstimatorAlgoBackend(RangeEstimatorAlgoBackend): + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + return get_target_point(target_type, target_node_name, port_id) + + @staticmethod + def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point: PTTargetPoint) -> Tuple[int, ...]: + return nncf_graph.get_input_shape_for_insertion_point(target_point) + + @staticmethod + def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]: + # TODO(dlyakhov): support transpose conv and other cases + return (0,) + + @staticmethod + def get_statistic_collector( + range_estimator_params: RangeEstimatorParameters, + use_abs_max: bool, + reduction_axes: Optional[Tuple[int, ...]], + aggregation_axes: Optional[Tuple[int, ...]], + inplace: bool, + num_samples: Optional[int] = None, + ) -> TensorCollector: + collector = TensorCollector(MinMaxTensorStatistic) + for params, container_key in zip( + [range_estimator_params.min, range_estimator_params.max], + [MinMaxTensorStatistic.MIN_STAT, MinMaxTensorStatistic.MAX_STAT], + ): + if params.statistics_type not in PT_REDUCERS_MAP: + raise nncf.InternalError( + f"Statistic type: {params.statistics_type} is not supported for Torch PTQ backend yet." + ) + + if params.aggregator_type not in AGGREGATORS_MAP: + raise nncf.InternalError( + f"Aggregator type: {params.aggregator_type} is not supported for Torch PTQ backend yet." + ) + + statistic_type = params.statistics_type + if statistic_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: + # TODO(dlyakhov): merge two quantile aggregators in one + if container_key == MinMaxTensorStatistic.MIN_STAT: + quantile = params.quantile_outlier_prob + else: + quantile = 1 - params.quantile_outlier_prob + reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes, quantile=[quantile]) + else: + if use_abs_max and statistic_type == StatisticsType.MAX: + statistic_type = StatisticsType.ABS_MAX + reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes) + + kwargs = { + "num_samples": num_samples, + "aggregation_axes": aggregation_axes, + } + if params.aggregator_type in [AggregatorType.MEAN_NO_OUTLIERS, AggregatorType.MEDIAN_NO_OUTLIERS]: + kwargs.update({"quantile": params.quantile_outlier_prob}) + aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs) + + collector.register_statistic_branch(container_key, reducer, aggregator) + return collector + + @staticmethod + def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]: + return get_weight_tensor_port_ids(node, graph) + + @staticmethod + def get_weight_name(nncf_graph: NNCFGraph, target_point: PTTargetPoint) -> str: + weighted_node = nncf_graph.get_node_by_name(target_point.target_node_name) + weight_edge = nncf_graph.get_input_edge_by_port_id(weighted_node, target_point.input_port_id) + weight = weight_edge.from_node + return weight.node_name + + @staticmethod + def should_quantize_weight(weight_name: str, quantized_weight_names: Set[str]) -> bool: + # If the nodes share one weight tensor, we should have only one quantizer on that + return weight_name not in quantized_weight_names + + @staticmethod + def _get_input_scale_shape( + nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: + is_weights = target_point.is_weight_target_point() + if is_weights: + # TODO(dlyakhov): support transpose conv/ make channel_idx common + channel_idx = 0 + else: + channel_idx = 1 # channel dim for activations + + input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) + scale_shape = tuple( + get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx) + ) + + return input_shape, scale_shape, channel_idx + + @staticmethod + def _create_quantizer( + quantizer_config: QuantizerConfig, + scale_shape: Tuple, + parameters: FakeQuantizeParameters, + target_type: TargetType, + ) -> FakeQuantize: + mode = quantizer_config.mode + quantizer_cls = QUANTIZATION_MODULES.get(mode) + narrow_range = target_type == TargetType.OPERATION_WITH_WEIGHTS and mode == QuantizationMode.SYMMETRIC + quantizer_spec = PTQuantizerSpec.from_config( + quantizer_config, + narrow_range=narrow_range, + scale_shape=scale_shape, + half_range=False, + logarithm_scale=False, + is_quantized_on_export=False, + compression_lr_multiplier=None, + ) + quantizer = quantizer_cls(quantizer_spec) + + # Fill it with minmax + # TODO(dlyakhov) Prevent creation of intermediate objects like nncf quantizer. + FXRangeEstimatorAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape) + # Convert to the torch fake quantizer + torch_fq = convert_to_torch_fakequantizer(quantizer) + return torch_fq + + @staticmethod + def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None: + if isinstance(quantizer, AsymmetricQuantizer): + quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape)) + input_range = parameters.input_high - parameters.input_low + # Subtract eps from the input_range to make quantizer parameters equal to + # original parameters on the forward call. + quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape)) + else: + quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) + # Subtract eps from the scale to make quantizer parameters equal to + # original parameters on the forward call. + quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape)) + + @staticmethod + def create_quantizer_insertion_command( + nncf_graph: NNCFGraph, + target_point: PTTargetPoint, + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> FXApplyTransformationCommand: + _, scale_shape, _ = FXRangeEstimatorAlgoBackend._get_input_scale_shape( + nncf_graph, target_point, quantizer_config.per_channel + ) + + quantizer = FXRangeEstimatorAlgoBackend._create_quantizer( + quantizer_config, scale_shape, parameters, target_point.target_type + ) + transformation = qdq_insertion_transformation_builder(quantizer, [target_point]) + return FXApplyTransformationCommand(transformation) + + @staticmethod + def create_unified_scales_quantizers_insertion_commands( + nncf_graph: NNCFGraph, + target_points: List[PTTargetPoint], + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> List[PTSharedFnInsertionCommand]: + _, scale_shape, _ = FXRangeEstimatorAlgoBackend._get_input_scale_shape( + nncf_graph, target_points[0], quantizer_config.per_channel + ) + + quantizer = FXRangeEstimatorAlgoBackend._create_quantizer( + quantizer_config, scale_shape, parameters, target_points[0].target_type + ) + + transformations = [] + for tp in target_points: + transformation = qdq_insertion_transformation_builder(quantizer, [tp]) + transformations.append(FXApplyTransformationCommand(transformation)) + return transformations diff --git a/nncf/experimental/torch/fx/quantization/quantize_pt2e.py b/nncf/experimental/torch/fx/quantization/quantize_pt2e.py new file mode 100644 index 00000000000..efa32af48d6 --- /dev/null +++ b/nncf/experimental/torch/fx/quantization/quantize_pt2e.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Optional + +import torch +import torch.fx +from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass +from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ +from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat +from torch.ao.quantization.pt2e.utils import _disallow_eval_train +from torch.ao.quantization.quantizer import Quantizer +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_manager import PassManager + +from nncf.common.factory import NNCFGraphFactory +from nncf.common.logging import nncf_logger +from nncf.data import Dataset +from nncf.experimental.common.quantization.algorithms.post_training.algorithm import PostTrainingQuantization +from nncf.experimental.common.quantization.algorithms.quantizer.fx_quantizer import NNCFFXQuantizer +from nncf.experimental.torch.fx.transformations import fuse_conv_bn +from nncf.parameters import ModelType +from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters + +DEFAULT_RANGE_TYPE = "mean_min_max" + + +def quantize_pt2e( + model: torch.fx.GraphModule, + quantizer: Quantizer, + calibration_dataset: Dataset, + subset_size: int = 300, + fast_bias_correction: bool = True, + model_type: Optional[ModelType] = None, + advanced_parameters: Optional[AdvancedQuantizationParameters] = None, +) -> torch.fx.GraphModule: + """ + Implementation of the `quantize()` method for the Torch FX backend. + """ + nncf_logger.warning( + "Experimental Torch FX quantization backend is being used for the given torch.fx.GraphModule model." + " Torch FX PTQ is an experimental feature, consider using Torch or OpenVino PTQ backends" + " in case of errors or a poor model performance." + ) + + original_graph_meta = model.meta + + copied_model = deepcopy(model) + + quantization_algorithm = PostTrainingQuantization( + quantizer=NNCFFXQuantizer(quantizer), + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + advanced_parameters=advanced_parameters, + ) + + # To make it easier for bias correction algorithms, + # biases are being separated by the followng calls. + fuse_conv_bn(copied_model) + + nncf_graph = NNCFGraphFactory.create(copied_model) + quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset) + + # Magic. Without this call compiled model + # is not preformant + quantized_model = GraphModule(quantized_model, quantized_model.graph) + + quantized_model = _fold_conv_bn_qat(quantized_model) + pm = PassManager([DuplicateDQPass()]) + + quantized_model = pm(quantized_model).graph_module + pm = PassManager([PortNodeMetaForQDQ()]) + quantized_model = pm(quantized_model).graph_module + + quantized_model.meta.update(original_graph_meta) + quantized_model = _disallow_eval_train(quantized_model) + # Each transformation adds a duplicate tensor value to the model buffer. + # This step removes the duplicates tensor values from the buffer. + quantized_model = GraphModule(quantized_model, quantized_model.graph) + + return quantized_model diff --git a/torch_fx_experimental_q.py b/torch_fx_experimental_q.py new file mode 100644 index 00000000000..77ca8859442 --- /dev/null +++ b/torch_fx_experimental_q.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ["TORCHINDUCTOR_FREEZING"] = "1" + +from time import time + +import torch +import torch.fx +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e +from torch.ao.quantization.quantize_pt2e import prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.ao.quantization.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config +from torchvision import models + +import nncf +import nncf.torch +from nncf.experimental.torch.fx.quantization.quantize_pt2e import quantize_pt2e +from tests.torch.fx.helpers import visualize_fx_model + + +def measure_time(model, example_inputs, num_iters=3000): + with torch.no_grad(): + model(*example_inputs) + total_time = 0 + for _ in range(num_iters): + start_time = time() + model(*example_inputs) + total_time += time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def main(model_cls): + model = model_cls() + example_inputs = torch.ones((1, 3, 224, 224)) + exported_model = capture_pre_autograd_graph(model.eval(), (example_inputs,)) + + quantizer = X86InductorQuantizer() + quantizer.set_global(get_default_x86_inductor_quantization_config()) + + nncf_quantizer_model = quantize_pt2e(exported_model, quantizer, calibration_dataset=nncf.Dataset([example_inputs])) + + visualize_fx_model(nncf_quantizer_model, "nncf_quantizer_before_fold_resnet.svg") + return nncf_quantizer_model + + # exported_model = capture_pre_autograd_graph(model.eval(), (example_inputs,)) + # nncf_int8 = nncf.quantize(exported_model, nncf.Dataset([example_inputs])) + # visualize_fx_model(nncf_int8, "nncf_resnet.svg") + + +def main_native(model_cls): + model = model_cls() + example_inputs = torch.ones((1, 3, 224, 224)) + exported_model = capture_pre_autograd_graph(model.eval(), (example_inputs,)) + + quantizer = X86InductorQuantizer() + quantizer.set_global(get_default_x86_inductor_quantization_config()) + + prepared_model = prepare_pt2e(exported_model, quantizer) + prepared_model(example_inputs) + converted_model = convert_pt2e(prepared_model) + visualize_fx_model(converted_model, "x86int8_resnet.svg") + return converted_model + + +def constant_fold(m): + pass + + +if __name__ == "__main__": + with nncf.torch.disable_patching(): + for model_cls in (models.resnet18, models.mobilenet_v3_small, models.vit_b_16, models.swin_v2_s): + # for model_cls in (models.mobilenet_v3_small,): + print(f"{model_cls} check!") + nncf_q_model = main(model_cls) + + constant_fold(nncf_q_model) + visualize_fx_model(nncf_q_model, "nncf_quantizer_after_constant_fold_resnet.svg") + + pt_q_model = main_native(model_cls) + print("benchmarking...") + pt_compiled = torch.compile(model_cls()) + pt_int8_compiled = torch.compile(pt_q_model) + nncf_comipled = torch.compile(nncf_q_model) + + example_inputs = (torch.ones((1, 3, 224, 224)),) + + pt_time = measure_time(pt_compiled, example_inputs) + print(f"PT fp32 performance measured: {pt_time}") + + pt_int8_time = measure_time(pt_int8_compiled, example_inputs) + print(f"PT int8 performance measured: {pt_int8_time}") + + nncf_int8_time = measure_time(nncf_comipled, example_inputs) + print(f"NNCF int8 performance measured: {nncf_int8_time}")