Skip to content

Commit

Permalink
x_max -> x_mean and w_max -> w_mean name changes and some comments (#378
Browse files Browse the repository at this point in the history
)
  • Loading branch information
OscarSavolainen authored Mar 2, 2024
1 parent d9dc8e5 commit f713b88
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,23 @@ def _search_best_scale(
# Put x on the right device
inp = inp.to(next(module2inspect.parameters()).device)

# [STEP 1]: Compute maximum of weight
# [STEP 1]: Compute per-channel mean of normalised weights
# All layer weights are concatted together
weight = torch.cat([_m.weight for _m in layers], dim=0)
org_shape = weight.shape
# The weights are reshaped to be organised by quantization group
weight = weight.view(-1, self.group_size)
# Calculates the relative magnitude of the weights within each of the quantization groups,
# and rescales each group individually so that each group has weights on a 0-1 scale.
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
# Resizes the rescaled weight matrix back up to its original dimensions
w_scale = w_scale.view(org_shape)
w_max = w_scale.mean(0)
# Gets the average rescaled magnitude for each output channel
w_mean = w_scale.mean(0)
clear_memory(weight)

# [STEP 2]: Compute maximum of x
x_max = inp.abs().view(-1, inp.shape[-1]).mean(0)
# [STEP 2]: Compute per-channel mean of the input activation
x_mean = inp.abs().view(-1, inp.shape[-1]).mean(0)

# [STEP 3]: Compute output of module
with torch.no_grad():
Expand All @@ -266,7 +272,7 @@ def _search_best_scale(

# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect, layers, fp16_output, module_kwargs
inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
)

return (
Expand All @@ -278,8 +284,8 @@ def _search_best_scale(
def _compute_best_scale(
self,
x,
w_max,
x_max,
w_mean,
x_mean,
module2inspect,
linears2scale: List[nn.Linear],
fp16_output,
Expand All @@ -303,18 +309,18 @@ def _compute_best_scale(
org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}

device = x.device
x_max = x_max.view(-1).to(device)
w_max = w_max.view(-1).to(device)
x_mean = x_mean.view(-1).to(device)
w_mean = w_mean.view(-1).to(device)

for ratio in range(n_grid):
# create new scales
ratio = ratio / n_grid

# NOTE: s^-1 * x is fused here, according to paper
if self.duo_scaling:
scales = (x_max.pow(ratio) / w_max.pow(1 - ratio)).clamp(min=1e-4)
scales = (x_mean.pow(ratio) / w_mean.pow(1 - ratio)).clamp(min=1e-4)
else:
scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device)

Expand Down

0 comments on commit f713b88

Please sign in to comment.