diff --git a/vllm/config.py b/vllm/config.py index 91bbbfec4b7b3..34c473c68f391 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -392,6 +392,20 @@ def _verify_quantization(self) -> None: f"method specified in the `quantization` argument " f"({self.quantization}).") + from vllm.model_executor.layers.quantization.mx_quant import ( + SUPPORTED_MX_CONFIGS) + if self.quantization is not None and self.quantization in \ + SUPPORTED_MX_CONFIGS: + dtypes = self.quantization.split("_") + # remove 'w' and 'a' prefix + weight_dtype = dtypes[0][1:] + act_dtype = None if len(dtypes) == 1 else dtypes[1][1:] + self.hf_config.quantization_config = { + "weight_dtype": weight_dtype, + "act_dtype": act_dtype, + "quant_method": self.quantization, + } + if self.quantization is not None: if self.quantization not in supported_quantization: raise ValueError( diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index da841d052d728..414543eb6ee9e 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -24,6 +24,8 @@ from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config +from vllm.model_executor.layers.quantization.mx_quant import ( + SUPPORTED_MX_CONFIGS, MXConfig) from vllm.model_executor.layers.quantization.neuron_quant import ( NeuronQuantConfig) from vllm.model_executor.layers.quantization.qqq import QQQConfig @@ -51,6 +53,9 @@ "experts_int8": ExpertsInt8Config, "neuron_quant": NeuronQuantConfig, "ipex": IPEXConfig, + # Aliases for the different MX quant configs + **{config: MXConfig + for config in SUPPORTED_MX_CONFIGS}, } diff --git a/vllm/model_executor/layers/quantization/mx_quant.py b/vllm/model_executor/layers/quantization/mx_quant.py new file mode 100644 index 0000000000000..c38e9105810f0 --- /dev/null +++ b/vllm/model_executor/layers/quantization/mx_quant.py @@ -0,0 +1,141 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn import Module + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import ModelWeightParameter + +MXFP8_E5M2 = "fp8e5" +MXFP8_E4M3 = "fp8e4" +MXFP6_E3M2 = "fp6e3" +MXFP6_E2M3 = "fp6e2" +MXFP4_E2M1 = "fp4e2" + +# Supported element dtypes +# TODO: add support for MXINT8 +SUPPORTED_DTYPES = [ + MXFP8_E5M2, + MXFP8_E4M3, + MXFP6_E3M2, + MXFP6_E2M3, + MXFP4_E2M1, +] + +SUPPORTED_MX_CONFIGS = [ + *[ + f"w{wdtype}_a{adtype}" + for wdtype, adtype in zip(SUPPORTED_DTYPES, SUPPORTED_DTYPES) + ], + *[f"w{wdtype}" for wdtype in SUPPORTED_DTYPES], +] + +# https://github.com/pytorch/ao/blob/71a442ae775e0ea5a541dcce637b128070d1243c/torchao/prototype/mx_formats/constants.py#L3-L18 +DTYPE_MAP = { + MXFP8_E5M2: torch.float8_e5m2, + MXFP8_E4M3: torch.float8_e4m3fn, + MXFP6_E3M2: "fp6_e3m2", + MXFP6_E2M3: "fp6_e2m3", + MXFP4_E2M1: "fp4_e2m1", +} + + +class MXConfig(QuantizationConfig): + """MX Quantization Configuration.""" + + def __init__( + self, + weight_dtype: str, + act_dtype: Optional[str], + ) -> None: + if weight_dtype not in SUPPORTED_DTYPES: + raise ValueError(f"Unsupported weight scheme {weight_dtype}") + if act_dtype and act_dtype not in SUPPORTED_DTYPES: + raise ValueError(f"Unsupported activation scheme {act_dtype}") + + self.weight_dtype = DTYPE_MAP[weight_dtype] + self.act_dtype = DTYPE_MAP[act_dtype] if act_dtype else None + # Hardcoded for the MX spec + self.block_size = 32 + + def get_name(self) -> str: + return "mx_quant" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MXConfig": + weight_dtype = cls.get_from_keys(config, ["weight_dtype"]) + act_dtype = cls.get_from_keys(config, ["act_dtype"]) + return cls(weight_dtype=weight_dtype, act_dtype=act_dtype) + + def get_quant_method(self, layer: Module, + prefix: str) -> Optional["MXLinearMethod"]: + if isinstance(layer, LinearBase): + return MXLinearMethod(self) + return None + + +class MXLinearMethod(LinearMethodBase): + """Linear method for MX quant. """ + + def __init__(self, quant_config: MXConfig): + try: + import torchao # noqa: F401 + except ImportError as err: + raise ImportError("Please install torchao==0.6.1 via " + "`pip install torchao==0.6.1` to use " + "mx quantization.") from err + self.quant_config = quant_config + + def create_weights(self, layer: Module, input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + del input_size, output_size # Unused. + weight_loader = extra_weight_attrs.get("weight_loader") + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + def process_weights_after_loading(self, layer: Module) -> None: + from torchao.prototype.mx_formats.mx_tensor import MXTensor + layer.weight_mx = MXTensor.to_mx( + layer.weight.data.t().contiguous().to(torch.float32), + self.quant_config.weight_dtype, self.quant_config.block_size) + layer.weight = None + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + orig_dtype = x.dtype + from torchao.prototype.mx_formats.mx_tensor import MXTensor + if not self.quant_config.act_dtype: + weight = layer.weight_mx.to_dtype(orig_dtype).t().contiguous() + out = torch.nn.functional.linear(x, weight, bias) + else: + x = MXTensor.to_mx(x.to(torch.float32), + self.quant_config.act_dtype, + self.quant_config.block_size) + out = torch.mm(x, layer.weight_mx) + if bias: + out += bias + + return out.to(orig_dtype)