Skip to content

Commit

Permalink
[Bugfix] Fix GGUF inference with FP16 unquantized checkpoint (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#10675)

Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py authored Nov 27, 2024
1 parent c411def commit b98c62b
Showing 1 changed file with 60 additions and 9 deletions.
69 changes: 60 additions & 9 deletions vllm/model_executor/layers/quantization/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import gguf
import torch
from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm import _custom_ops as ops
Expand Down Expand Up @@ -49,19 +50,65 @@ def get_quant_method(self, layer: torch.nn.Module,
return None


UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
WeightType.Q4_0,
WeightType.Q4_1,
WeightType.Q5_0,
WeightType.Q5_1,
WeightType.Q8_0,
WeightType.Q8_1,
}
KQUANT_TYPES = {
WeightType.Q2_K,
WeightType.Q3_K,
WeightType.Q4_K,
WeightType.Q5_K,
WeightType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
WeightType.IQ1_M,
WeightType.IQ1_S,
WeightType.IQ2_XXS,
WeightType.IQ2_XS,
WeightType.IQ2_S,
WeightType.IQ3_XXS,
WeightType.IQ3_S,
WeightType.IQ4_XS,
WeightType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES


def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
# use dequantize mulmat for IQmatrix, mmq for k-quants
if x.shape[0] == 1:
# enable mmvq in contiguous batching
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# enable MMVQ in contiguous batching with batch_size=1
if x.shape[0] == 1 and qweight_type in MMVQ_QUANT_TYPES:
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
elif qweight_type >= 16:
# Use MMQ Kernel if it's available (standard + k-quants)
elif qweight_type in MMQ_QUANT_TYPES:
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.T
else:
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = WeightType(qweight_type)
raise NotImplementedError(
f"Unsupported GGUF quantization type: {qweight_type}")
return y


Expand Down Expand Up @@ -121,9 +168,9 @@ def apply(self,
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
qweight = layer.qweight.unbind(0)
result = []
for id in shard_id:
q_idx = layer.qweight.shard_id_map[id]
qweight_type = layer.qweight_type.shard_weight_type[id]
for idx in shard_id:
q_idx = layer.qweight.shard_id_map[idx]
qweight_type = layer.qweight_type.shard_weight_type[idx]
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
out = torch.cat(result, axis=1)
else:
Expand Down Expand Up @@ -163,9 +210,13 @@ class GGUFUninitializedParameter(UninitializedParameter):
data_container: List[torch.Tensor]

def materialize_nested(self) -> Parameter:
dtype = {data.dtype for data in self.data_container}
assert len(dtype) == 1, ValueError(
f"Data container has mixed dtypes: {dtype}")
dtype = next(iter(dtype))
nested_data = torch.nested.nested_tensor(self.data_container,
device=self.device,
dtype=torch.uint8)
dtype=dtype)
self.data_container.clear()
param = torch.Tensor._make_subclass(self.cls_to_become,
nested_data,
Expand Down

0 comments on commit b98c62b

Please sign in to comment.