diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 50c7a1d3fbb20..2e425e2375539 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -134,20 +134,6 @@ def create_weights(self, layer: torch.nn.Module, output_size: int, # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # convert `weight_packed` from: - # {input_dim = 1, output_dim = 0, packed_dim = 1} - # to: - # {input_dim = 0, output_dim = 1, packed_dim = 0} - # expected the kernels `process_weights_after_loading` - replace_parameter(layer, "weight_packed", layer.weight_packed.t()) - - # convert `weight_scale` from: - # {input_dim = 1, output_dim = 0} - # to: - # {input_dim = 0, output_dim = 1} - # expected the kernels `process_weights_after_loading` - replace_parameter(layer, "weight_scale", layer.weight_scale.t()) - self.kernel.process_weights_after_loading(layer) diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py index a8cf3c0d669d7..dcbc9f0e69e54 100644 --- a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -45,7 +45,7 @@ def __init__(self, self.w_zp_name = w_zp_param_name self.w_gidx_name = w_gidx_param_name - # note assumes that + # note assumes that (if the they are not ModelWeightParameters) # `getattr(layer, w_q_name)` is: # {input_dim = 0, output_dim = 1, packed_dim = 0} # `getattr(layer, w_s_name)` is: diff --git a/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py index d936549a912e6..e583b1df3c784 100644 --- a/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py @@ -1,9 +1,9 @@ from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.machete_utils import ( MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, query_machete_supported_quant_types) - +from vllm.model_executor.parameter import (ModelWeightParameter, + PackedvLLMParameter) from .MPLinearKernel import * @@ -43,11 +43,23 @@ def can_implement(cls, # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module): + def transform_w_q(x): + # TODO (lucas): assert isinstance(x, PackedvLLMParameter) once + # everything is migrated to using weight_loader_v2 + if isinstance(x, PackedvLLMParameter): + x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0) + return ops.machete_prepack_B(x.t().contiguous().t(), + self.config.weight_type) + def transform_w_s(x): + # TODO (lucas): assert isinstance(x, PackedvLLMParameter) once + # everything is migrated to using weight_loader_v2 + if isinstance(x, ModelWeightParameter): + x = x.permute_layout(input_dim=0, output_dim=1) + return x.contiguous() + # Repack weights and scales for Machete - self._transform_param( - layer, self.w_q_name, lambda x: ops.machete_prepack_B( - x.t().contiguous().t(), self.config.weight_type)) - self._transform_param(layer, self.w_s_name, lambda x: x.contiguous()) + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) def apply_weights(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py index 83e31f85c48b7..9da712b7e0dda 100644 --- a/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py +++ b/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py @@ -6,7 +6,8 @@ marlin_is_k_full, query_marlin_supported_quant_types) from .MPLinearKernel import * - +from vllm.model_executor.parameter import (ModelWeightParameter, + PackedvLLMParameter) class MarlinLinearKernel(MPLinearKernel): @@ -80,20 +81,31 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) - self._transform_param(layer, self.w_q_name, lambda x: \ - ops.gptq_marlin_repack( + def transform_w_q(x): + # TODO (lucas): assert isinstance(x, PackedvLLMParameter) once + # everything is migrated to using weight_loader_v2 + if isinstance(x, PackedvLLMParameter): + x = x.permute_layout(input_dim=0, output_dim=1, packed_dim=0) + return ops.gptq_marlin_repack( x.contiguous(), perm=layer.g_idx_sort_indices, size_k=c.partition_weight_shape[0], size_n=c.partition_weight_shape[1], - num_bits=c.weight_type.size_bits)) - - self._transform_param(layer, self.w_s_name, lambda x: \ - marlin_permute_scales( + num_bits=c.weight_type.size_bits) + + def transform_w_s(x): + # TODO (lucas): assert isinstance(x, PackedvLLMParameter) once + # everything is migrated to using weight_loader_v2 + if isinstance(x, ModelWeightParameter): + x = x.permute_layout(input_dim=0, output_dim=1) + return marlin_permute_scales( x.contiguous(), size_k=c.partition_weight_shape[0], size_n=c.partition_weight_shape[1], - group_size=c.group_size)) + group_size=c.group_size) + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) def apply_weights(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 10239843b3222..47b36bae93a8e 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -142,6 +142,25 @@ class ModelWeightParameter(_ColumnvLLMParameter): def __init__(self, input_dim: int, **kwargs): self._input_dim = input_dim super().__init__(**kwargs) + + def permute_layout(self, input_dim: int, output_dim: int, **kwargs) \ + -> 'ModelWeightParameter': + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim + # preservier other dimensions + perm = [i for i in range(self.data.dim()) + if i not in [self.input_dim, self.output_dim] + ] + perm.insert(input_dim, self.input_dim) + perm.insert(output_dim, self.output_dim) + + return ModelWeightParameter( + data=self.data.permute(*perm).contiguous(), + weight_loader=self.weight_loader, + input_dim=input_dim, + output_dim=output_dim, + **kwargs) @property def input_dim(self): @@ -253,6 +272,23 @@ def __init__(self, self._marlin_tile = marlin_tile_size super().__init__(**kwargs) + def permute_layout(self, input_dim: int, output_dim: int, + packed_dim: int = 0, + **kwargs)\ + -> 'ModelWeightParameter': + + assert packed_dim == packed_dim + + return PackedvLLMParameter( + data=ModelWeightParameter\ + .permute_layout(self, input_dim, output_dim).data, + weight_loader=self.weight_loader, + input_dim=input_dim, + output_dim=output_dim, + packed_dim=self.packed_dim, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile) + @property def packed_dim(self): return self._packed_dim