Skip to content

Commit

Permalink
De-constant weight bits and pack factor
Browse files Browse the repository at this point in the history
Signed-off-by: ElizaWszola <[email protected]>
  • Loading branch information
ElizaWszola committed Nov 5, 2024
1 parent 293ca37 commit 7692ff0
Showing 1 changed file with 35 additions and 23 deletions.
58 changes: 35 additions & 23 deletions vllm/model_executor/layers/quantization/hqq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def __init__(
weight_bits: int,
group_size: int,
) -> None:
self.pack_factor = 8 // weight_bits # packed into uint8
self.weight_bits = weight_bits
self.group_size = group_size
self.pack_factor = 32 // weight_bits # packed into int32 in GPTQ format
self.quant_type = self.TYPE_MAP[(weight_bits)]

def __repr__(self) -> str:
Expand Down Expand Up @@ -107,40 +108,49 @@ class HQQQweightParameter(PackedvLLMParameter):
def unpack_4bit_u8(self,
W_q: torch.Tensor) -> torch.Tensor: # uint8/2 > uint8
dtype = torch.uint8
_step = W_q.shape[0]
tmp = torch.empty([2 * _step, W_q.shape[1]],
step = W_q.shape[0]
tmp = torch.empty([2 * step, W_q.shape[1]],
dtype=dtype,
device=W_q.device)
tmp[:_step] = (W_q & 0b11110000) >> 4
tmp[_step:] = W_q & 0b00001111
tmp[:step] = (W_q & 0b11110000) >> 4
tmp[step:] = W_q & 0b00001111
return tmp

def __init__(self, packed_factor: int, packed_dim: int, **kwargs):
def unpack_u8(self, W_q: torch.Tensor) -> torch.Tensor:
assert self.weight_bits == 4, "Unsupported quant bitsize (must be 4)"
return self.unpack_4bit_u8(W_q)

def __init__(self, packed_factor: int, packed_dim: int, weight_bits: int,
**kwargs):
super().__init__(packed_factor, packed_dim, None, **kwargs)
self.weight_bits = weight_bits
self.input_shape = self.shape[self.input_dim] * self.packed_factor
self.output_shape = self.shape[self.output_dim]

def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = self.unpack_4bit_u8(loaded_weight)
loaded_weight = self.unpack_u8(loaded_weight)
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
1, 0)
loaded_weight = gptq_pack(loaded_weight, 4, loaded_weight.shape[0],
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
loaded_weight.shape[0],
loaded_weight.shape[1])
super().load_merged_column_weight(loaded_weight, **kwargs)

def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
loaded_weight = self.unpack_4bit_u8(loaded_weight)
loaded_weight = self.unpack_u8(loaded_weight)
loaded_weight = loaded_weight.reshape(self.output_shape,
-1).transpose(1, 0)
loaded_weight = gptq_pack(loaded_weight, 4, loaded_weight.shape[0],
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
loaded_weight.shape[0],
loaded_weight.shape[1])
super().load_row_parallel_weight(loaded_weight)

def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = self.unpack_4bit_u8(loaded_weight)
loaded_weight = self.unpack_u8(loaded_weight)
loaded_weight = loaded_weight.reshape(-1, self.input_shape).transpose(
1, 0)
loaded_weight = gptq_pack(loaded_weight, 4, loaded_weight.shape[0],
loaded_weight = gptq_pack(loaded_weight, self.weight_bits,
loaded_weight.shape[0],
loaded_weight.shape[1])
super().load_qkv_weight(loaded_weight, **kwargs)

Expand Down Expand Up @@ -190,16 +200,18 @@ def create_weights(
self.scales_and_zp_size = (input_size_per_partition //
self.quant_config.group_size)

qweight = HQQQweightParameter(data=torch.empty(
self.input_size_per_partition // 8,
self.output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=8,
weight_loader=weight_loader)
qweight = HQQQweightParameter(
data=torch.empty(
self.input_size_per_partition // self.quant_config.pack_factor,
self.output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_bits=self.quant_config.weight_bits,
weight_loader=weight_loader)

zeros = HQQZeroScaleParameter(data=torch.empty(
self.output_size_per_partition,
Expand Down Expand Up @@ -248,7 +260,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
sort_indices,
self.input_size_per_partition,
self.output_size_per_partition,
4,
self.quant_config.weight_bits,
).to(dev)
marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0),
self.input_size_per_partition,
Expand Down

0 comments on commit 7692ff0

Please sign in to comment.