diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index e2604d4bed3e2..9d404f138a963 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -49,8 +49,7 @@ torch::Tensor run_impl(PyTorchArguments args) { torch::empty({M, N}, torch::TensorOptions() .dtype(equivalent_scalar_type_v) .device(device)); - - auto const &A = args.A, &B = args.B; + *auto const &A = args.A, &B = args.B; auto const &C = args.C, &scales = args.scales, &zeros = args.zeros; auto layout_A = make_cute_layout(A, "A"); diff --git a/examples/offline_inference.py b/examples/offline_inference.py index a1f262208442e..df303229118b9 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,12 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="nm-testing/tinyllama-oneshot-w4a16-group128-v2") +# GPTQ = "kaitchup/Llama-2-7b-gptq-3bit" +# marlin = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ" +# machete = "TheBloke/Llama-2-7B-GPTQ" +# machete/marlin CT = "nm-testing/tinyllama-oneshot-w4a16-group128-v2" +# "nm-testing/tinyllama-oneshot-w4a16-channel-v2" +llm = LLM(model="nm-testing/tinyllama-oneshot-w4a16-channel-v2") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index aa04fcf8310bf..afa713f4cd43d 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -217,6 +217,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) + print(layer.qzeros) + print(hex(layer.qzeros[0][0].to(torch.uint32).item())) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 8e59e8ff019db..fed393dd97771 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,12 +1,9 @@ from typing import Any, Dict, List, Optional import torch -from torch.nn import Parameter from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization.kernels import ( - choose_mp_linear_kernel, MPLinearLayerConfig) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kernels import ( diff --git a/vllm/model_executor/layers/quantization/kernels/GPTQLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/GPTQLinearKernel.py new file mode 100644 index 0000000000000..d8b2de6141d63 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/GPTQLinearKernel.py @@ -0,0 +1,85 @@ +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.parameter import (ModelWeightParameter, + PackedvLLMParameter) + +from .MPLinearKernel import * + + +class GPTQLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if c.act_type != torch.half: + return False, f"Act type {c.act_type} currently not supported by GPTQLinearKernel" + + if c.zero_points: + return False, "Zero points currently not supported by GPTQLinearKernel" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + + def transform_w_q(x): + # TODO (lucas): assert isinstance(x, PackedvLLMParameter) once + # everything is migrated to using weight_loader_v2 + if isinstance(x, PackedvLLMParameter): + x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0) + return ops.machete_prepack_B(x.t().contiguous().t(), + self.config.weight_type) + + def transform_w_s(x): + # TODO (lucas): assert isinstance(x, PackedvLLMParameter) once + # everything is migrated to using weight_loader_v2 + if isinstance(x, ModelWeightParameter): + x = x.permute_layout(input_dim=0, output_dim=1) + return x.contiguous() + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + output = ops.machete_gemm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_zeros=None, + b_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index eb491dd1554a8..373151a5311e5 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -27,6 +27,8 @@ class scalar_types: float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value) # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) uint4b8 = ScalarType.uint(4, 8) uint8b128 = ScalarType.uint(8, 128)