Skip to content

Commit

Permalink
fix SQ auto bug (#1294)
Browse files Browse the repository at this point in the history
Signed-off-by: wenhuach21 <[email protected]>
(cherry picked from commit 1730de0)
  • Loading branch information
wenhuach21 authored and chensuyue committed Sep 28, 2023
1 parent e9c14a5 commit 35def7b
Showing 1 changed file with 24 additions and 20 deletions.
44 changes: 24 additions & 20 deletions neural_compressor/adaptor/torch_utils/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,20 +794,20 @@ def dict_to_list(dic):
raise NotImplementedError
return best_alpha

def _auto_tune_alpha_new(
def _auto_tune_alpha(
self, input_maxes, calib_sample_num=32, alpha_min=0.3, alpha_max=0.7, alpha_step=0.05, shared_criterion="min"
):
"""Perform alpha-tuning to obtain layer-wise optimal alpha values and adjust parameters accordingly.
This function takes quantization of the former layers into consideration when qdq one layer
Also, it reduces the memory usage at the cost of increasingtuning time
TODO may have compatibility issue when setting folding=True
:param input_maxes:
:param calib_sample_num:
:param alpha_min:
:param alpha_max:
:param alpha_step:
:param shared_criterion:
TODO may have compatibility issue when setting folding=True, check whether having issues when bs!=1
:param input_maxes: calibration data, input max
:param calib_sample_num: sample count used to auto tuning alpha
:param alpha_min: the min value of alpha
:param alpha_max: the max value of alpha
:param alpha_step: the alpha step in search space
:param shared_criterion: the criterion to choose alpha when multiple layers must share one same alpha
:return:
"""
logger.info("start sq auto tuning")
Expand All @@ -830,13 +830,16 @@ def _auto_tune_alpha_new(
self.absorb_to_layer, input_maxes, default_alpha, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
cnt = 0
total_cnt = 0
tmp_cnt = 0
alpha_update_iter = 0
# multiply_factor is used to combine samples to calib_sample_num // 4 before summarizing the best alpha
multiply_factor = calib_sample_num // 4 if calib_sample_num >= 4 else calib_sample_num
tune_cnt = 4
multiply_factor = calib_sample_num // tune_cnt if calib_sample_num >= tune_cnt else calib_sample_num

best_alphas = default_alpha
if not self.dataloader:
logger.info(f"Auto-tuning failed due to no dataloader, using {best_alphas} instead.")
self._qdq_model_unwrapper_for_auto()
return best_alphas
try:
Expand All @@ -857,18 +860,19 @@ def _auto_tune_alpha_new(
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
cnt += self.dataloader.batch_size
if cnt // multiply_factor >= 1:
total_cnt += self.dataloader.batch_size
tmp_cnt += self.dataloader.batch_size
if tmp_cnt // multiply_factor >= 1:
alpha_update_iter += 1
cnt = 0
tmp_cnt = 0
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
logger.info(f"Auto alpha update iter: {alpha_update_iter}, {key}: {best_alphas[key]}")
absorb_input_scales, weight_scales = self._cal_scales(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
if cnt >= calib_sample_num:
if total_cnt >= calib_sample_num:
break
except:
for input in self.dataloader:
Expand All @@ -888,10 +892,11 @@ def _auto_tune_alpha_new(
cur_loss = loss_alphas[key]
for alpha_key in cur_loss.keys():
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
cnt += self.dataloader.batch_size
if cnt // multiply_factor >= 1:
total_cnt += self.dataloader.batch_size
tmp_cnt += self.dataloader.batch_size
if tmp_cnt // multiply_factor >= 1:
alpha_update_iter += 1
cnt = 0
tmp_cnt = 0

best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
for key in best_alphas.keys():
Expand All @@ -900,7 +905,7 @@ def _auto_tune_alpha_new(
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
)
self._update_scales_for_auto(absorb_input_scales, weight_scales)
if cnt >= calib_sample_num:
if total_cnt >= calib_sample_num:
break

best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
Expand Down Expand Up @@ -934,7 +939,6 @@ def transform(
logger.warning("smooth quant is ignored since the model is not a torch module")
return self.model

logger.info("call new sq") ##TODO need to remove later
if folding:
self.insert_mul, self.allow_absorb = False, True
else:
Expand Down Expand Up @@ -994,7 +998,7 @@ def transform(
del self.absorb_to_layer[d]

if alpha == "auto":
self.alpha_per_layer = self._auto_tune_alpha_new(
self.alpha_per_layer = self._auto_tune_alpha(
input_maxes_abs, calib_sample_num=32, **auto_alpha_args
) ##save the alpha

Expand Down

0 comments on commit 35def7b

Please sign in to comment.