Skip to content

Commit

Permalink
Experimental quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 28, 2024
1 parent e6bf1d5 commit d1941f3
Show file tree
Hide file tree
Showing 14 changed files with 9,653 additions and 217 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,61 +16,60 @@
from nncf.common.graph.graph import NNCFGraph
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.experimental.common.quantization.algorithms.post_training.pipeline import create_ptq_pipeline
from nncf.experimental.common.quantization.algorithms.post_training.pipeline import experimental_create_ptq_pipeline
from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer
from nncf.parameters import ModelType
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import RangeEstimatorParameters
from nncf.quantization.algorithms.algorithm import Algorithm

TModel = TypeVar("TModel")
TPass = Callable[[TModel], TModel]


class PostTrainingQuantization(Algorithm):
class ExperimentalPostTrainingQuantization(Algorithm):
"""
Implements Post-Training Quantization algorithm, which basically includes:
Implements Experimental Post-Training Quantization algorithm, which basically includes:
1) ChannelAlignment
2) MinMaxQuantization
2) MinMaxRangeInit
3) FastBiasCorrection or BiasCorrection
"""

def __init__(
self,
quantizer: NNCFQuantizer,
subset_size: int = 300,
fast_bias_correction: bool = True,
model_type: Optional[ModelType] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
fast_bias_correction: Optional[bool] = True,
smooth_quant: bool = False,
bias_correction_params: Optional[AdvancedBiasCorrectionParameters] = None,
smooth_quant_params: Optional[AdvancedSmoothQuantParameters] = None,
activations_range_estimator_params: Optional[RangeEstimatorParameters] = None,
weights_range_estimator_params: Optional[RangeEstimatorParameters] = None,
):
"""
:param mode: Special quantization mode that specify different ways of the optimization.
:param preset: A preset controls the quantization mode (symmetric and asymmetric).
It can take the following values:
- `performance`: Symmetric quantization of weights and activations.
- `mixed`: Symmetric quantization of weights and asymmetric quantization of activations.
Default value is None. In this case, `mixed` preset is used for `transformer`
model type otherwise `performace`.
:param target_device: A target device the specificity of which will be taken
into account while compressing in order to obtain the best performance
for this type of device.
:param quantizer: NNCFQuantizer to use in MiMaxRageInit algorithm.
:param subset_size: Size of a subset to calculate activations
statistics used for quantization.
:param fast_bias_correction: Setting this option to `False` enables a different
bias correction method which is more accurate, in general, and takes
more time but requires less memory.
:param model_type: Model type is needed to specify additional patterns
in the model. Supported only `transformer` now.
:param ignored_scope: An ignored scope that defined the list of model control
flow graph nodes to be ignored during quantization.
:param advanced_parameters: Advanced quantization parameters for
fine-tuning the quantization algorithm
more time but requires less memory. None disables the bias correction algorithm.
:param smooth_quant: Setting this option to `True` enables the SmoothQuant algorithm.
:param bias_correction_params: Contains advanced parameters for fine-tuning bias correction algorithm.
:param smooth_quant_params: Contains advanced alpha parameters for SmoothQuant algorithm.
:param activations_range_estimator_params: Contains parameters for estimating the range
of activations of the model.
:param weights_range_estimator_params: Contains parameters for estimating the range
of weights of the model.
"""
self._pipeline = create_ptq_pipeline(
self._pipeline = experimental_create_ptq_pipeline(
quantizer=quantizer,
subset_size=subset_size,
fast_bias_correction=fast_bias_correction,
model_type=model_type,
advanced_parameters=advanced_parameters,
smooth_quant=smooth_quant,
bias_correction_params=bias_correction_params,
smooth_quant_params=smooth_quant_params,
activations_range_estimator_params=activations_range_estimator_params,
weights_range_estimator_params=weights_range_estimator_params,
)

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@

from typing import Optional, TypeVar

from nncf.common.deprecation import warning_deprecated
from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer
from nncf.experimental.common.quantization.algorithms.range_estimator.range_estimator import MinMaxRangeEstimator
from nncf.parameters import ModelType
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import RangeEstimatorParameters
from nncf.quantization.algorithms.bias_correction.algorithm import BIAS_CORRECTION_THRESHOLD
from nncf.quantization.algorithms.bias_correction.algorithm import BiasCorrection
from nncf.quantization.algorithms.channel_alignment.algorithm import ChannelAlignment
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FAST_BIAS_CORRECTION_THRESHOLD
from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection
from nncf.quantization.algorithms.pipeline import Pipeline
Expand All @@ -27,93 +26,67 @@
TModel = TypeVar("TModel")


def create_ptq_pipeline(
def experimental_create_ptq_pipeline(
quantizer: NNCFQuantizer,
subset_size: int = 300,
fast_bias_correction: bool = True,
model_type: Optional[ModelType] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
fast_bias_correction: Optional[bool] = True,
smooth_quant: bool = False,
bias_correction_params: Optional[AdvancedBiasCorrectionParameters] = None,
smooth_quant_params: Optional[AdvancedSmoothQuantParameters] = None,
activations_range_estimator_params: Optional[RangeEstimatorParameters] = None,
weights_range_estimator_params: Optional[RangeEstimatorParameters] = None,
) -> Pipeline:
"""
Creates a post-training quantization pipeline.
Creates an experimental post-training quantization pipeline.
The post-training quantization pipeline includes the following steps:
The experimental post-training quantization pipeline includes the following steps:
1) SmoothQuant
2) ChannelAlignment
3) MinMaxQuantization
4) FastBiasCorrection or BiasCorrection
:param mode: Special quantization mode that specify different ways of the optimization.
:param preset: A preset controls the quantization mode (symmetric and asymmetric).
It can take the following values:
- `performance`: Symmetric quantization of weights and activations.
- `mixed`: Symmetric quantization of weights and asymmetric quantization of activations.
Default value is None. In this case, `mixed` preset is used for `transformer`
model type otherwise `performace`.
:param target_device: A target device the specificity of which will be taken
into account while compressing in order to obtain the best performance
for this type of device.
2) MinMaxRangeInit
3) FastBiasCorrection or BiasCorrection
:param quantizer: NNCFQuantizer to use in MiMaxRageInit algorithm.
:param subset_size: Size of a subset to calculate activations
statistics used for quantization.
:param fast_bias_correction: Setting this option to `False` enables a different
bias correction method which is more accurate, in general, and takes
more time but requires less memory.
:param model_type: Model type is needed to specify additional patterns
in the model. Supported only `transformer` now.
:param advanced_parameters: Advanced quantization parameters for
fine-tuning the quantization algorithm
:return: A post-training quantization pipeline.
more time but requires less memory. None disables the bias correction algorithm.
:param smooth_quant: Setting this option to `True` enables the SmoothQuant algorithm.
:param bias_correction_params: Contains advanced parameters for fine-tuning bias correction algorithm.
:param smooth_quant_params: Contains advanced alpha parameters for SmoothQuant algorithm.
:param activations_range_estimator_params: Contains parameters for estimating the range
of activations of the model.
:param weights_range_estimator_params: Contains parameters for estimating the range
of weights of the model.
:return: An experimental post-training quantization pipeline.
"""

if advanced_parameters is None:
advanced_parameters = AdvancedQuantizationParameters()

# Build the post-training quantization pipeline.
pipeline_steps = []

# Add the `SmoothQuant` algorithm as the first step of the pipeline.
# It is added only for `ModelType.TRANSFORMER`.
sq_params = advanced_parameters.smooth_quant_alphas
sq_alpha = advanced_parameters.smooth_quant_alpha
if sq_alpha is not None:
warning_deprecated(
"`AdvancedQuantizationParameters(smooth_quant_alpha=..)` is deprecated."
"Please, use `AdvancedQuantizationParameters(smooth_quant_alphas)` option "
"with AdvancedSmoothQuantParameters(convolution=.., matmul=..) as value instead."
)
if sq_alpha < 0:
sq_params.convolution = -1
sq_params.matmul = -1
else:
sq_params.matmul = sq_alpha

if model_type == ModelType.TRANSFORMER and (sq_params.convolution >= 0 or sq_params.matmul >= 0):
alpha_map = {"convolution": sq_params.convolution, "matmul": sq_params.matmul}
pipeline_steps.append([SmoothQuant(subset_size, advanced_parameters.inplace_statistics, alpha_map=alpha_map)])
if smooth_quant_params is None:
smooth_quant_params = AdvancedSmoothQuantParameters()

# Add the `ChannelAlignment` algorithm as the second step of the pipeline.
if not advanced_parameters.disable_channel_alignment:
pipeline_steps.append([ChannelAlignment(subset_size, advanced_parameters.inplace_statistics)])
if smooth_quant and smooth_quant_params.convolution >= 0 or smooth_quant_params.matmul >= 0:
alpha_map = {"convolution": smooth_quant_params.convolution, "matmul": smooth_quant_params.matmul}
pipeline_steps.append([SmoothQuant(subset_size, False, alpha_map=alpha_map)])

# Add the `MinMaxQuantization` algorithm as the third step of the pipeline.
pipeline_steps.append(
[
MinMaxRangeEstimator(
quantizer=quantizer,
subset_size=subset_size,
inplace_statistics=advanced_parameters.inplace_statistics,
batchwise_statistics=advanced_parameters.batchwise_statistics,
activations_range_estimator_params=advanced_parameters.activations_range_estimator_params,
weights_range_estimator_params=advanced_parameters.weights_range_estimator_params,
inplace_statistics=False,
activations_range_estimator_params=activations_range_estimator_params,
weights_range_estimator_params=weights_range_estimator_params,
)
]
)

if not advanced_parameters.disable_bias_correction:
if fast_bias_correction is not None:
# Add the `FastBiasCorrection` or `BiasCorrection` as additional algorithm
# inside the third step of the pipeline. It is added after `MinMaxQuantization`
# algorithm.
bias_correction_params = advanced_parameters.bias_correction_params
if fast_bias_correction:
threshold = FAST_BIAS_CORRECTION_THRESHOLD
bias_correction_subset_size = subset_size
Expand All @@ -123,6 +96,9 @@ def create_ptq_pipeline(
bias_correction_subset_size = max(int(subset_size * 0.2), 1)
bias_correction_cls = BiasCorrection

if bias_correction_params is None:
bias_correction_params = AdvancedBiasCorrectionParameters()

if bias_correction_params.threshold is not None:
threshold = bias_correction_params.threshold

Expand All @@ -131,8 +107,6 @@ def create_ptq_pipeline(
bias_correction_subset_size,
threshold,
bias_correction_params.apply_for_all_nodes,
advanced_parameters.inplace_statistics,
advanced_parameters.backend_params,
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@

from collections import defaultdict
from copy import deepcopy
from typing import Dict, Tuple, Union

import torch
import torch.fx
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_group_id
from torch.ao.quantization.pt2e.prepare import _get_edge_or_node_to_qspec
from torch.ao.quantization.pt2e.prepare import _get_obs_or_fq_map
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec
from torch.ao.quantization.quantizer.quantizer import QuantizationSpecBase
from torch.ao.quantization.quantizer.quantizer import SharedQuantizationSpec

import nncf
Expand All @@ -32,6 +31,8 @@
from nncf.common.quantization.structs import QuantizerConfig
from nncf.experimental.common.quantization.algorithms.quantizer.quantizer import NNCFQuantizer

EdgeOrNode = Union[Tuple[torch.fx.Node, torch.fx.Node]]


class NNCFFXQuantizer(NNCFQuantizer):
def __init__(self, quantizer: Quantizer):
Expand All @@ -47,12 +48,7 @@ def get_quantization_setup(self, model: torch.fx.GraphModule, nncf_graph: NNCFGr

@staticmethod
def get_quantizer_config_from_anotated_model(anotated_model: torch.fx.GraphModule) -> SingleConfigQuantizerSetup:
is_qat = False
edge_or_node_to_qspec = _get_edge_or_node_to_qspec(anotated_model)
edge_or_node_to_group_id = _get_edge_or_node_to_group_id(edge_or_node_to_qspec)
obs_or_fq_map = _get_obs_or_fq_map(edge_or_node_to_group_id, edge_or_node_to_qspec, is_qat)
if obs_or_fq_map:
pass

q_map = defaultdict(list)
for edge, qspec in edge_or_node_to_qspec.items():
Expand Down Expand Up @@ -108,3 +104,26 @@ def get_quantizer_config_from_anotated_model(anotated_model: torch.fx.GraphModul
raise nncf.InternalError(f"Unknown torch.ao quantization spec: {qspec}")

return q_setup


def _get_edge_or_node_to_qspec(
model: torch.fx.GraphModule,
) -> Dict[EdgeOrNode, QuantizationSpecBase]:
"""
Get a map from EdgeOrNode to quantization spec based on annotations on the nodes.
:param model: torch.fx.GraphModule instance.
:return: A map from EdgeOrNode to quantization spec based on annotations on the nodes.
"""
edge_or_node_to_qspec: Dict[EdgeOrNode, QuantizationSpecBase] = {}
for n in model.graph.nodes:
if hasattr(n, "meta") and "quantization_annotation" in n.meta:
qa = n.meta["quantization_annotation"]
for input_to_n, qspec in qa.input_qspec_map.items():
input_edge = (input_to_n, n)
edge_or_node_to_qspec[input_edge] = qspec
if qa.output_qspec is not None:
output_node = n
qspec = qa.output_qspec
edge_or_node_to_qspec[output_node] = qspec
return edge_or_node_to_qspec
44 changes: 34 additions & 10 deletions nncf/experimental/torch/fx/quantization/quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@
from nncf.common.factory import NNCFGraphFactory
from nncf.common.logging import nncf_logger
from nncf.data import Dataset
from nncf.experimental.common.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.experimental.common.quantization.algorithms.post_training.algorithm import (
ExperimentalPostTrainingQuantization,
)
from nncf.experimental.common.quantization.algorithms.quantizer.fx_quantizer import NNCFFXQuantizer
from nncf.experimental.torch.fx.constant_folding import constant_fold
from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS
from nncf.experimental.torch.fx.transformations import fuse_conv_bn
from nncf.parameters import ModelType
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters

DEFAULT_RANGE_TYPE = "mean_min_max"
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import RangeEstimatorParameters


def quantize_pt2e(
Expand All @@ -40,8 +43,12 @@ def quantize_pt2e(
calibration_dataset: Dataset,
subset_size: int = 300,
fast_bias_correction: bool = True,
model_type: Optional[ModelType] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
smooth_quant: bool = False,
bias_correction_params: Optional[AdvancedBiasCorrectionParameters] = None,
smooth_quant_params: Optional[AdvancedSmoothQuantParameters] = None,
activations_range_estimator_params: Optional[RangeEstimatorParameters] = None,
weights_range_estimator_params: Optional[RangeEstimatorParameters] = None,
fold_quantize: Optional[bool] = False,
) -> torch.fx.GraphModule:
"""
Implementation of the `quantize()` method for the Torch FX backend.
Expand All @@ -56,12 +63,15 @@ def quantize_pt2e(

copied_model = deepcopy(model)

quantization_algorithm = PostTrainingQuantization(
quantization_algorithm = ExperimentalPostTrainingQuantization(
quantizer=NNCFFXQuantizer(quantizer),
subset_size=subset_size,
fast_bias_correction=fast_bias_correction,
model_type=model_type,
advanced_parameters=advanced_parameters,
smooth_quant=smooth_quant,
bias_correction_params=bias_correction_params,
smooth_quant_params=smooth_quant_params,
activations_range_estimator_params=activations_range_estimator_params,
weights_range_estimator_params=weights_range_estimator_params,
)

# To make it easier for bias correction algorithms,
Expand All @@ -76,6 +86,9 @@ def quantize_pt2e(
quantized_model = GraphModule(quantized_model, quantized_model.graph)

quantized_model = _fold_conv_bn_qat(quantized_model)
if fold_quantize:
constant_fold(quantized_model, _quant_node_constraint)

pm = PassManager([DuplicateDQPass()])

quantized_model = pm(quantized_model).graph_module
Expand All @@ -89,3 +102,14 @@ def quantize_pt2e(
quantized_model = GraphModule(quantized_model, quantized_model.graph)

return quantized_model


def _quant_node_constraint(n: torch.fx.Node) -> bool:
"""If there is any pure ops between get_attr and quantize op they will be const propagated
e.g. get_attr(weight) -> transpose -> quantize -> dequantize*
(Note: dequantize op is not going to be constant propagated)
This filter is added because we don't want to constant fold the things that are not
related to quantization
"""
return n.op == "call_function" and n.target in QUANTIZE_NODE_TARGETS
Loading

0 comments on commit d1941f3

Please sign in to comment.