Skip to content

Commit

Permalink
Assymetric activations are forced
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jun 3, 2024
1 parent 35ec4b0 commit 3b1e7f0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
13 changes: 13 additions & 0 deletions nncf/experimental/torch_fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import nncf
from nncf.common.factory import NNCFGraphFactory
from nncf.common.quantization.structs import QuantizationPreset
from nncf.common.quantization.structs import QuantizationScheme
from nncf.data import Dataset
from nncf.parameters import CompressWeightsMode
from nncf.parameters import ModelType
Expand All @@ -32,6 +33,7 @@
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import QuantizationParameters
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression
from nncf.scopes import IgnoredScope
Expand Down Expand Up @@ -66,6 +68,17 @@ def quantize_impl(
copied_model = deepcopy(model)
# copied_model = model

if advanced_parameters is None:
advanced_parameters = AdvancedQuantizationParameters()
# torch.fx supports only assymetric activations quantization
# force to use only this type of quantization
activations_quantization_params = advanced_parameters.activations_quantization_params
if activations_quantization_params is None:
activations_quantization_params = QuantizationParameters()

activations_quantization_params.mode = QuantizationScheme.ASYMMETRIC
advanced_parameters.activations_quantization_params = activations_quantization_params

quantization_algorithm = PostTrainingQuantization(
preset=preset,
target_device=target_device,
Expand Down
7 changes: 2 additions & 5 deletions nncf/experimental/torch_fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,8 @@ def insert_one_qdq(

# Quantized functions accepts only uint8 as an input
if target_point.target_type != TargetType.OPERATION_WITH_WEIGHTS and qparams["_dtype_"] == torch.int8:
qparams["_zero_point_"] = qparams["_zero_point_"] - qparams["_quant_min_"]
quants_len = qparams["_quant_max_"] - qparams["_quant_min_"]
qparams["_quant_min_"] = 0
qparams["_quant_max_"] = quants_len
qparams["_dtype_"] = torch.uint8
raise RuntimeError("Wrong parameters: activations should always be uint8")

# TODO: map FakeQuantizePramaeters to qparams for quantize/dequantize
# 2. replace activation_post_process node with quantize and dequantize
graph = model.graph
Expand Down

0 comments on commit 3b1e7f0

Please sign in to comment.