Skip to content

Commit

Permalink
[Model][Quantization] HQQ support through Marlin kernel expansion (#9766
Browse files Browse the repository at this point in the history
)

Signed-off-by: ElizaWszola <[email protected]>
  • Loading branch information
ElizaWszola authored Nov 19, 2024
1 parent efa9084 commit b00b33d
Show file tree
Hide file tree
Showing 11 changed files with 632 additions and 89 deletions.
3 changes: 2 additions & 1 deletion benchmarks/kernels/benchmark_machete.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
size_m=bt.a.shape[0],
size_n=bt.w_ref.shape[1],
size_k=bt.w_ref.shape[0],
is_k_full=True)
is_k_full=True,
is_zp_float=False)
else:
assert bt.a.dtype == torch.int8
assert bt.wtype == scalar_types.uint4b8
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand All @@ -141,7 +141,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
Expand Down
277 changes: 200 additions & 77 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
"int b_q_type, "
"SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor");
"bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
// conditionally compiled so impl registration is in source file

// gptq_marlin repack from GPTQ.
Expand Down
88 changes: 87 additions & 1 deletion tests/kernels/test_marlin_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
marlin_qqq_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
from vllm.scalar_type import scalar_types

ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
Expand All @@ -40,6 +41,8 @@
MARLIN_24_K_CHUNKS = [128]
MARLIN_24_N_CHUNKS = [512]

HQQ_SUPPORTED_GROUP_SIZES = [64]

MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
Expand Down Expand Up @@ -226,7 +229,7 @@ def test_gptq_marlin_gemm(
torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce),
a_input.shape[1], is_k_full, False, use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)

output = ops.gptq_marlin_gemm(
Expand All @@ -244,6 +247,7 @@ def test_gptq_marlin_gemm(
is_k_full=is_k_full,
has_zp=False,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)

Expand Down Expand Up @@ -441,6 +445,7 @@ def test_awq_marlin_gemm(
is_k_full=is_k_full,
has_zp=has_zp,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
output_ref = torch.matmul(a_input, w_ref)

Expand All @@ -451,6 +456,87 @@ def test_awq_marlin_gemm(
assert max_diff < 0.04


@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
def test_hqq_marlin_gemm(
k_chunk,
n_chunk,
group_size,
mnk_factors,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

quant_type = scalar_types.uint4

a_input = rand_data((size_m, size_k))
dev = a_input.device

b_weight = torch.randint(0,
10, (size_n, size_k),
dtype=torch.uint8,
device=dev)
scale = rand_data((size_n, size_k // group_size))
zero = rand_data((size_n, size_k // group_size))

gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)

sort_indices = torch.empty(0, dtype=torch.int, device=dev)
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
4).to(dev)
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
group_size).to(dev)
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
group_size).to(dev)

g_idx = marlin_make_empty_g_idx(dev)
g_idx_sort_indices = marlin_make_empty_g_idx(dev)

workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

output = ops.gptq_marlin_gemm(
a_input,
marlin_w_q,
marlin_s,
marlin_zp,
g_idx,
g_idx_sort_indices,
workspace.scratch,
quant_type,
a_input.shape[0],
b_weight.shape[0],
a_input.shape[1],
is_k_full=True,
has_zp=True,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=True,
)

b_flat = b_weight.reshape(-1, group_size)
zp_flat = zero.reshape(-1, 1)
s_flat = scale.reshape(-1, 1)
dequant = (b_flat - zp_flat) * s_flat

output_ref = torch.matmul(a_input,
dequant.reshape(b_weight.shape).transpose(1, 0))

torch.cuda.synchronize()

max_diff = compute_max_diff(output, output_ref)

assert max_diff < 0.04


@pytest.mark.skipif(not is_quant_method_supported("qqq"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
Expand Down
3 changes: 2 additions & 1 deletion tests/weight_loading/models.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
qqq, HandH1998/QQQ-Llama-3-8b-g128, main
qqq, HandH1998/QQQ-Llama-3-8b, main
qqq, HandH1998/QQQ-Llama-3-8b, main
hqq, nm-testing/Llama-3.2-1B-Instruct-HQQ, main
8 changes: 5 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,8 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor,
size_k: torch.SymInt,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)

@register_fake("_C::ggml_dequantize")
Expand Down Expand Up @@ -601,11 +602,12 @@ def gptq_marlin_gemm(a: torch.Tensor,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, b_q_type.id,
size_m, size_n, size_k, is_k_full,
has_zp, use_fp32_reduce)
has_zp, use_fp32_reduce, is_zp_float)


# fp8 marlin
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod"
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
"HQQMarlinMethod"
]


Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
Expand Down Expand Up @@ -48,6 +49,7 @@
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
Expand Down
Loading

0 comments on commit b00b33d

Please sign in to comment.