Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AutoAWQ smooth + INC RTN (HPU) #637

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
build:
CUDA_VISIBLE_DEVICES=-1 pip install -e . -vvv

.PHONY: build
9 changes: 7 additions & 2 deletions awq/modules/linear/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from awq.utils.module import try_import
from awq.utils.utils import get_best_device
from awq.utils.packing_utils import dequantize_gemm
import logging
logger = logging.getLogger(__name__)

# NOTE: We check if awq_ext or triton is available. awq_ext will be preferred if both are installed.

Expand Down Expand Up @@ -199,6 +201,7 @@ def from_linear(
/ awq_linear.scales[idx // group_size]
).to(torch.int)[:, None]
)
logger.warning("Got int weight...")
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32)
Expand All @@ -225,7 +228,7 @@ def from_linear(
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
awq_linear.qweight = qweight

zeros = zeros.to(dtype=torch.int32, device=best_device)
zeros = zeros.to(dtype=torch.int32, device="cpu")

if "mps" in best_device:
zeros = zeros.to("cpu")
Expand All @@ -235,7 +238,7 @@ def from_linear(
dtype=torch.int32,
device=zeros.device,
)

logger.warning("PACK Qzeros...")
for col in range(zeros.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
Expand All @@ -244,7 +247,9 @@ def from_linear(
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
logger.warning("PACK Qzeros done...")
awq_linear.qzeros = qzeros
awq_linear = awq_linear.to(best_device)

return awq_linear

Expand Down
127 changes: 82 additions & 45 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
set_op_by_name,
exclude_layers_to_not_quantize,
)

import logging
logger = logging.getLogger(__name__)
import habana_frameworks.torch.core as htcore

class AwqQuantizer:
def __init__(
Expand Down Expand Up @@ -70,13 +72,17 @@ def __init__(
n_samples=self.max_calib_samples, max_seq_len=self.max_calib_seq_len
)

def pseudo_quantize_tensor(self, w: torch.Tensor):
def pseudo_quantize_tensor(self, w: torch.Tensor, return_int=False):
org_w_shape = w.shape
if self.group_size > 0:
assert org_w_shape[-1] % self.group_size == 0
w = w.reshape(-1, self.group_size)
if torch.isnan(w).sum() > 0:
breakpoint()
logging.error(f"Found {torch.isnan(w).sum()} NaNs in weight matrix")
assert w.dim() == 2
assert torch.isnan(w).sum() == 0
# breakpoint()

# zero point quantization
if self.zero_point:
Expand All @@ -86,9 +92,8 @@ def pseudo_quantize_tensor(self, w: torch.Tensor):
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
w = (
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
) * scales
w_int = torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
w = w_int * scales
zeros = zeros.view(org_w_shape[0], -1)
else:
max_val = w.abs().amax(dim=1, keepdim=True)
Expand All @@ -97,14 +102,19 @@ def pseudo_quantize_tensor(self, w: torch.Tensor):
min_int = -(2 ** (self.w_bit - 1))
scales = max_val / max_int
zeros = None
w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales

w_int = torch.clamp(torch.round(w / scales), min_int, max_int)
w = w_int * scales
if torch.isnan(w).sum() > 0:
breakpoint()
logging.error(f"Found {torch.isnan(w).sum()} NaNs in weight matrix {w.shape}")
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0

scales = scales.view(org_w_shape[0], -1)
w = w.reshape(org_w_shape)


if return_int:
return w, scales, zeros, w_int.reshape(org_w_shape)
return w, scales, zeros

def pseudo_dequantize_tensor(
Expand All @@ -124,7 +134,10 @@ def pseudo_dequantize_tensor(
return w

def quantize(self):
self._num_modules = len(self.modules)
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# if i > 1:
# return
# Move module and inputs to correct device
common_device = next(self.modules[i].parameters()).device
if common_device is None or str(common_device) == "cpu":
Expand Down Expand Up @@ -171,6 +184,7 @@ def quantize(self):
scales_list = append_str_prefix(
scales_list, get_op_name(self.model, self.modules[i]) + "."
)
logger.warning(f"Applied scales: {scales_list}")

# [STEP 3]: Compute and apply clipping list
if self.apply_clip:
Expand Down Expand Up @@ -199,43 +213,61 @@ def pack(self):

def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
for name, linear_layer in named_linears.items():
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer = linear_layer.to(get_best_device()).half()

linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
linear_layer.weight.data
)

if self.version == "gemm":
scales = scales.t().contiguous()
if zeros is not None:
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM

elif self.version == "gemv":
q_linear_module = WQLinear_GEMV

elif self.version == "marlin":
q_linear_module = WQLinear_Marlin

elif self.version == "gemv_fast":
q_linear_module = WQLinear_GEMVFast

else:
raise ValueError(f"Unknown version {self.version}")

q_linear = q_linear_module.from_linear(
linear=linear_layer,
w_bit=self.w_bit,
group_size=self.group_size,
init_only=False,
scales=scales,
zeros=zeros,
)

linear_layer.cpu()
q_linear.to(next(module.parameters()).device)
logger.warning(f"Quantizing {name}")
# linear_layer = linear_layer.cpu().half()
# # NOTE: small regression in perplexity if linear layer uses .cpu().float()
# # linear_layer = linear_layer.to(get_best_device()).half()

# linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
# linear_layer.weight.data
# )

# if self.version == "gemm":
# scales = scales.t().contiguous()
# if zeros is not None:
# zeros = zeros.t().contiguous()
# q_linear_module = WQLinear_GEMM

# elif self.version == "gemv":
# q_linear_module = WQLinear_GEMV

# elif self.version == "marlin":
# q_linear_module = WQLinear_Marlin

# elif self.version == "gemv_fast":
# q_linear_module = WQLinear_GEMVFast

# else:
# raise ValueError(f"Unknown version {self.version}")
# linear_layer = linear_layer.cpu()
# from neural_compressor.torch.algorithms.weight_only.rtn import RTNQuantizer
# from neural_compressor.torch.quantization.config import RTNConfig
# config = RTNConfig(group_size=self.group_size, bits=self.w_bit)
# config_dict = config.to_dict()
# config_dict["scheme"] = "sym" # ?
# rtn_quantizer = RTNQuantizer(quant_config={'': config_dict})
# q_linear = rtn_quantizer.quantize(linear_layer)
# # breakpoint()
# # breakpoint()

# # q_linear = linear_layer

# # q_linear = q_linear_module.from_linear(
# # linear=linear_layer,
# # w_bit=self.w_bit,
# # group_size=self.group_size,
# # init_only=False,
# # scales=scales,
# # zeros=zeros,
# # )
# logger.warning(f"got q_linear {q_linear}")

# linear_layer.cpu()
# q_linear.to(next(module.parameters()).device)
q_linear = linear_layer
set_op_by_name(module, name, q_linear)
# set_op_by_name(module, name, q_linear)
logger.warning(f"update {name} to {q_linear}")
clear_memory()

@torch.no_grad()
Expand Down Expand Up @@ -325,11 +357,13 @@ def _search_best_scale(
with torch.no_grad():
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
fp16_output = self._module_forward(inp, module2inspect, module_kwargs)
htcore.mark_step()

# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
)
htcore.mark_step()

return (
get_op_name(module, prev_op),
Expand Down Expand Up @@ -367,7 +401,8 @@ def _compute_best_scale(
device = x.device
x_mean = x_mean.view(-1).to(device)
w_mean = w_mean.view(-1).to(device)


logger.warning("Searching for best scale")
for ratio in range(n_grid):
# create new scales
ratio = ratio / n_grid
Expand Down Expand Up @@ -450,6 +485,7 @@ def _search_best_clip(self, layer, named_linears, input_feat):
avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]

for name in named_linears:
logger.warning(f"Searching for best clip: {name}")
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in avoid_clipping]):
continue
Expand Down Expand Up @@ -594,6 +630,7 @@ def forward(self, *args, **kwargs):
return modules, layer_kwargs, inps

def _get_input_feat(self, layer, named_linears):
logger.warning("Computing input features for layer %s", layer)
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
Expand Down
6 changes: 5 additions & 1 deletion awq/quantize/scale.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from turtle import ht
from venv import logger
import torch
import torch.nn as nn
from typing import Tuple, List
Expand All @@ -10,6 +12,7 @@
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
from transformers.models.cohere.modeling_cohere import CohereLayerNorm
from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation
import habana_frameworks.torch.core as htcore

allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm, Gemma2RMSNorm, CohereLayerNorm]
allowed_act_fns = [
Expand All @@ -36,6 +39,7 @@ def apply_clip(module, clip_list: Tuple[str, torch.Tensor]):

def apply_scale(module, scales_list, input_feat_dict=None):
for prev_op_name, layer_names, scales in scales_list:
logger.warning(f"Apply scale {prev_op_name} -> {layer_names}")
prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names]

Expand Down Expand Up @@ -77,7 +81,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
if layer_name in input_feat_dict:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))

htcore.mark_step()
prev_op.cpu()
for layer in layers:
layer.cpu()
Expand Down
15 changes: 15 additions & 0 deletions awq/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def clear_memory(weight=None):
if weight is not None:
del weight
gc.collect()
if is_hpex_available():
# import habana_frameworks.torch.core as htcore
# torch.hpu.empty_cache()
return
torch.cuda.empty_cache()


Expand All @@ -86,9 +90,20 @@ def compute_memory_used_pct(device):
return memory_pct


def is_hpex_available():
try:
import habana_frameworks.torch.core as htcore
HPEX_AVAILABLE = True
except ImportError:
HPEX_AVAILABLE = False
return HPEX_AVAILABLE

def get_best_device():
if torch.backends.mps.is_available():
return "mps"
elif is_hpex_available():
# FIXME: return device name with index?
return "hpu"
elif torch.cuda.is_available():
return "cuda:0"
else:
Expand Down
Loading