diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 1b0453c2bd6f8..fc259589d777d 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -9,10 +9,11 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_tensor from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.scalar_type import scalar_types diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 11566a8ffe852..f086bce9666d3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -1,20 +1,15 @@ -from typing import Callable, List, Optional, Tuple, Type +from typing import Callable, List, Optional import torch -from abc import ABC - -from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedvLLMParameter) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsWNA16"] @@ -25,199 +20,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) -class QuantLinearKernel(ABC): - @classmethod - def get_min_capability(cls) -> int: - raise NotImplementedError("apply_weights not implemented") - - @classmethod - def can_implement(cls, - full_weight_shape: Tuple[int, int], # [in, out] - partition_weight_shape: Tuple[int, int], - quant_type: ScalarType, - act_type: torch.dtype, - group_size: int, - zero_points: bool, - act_reordering: bool) -> Tuple[bool, str]: - raise NotImplementedError("can_implement not implemented") - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - pass - - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - raise NotImplementedError("apply_weights not implemented") - -class MarlinKernel(QuantLinearKernel): - - @classmethod - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def can_implement(cls, - full_weight_shape: Tuple[int, int], # [in, out] - partition_weight_shape: Tuple[int, int], - quant_type: ScalarType, - act_type: torch.dtype, - group_size: int, - zero_points: bool, - act_reordering: bool) -> Tuple[bool, str]: - - if zero_points: - return False, "Zero points currently not supported by "\ - " Compressed Tensors + Marlin. (Kernel supports it"\ - " but CompressedTensorsWNA16 does not so support has"\ - " not been addes to MarlinWNA16Kernel yet" - - if quant_type not in query_marlin_supported_quant_types(zero_points): - return False, f"Quant type ({quant_type}) not supported by Marlin,"\ - " supported types are: "\ - f"{query_marlin_supported_quant_types(zero_points)}" - - if group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return False, f"Group size ({group_size}) not supported by Marlin,"\ - " supported group sizes are: "\ - f"{MARLIN_SUPPORTED_GROUP_SIZES}" - - return check_marlin_supports_shape( - partition_weight_shape[0], - partition_weight_shape[1], - full_weight_shape[1], - group_size) - - # Checkpoints are serialized in compressed-tensors format, which is - # different from marlin format. Handle repacking here. - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.weight_packed.device - - # Allocate marlin workspace. - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Act-order not supported in compressed-tensors yet, so set to empty. - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.weight_zp = marlin_make_empty_g_idx(device) - - # Repack weights from compressed-tensors format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed.t().contiguous(), - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=layer.weight_type.size_bits) - replace_tensor(layer, "weight_packed", marlin_qweight) - - # Permute scales from compressed-tensors format to marlin format. - marlin_scales = marlin_permute_scales( - layer.weight_scale.squeeze().t().contiguous(), - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - group_size=layer.group_size) - replace_tensor(layer, "weight_scale", marlin_scales) - - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - print("running marlin") - return apply_gptq_marlin_linear( - input=x, - weight=layer.weight_packed, - weight_scale=layer.weight_scale, - weight_zp=layer.weight_zp, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=layer.weight_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=True, - bias=bias) - - -class MacheteKernel(QuantLinearKernel): - - @classmethod - def get_min_capability(cls) -> int: - return 90 - - @classmethod - def can_implement(cls, - full_weight_shape: Tuple[int, int], # [in, out] - partition_weight_shape: Tuple[int, int], - quant_type: ScalarType, - act_type: torch.dtype, - group_size: int, - zero_points: bool, - act_reordering: bool) -> Tuple[bool, str]: - if act_reordering: - return False, "Act reordering currently not supported by Machete" - - if zero_points: - return False, "Zero points currently not supported by "\ - " Compressed Tensors + Machete. (Kernel supports it"\ - " but CompressedTensorsWNA16 does not so support has"\ - " not been addes to MacheteWNA16Kernel yet" - - if quant_type not in query_machete_supported_quant_types(zero_points): - return False, f"Quant type ({quant_type}) not supported by "\ - "Machete, supported types are: "\ - f"{query_machete_supported_quant_types(zero_points)}" - - if group_size not in MACHETE_SUPPORTED_GROUP_SIZES: - return False, f"Group size ({group_size}) not supported by "\ - "Machete, supported group sizes are: "\ - f"{MACHETE_SUPPORTED_GROUP_SIZES}" - - return check_machete_supports_shape( - partition_weight_shape[0], - partition_weight_shape[1]) - - - # Checkpoints are serialized in compressed-tensors format, which is - # different from marlin format. Handle repacking here. - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - machete_qweight = ops.machete_prepack_B( - layer.weight_packed.t(), - layer.weight_type - ) - replace_tensor(layer, "weight_packed", machete_qweight) - replace_tensor(layer, "weight_scale", - layer.weight_scale.clone().t()) - - - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - assert layer.weight_scale.dtype == x.dtype - # print("running machete") - # print(layer.weight_packed.shape) - # print(layer.weight_scale.dtype, x.dtype, layer.group_size) - x_2d = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (layer.output_size_per_partition, ) - - output = ops.machete_gemm( - a=x_2d, - b_q=layer.weight_packed, - b_type=layer.weight_type, - b_zeros=None, - b_scales=layer.weight_scale, - b_group_size=layer.group_size - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - class CompressedTensorsWNA16(CompressedTensorsScheme): - - # in order of priority (i.e. performance if available) - possible_kernels: List[Type[QuantLinearKernel]] = [ - MacheteKernel, - #MarlinKernel, - ] def __init__(self, strategy: str, @@ -245,54 +48,27 @@ def get_min_capability(cls) -> int: # ampere and up return 80 - def create_weights(self, layer: torch.nn.Module, - output_size: int, - input_size: int, - output_partition_sizes: List[int], + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, - params_dtype: torch.dtype, - weight_loader: Callable, + params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + output_size_per_partition = sum(output_partition_sizes) - - failure_reasons = [] - - full_weight_shape = (input_size, output_size) - partition_weight_shape = \ - (input_size_per_partition, output_size_per_partition) - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - - for kernel in self.possible_kernels: - if kernel.get_min_capability() > capability: - failure_reasons.append( - (kernel.__name__, - f"requires capability {kernel.get_min_capability()}, " - f"current capability is {capability}")) - - can_implement, failure_reason = kernel.can_implement( - full_weight_shape=full_weight_shape, - partition_weight_shape=partition_weight_shape, - quant_type=self.quant_type, - act_type=params_dtype, - group_size=self.group_size, - zero_points=False, - act_reordering=False) + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=self.group_size, + zero_points=False, + act_reordering=False + ) - if can_implement: - self.kernel = kernel() - break - else: - failure_reasons.append( - (kernel.__name__, failure_reason)) - - if not hasattr(self, "kernel"): - raise ValueError( - f"Failed to find a kernel that can implement the "\ - "WNA16 linear layer. Reasons: \n" - + '\n'.join([f' {x} cannot implement due to: {r}' - for x, r in failure_reasons])) + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + self.kernel = kernel_type(mp_linear_kernel_config) # If group_size is -1, we are in channelwise case. channelwise = (self.group_size == -1) @@ -302,12 +78,6 @@ def create_weights(self, layer: torch.nn.Module, # scales across all gpus. partition_scales = (row_parallel and not channelwise) - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size) - scales_and_zp_size = input_size // group_size if partition_scales: @@ -336,6 +106,7 @@ def create_weights(self, layer: torch.nn.Module, dtype=params_dtype, ) } + if self.group_size == -1: weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) @@ -354,49 +125,12 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.group_size = group_size - # Checkpoints are serialized in compressed-tensors format, which is # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.weight_packed.device - - # Allocate marlin workspace. - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Act-order not supported in compressed-tensors yet, so set to empty. - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.weight_zp = marlin_make_empty_g_idx(device) - # Update for kernel - layer.weight_packed = torch.nn.Parameter( - layer.weight_packed.t().contiguous(), requires_grad=False) - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.squeeze().t().contiguous(), requires_grad=False) - - # Repack weights from compressed-tensors format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_type.size_bits) - replace_tensor(layer, "weight_packed", marlin_qweight) + self.kernel.process_weights_after_loading(layer) - # Permute scales from compressed-tensors format to marlin format. - marlin_scales = marlin_permute_scales( - layer.weight_scale, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - group_size=layer.group_size) - replace_tensor(layer, "weight_scale", marlin_scales) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - return self.kernel.apply_weights(layer, x, bias) \ No newline at end of file + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index b92697531c299..7f62e50e70ec9 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -9,10 +9,11 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_tensor from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, + marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.scalar_type import scalar_types diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py new file mode 100644 index 0000000000000..e1759e1429ef0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + +from vllm.scalar_type import ScalarType + + +@dataclass +class MPLinearLayerConfig: + full_weight_shape: Tuple[int, int] # [in, out] + partition_weight_shape: Tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + act_reordering: bool + + +class MPLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, c: MPLinearLayerConfig) -> None: + assert self.can_implement(c) + self.config = c + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py new file mode 100644 index 0000000000000..11e79967206b8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MacheteLinearKernel.py @@ -0,0 +1,75 @@ +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_tensor +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) + +from .MPLinearKernel import * + + +class MacheteLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.act_reordering: + return False, "Act reordering currently not supported by Machete" + + if c.zero_points: + return False, "Zero points currently not supported by "\ + " Compressed Tensors + Machete. (Kernel supports it"\ + " but CompressedTensorsWNA16 does not so support has"\ + " not been added to MacheteWNA16Kernel yet" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Repack weights and scales for Machete + + replace_tensor( + layer, "weight_packed", + ops.machete_prepack_B(layer.weight_packed.t(), + self.config.weight_type)) + replace_tensor(layer, "weight_scale", + layer.weight_scale.clone().contiguous().t()) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + c = self.config + + # print("running machete") + # print(layer.weight_packed.shape) + # print(layer.weight_scale.dtype, x.dtype, layer.group_size) + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + output = ops.machete_gemm(a=x_2d, + b_q=layer.weight_packed, + b_type=c.weight_type, + b_zeros=None, + b_scales=layer.weight_scale, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py new file mode 100644 index 0000000000000..632785c10c31f --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MarlinLinearKernel.py @@ -0,0 +1,93 @@ +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_tensor +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, + check_marlin_supports_shape, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_permute_scales, + query_marlin_supported_quant_types) + +from .MPLinearKernel import * + + +class MarlinLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.zero_points: + return False, "Zero points currently not supported by "\ + " Compressed Tensors + Marlin. (Kernel supports it"\ + " but CompressedTensorsWNA16 does not so support has"\ + " not been added to MarlinWNA16Kernel yet" + + quant_types = query_marlin_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, f"Quant type ({c.weight_type}) not supported by"\ + f" Marlin, supported types are: {quant_types}" + + if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Marlin, supported group sizes are: "\ + f"{MARLIN_SUPPORTED_GROUP_SIZES}" + + return check_marlin_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1], + c.full_weight_shape[1], + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = layer.weight_packed.device + c = self.config + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(c.partition_weight_shape[1], + device) + + # Act-order not supported in compressed-tensors yet, so set to empty. + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + # No zero-point + layer.weight_zp = marlin_make_empty_g_idx(device) + + # Repack weights from compressed-tensors format to marlin format. + marlin_qweight = ops.gptq_marlin_repack( + layer.weight_packed.t().contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits) + replace_tensor(layer, "weight_packed", marlin_qweight) + + # Permute scales from compressed-tensors format to marlin format. + marlin_scales = marlin_permute_scales( + layer.weight_scale.squeeze().t().contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size) + replace_tensor(layer, "weight_scale", marlin_scales) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + c = self.config + + return apply_gptq_marlin_linear( + input=x, + weight=layer.weight_packed, + weight_scale=layer.weight_scale, + weight_zp=layer.weight_zp, + g_idx=layer.g_idx, + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=True, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py new file mode 100644 index 0000000000000..22172771e5b64 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -0,0 +1,44 @@ +from typing import List, Optional, Type + +from vllm.platforms import current_platform + +from .MacheteLinearKernel import MacheteLinearKernel +from .MarlinLinearKernel import MarlinLinearKernel +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py index e69de29bb2d1d..3a88224040530 100644 --- a/vllm/model_executor/layers/quantization/utils/__init__.py +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -0,0 +1,3 @@ +from .layer_utils import replace_tensor + +__all__ = ['replace_tensor'] diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py new file mode 100644 index 0000000000000..8345bd860ca29 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -0,0 +1,12 @@ +import torch + + +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_tensor(layer: torch.nn.Module, name: str, + new_t: torch.Tensor) -> None: + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py index 956d239bcac25..18e1332050cdd 100644 --- a/vllm/model_executor/layers/quantization/utils/machete_utils.py +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -1,10 +1,12 @@ +from typing import List, Optional, Tuple + import torch -from typing import List, Tuple -from vllm.scalar_type import scalar_types, ScalarType +from vllm.scalar_type import ScalarType, scalar_types + +MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128] +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] -MACHETE_SUPPORTED_GROUP_SIZES=[-1, 128] -MACHETE_PREPACKED_BLOCK_SHAPE=[64, 128] def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: if zero_points: @@ -18,11 +20,11 @@ def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]: def check_machete_supports_shape(in_features: int, out_featrues: int) \ - -> Tuple[bool, str]: + -> Tuple[bool, Optional[str]]: if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: return False, "Input features size must be divisible by "\ f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: return False, "Output features size must be divisible by "\ f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" - return True, None \ No newline at end of file + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index b62fdffab4be4..fb2fdb5ca28df 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -121,7 +121,7 @@ def verify_marlin_supports_shape(output_size_per_partition: int, def check_marlin_supports_shape(output_size_per_partition: int, input_size_per_partition: int, input_size: int, group_size: int) \ - -> Tuple[bool, str]: + -> Tuple[bool, Optional[str]]: try: verify_marlin_supports_shape(output_size_per_partition, input_size_per_partition, input_size, @@ -234,17 +234,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp -# Newly generated tensors need to replace existing tensors that are -# already registered as parameters by vLLM (and won't be freed) -def replace_tensor(layer: torch.nn.Module, name: str, - new_t: torch.Tensor) -> None: - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor,