From f41d5e5d25b69bf6b9cfd2b9f663dee59ce580ed Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Mon, 23 Dec 2024 13:23:38 -0500 Subject: [PATCH] [Misc] Add assertion and helpful message for marlin24 compressed models (#11388) --- .../compressed_tensors/schemes/compressed_tensors_w4a16_24.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 9ad61a64e406c..61d1c911cd1ad 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -61,6 +61,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + assert params_dtype == torch.float16, ( + "float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501 + ) + pack_factor = 32 // self.quant_type.size_bits output_size_per_partition = sum(output_partition_sizes)