Skip to content

Commit

Permalink
Qualcomm AI Engine Direct - Quantizer refine for qat
Browse files Browse the repository at this point in the history
Differential Revision: D65738212

Pull Request resolved: pytorch#6747
  • Loading branch information
chunit-quic authored Nov 18, 2024
1 parent 19268de commit e95f171
Show file tree
Hide file tree
Showing 12 changed files with 797 additions and 602 deletions.

Large diffs are not rendered by default.

10 changes: 4 additions & 6 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
from typing import Sequence

import torch
from executorch.backends.qualcomm.quantizer.annotators import QUANT_ANNOTATION_KEY
from executorch.backends.qualcomm.quantizer.quantizer import (
get_16a8w_qnn_ptq_config,
get_default_8bit_qnn_ptq_config,
QuantizationConfig,
)
from executorch.backends.qualcomm.quantizer.utils import (
get_8a8w_qnn_ptq_config,
get_ptq_per_channel_quant_config,
QUANT_ANNOTATION_KEY,
QuantizationConfig,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.ao.quantization.quantizer import (
Expand Down Expand Up @@ -113,7 +111,7 @@ def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig):
# Annotate 16a8w for matmul op to get better performance
quantization_config_16a8w = get_16a8w_qnn_ptq_config()
# Annotate 8a8w for second input of matmul until past_kv_cache
quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True)
quantization_config_8a8w = get_8a8w_qnn_ptq_config(act_symmetric=True)
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
if "nn_module_stack" in node.meta:
Expand Down
108 changes: 108 additions & 0 deletions backends/qualcomm/quantizer/observers/per_channel_param_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
from torch.ao.quantization.observer import UniformQuantizationObserverBase


# TODO move to torch/ao/quantization/observer.py.
class PerChannelParamObserver(UniformQuantizationObserverBase):
"""
Minimize quantization loss caused by outlier via linear search. More details can be found at https://arxiv.org/pdf/2209.13325
"""

def __init__(
self,
ch_axis=0,
use_mse=True,
steps=100,
dtype=torch.int8,
qscheme=torch.per_channel_symmetric,
reduce_range=False,
quant_min=None,
quant_max=None,
factory_kwargs=None,
eps=torch.finfo(torch.float32).eps, # noqa: B008
is_dynamic=False,
**kwargs,
) -> None:
super().__init__(
dtype=dtype,
qscheme=qscheme,
reduce_range=reduce_range,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs,
eps=eps,
is_dynamic=is_dynamic,
**kwargs,
)

factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
self.ch_axis = ch_axis
self.use_mse = use_mse
self.steps = steps
self.calibrated = False

def to_ch_axis(self, x):
axis_order = list(range(len(x.size())))
axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis
return torch.flatten(x.permute(axis_order), start_dim=1)

def mse(self, pred, expect):
loss = (pred - expect).abs().pow(2)
return self.to_ch_axis(loss).mean(1)

def cosine(self, pred, expect):
target = torch.ones(pred.shape[self.ch_axis])
pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1)
expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1)
return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target)

def loss_fn(self, x, new_min, new_max):
scale, offset = self._calculate_qparams(new_min, new_max)
x_q = torch.fake_quantize_per_channel_affine(
x,
scale.data,
offset.data.int(),
self.ch_axis,
self.quant_min,
self.quant_max,
)
return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x)

def line_search(self, x):
x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1)
x_range = torch.max(x_min.abs(), x_max)
optimal_loss = torch.zeros_like(x_min) + 1e9

# check which clip range could produce smallest loss
for i in range(1, self.steps + 1):
thres = x_range / self.steps * i
current_loss = self.loss_fn(x, -thres, thres)
x_min = torch.where(current_loss < optimal_loss, -thres, x_min)
x_max = torch.where(current_loss < optimal_loss, thres, x_max)
optimal_loss = torch.min(current_loss, optimal_loss)

return x_min, x_max

def forward(self, x_orig):
# since params are static, one calibration is enough
if not self.calibrated:
x = x_orig.detach().to(self.min_val.dtype)
self.min_val, self.max_val = self.line_search(x)
self.calibrated = True

# return fake-quant result for saturating outliers
scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)
return torch.fake_quantize_per_channel_affine(
x_orig,
scale.data,
zero_point.data.int(),
self.ch_axis,
self.quant_min,
self.quant_max,
)

@torch.jit.export
def calculate_qparams(self):
return self._calculate_qparams(self.min_val, self.max_val)
Loading

0 comments on commit e95f171

Please sign in to comment.