diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3d822fc0c7f99..9d1a1e4f24ff7 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -289,13 +289,19 @@ def __init__( self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None - self.quant_method.create_weights( - layer=self, - num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=self.intermediate_size_per_partition, - params_dtype=params_dtype, - weight_loader=self.weight_loader) + moe_quant_params = { + "num_experts": num_experts, + "hidden_size": hidden_size, + "intermediate_size": self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + } + # need full intermediate size pre-sharding for WNA16 act order + if (self.quant_method.__class__.__name__ == + "CompressedTensorsWNA16MoEMethod"): + moe_quant_params["intermediate_full"] = intermediate_size + + self.quant_method.create_weights(layer=self, **moe_quant_params) def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, @@ -312,19 +318,30 @@ def _load_per_tensor_weight_scale(self, shard_id: str, elif shard_id == "w2": param_data[expert_id] = loaded_weight - def _load_model_weight_or_group_weight_scale(self, shard_dim: int, + def _load_model_weight_or_group_weight_scale(self, + shard_dim: int, expert_data: torch.Tensor, shard_id: str, loaded_weight: torch.Tensor, - tp_rank: int): - # Load grouped weight scales for group quantization - # or model weights + tp_rank: int, + load_full_w2: bool = False): + """ + Load grouped weight scales for group quantization or model weights + :param shard_dim: dimension to shard + :param expert_data: parameter for a particular expert + :param shard_id: either w1, w2, or w3 + :param loaded_weight: checkpoint weight to load into the param + :param tp_rank: tensor parallel rank + :param load_full_w2: whether or not the w2 loaded should be sharded. + """ if shard_id == "w2": - self._load_w2(shard_id=shard_id, - shard_dim=shard_dim, + # In the case where we have actorder/g_idx, we do not partition the + # w2 scales, as indicated by `load_full` argument, for all tp cases + self._load_w2(shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_rank=tp_rank, + load_full=load_full_w2) elif shard_id in ("w1", "w3"): self._load_w13(shard_id=shard_id, shard_dim=shard_dim, @@ -364,15 +381,21 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight) - def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): + def _load_w2(self, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) @@ -387,8 +410,7 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): if shard_id == "w2": - self._load_w2(shard_id=shard_id, - shard_dim=shard_dim, + self._load_w2(shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) @@ -428,7 +450,7 @@ def weight_loader(self, param: torch.nn.Parameter, is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: - shard_dim = ~shard_dim + shard_dim = int(not shard_dim) # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: @@ -480,7 +502,8 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_rank=tp_rank, + load_full_w2=getattr(param, "load_full_w2", False)) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: self._load_per_tensor_weight_scale(shard_id=shard_id, param=param, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 4fb8fd84e92d4..c66fce3d4316a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -13,6 +13,7 @@ FusedMoeWeightScaleSupported) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -254,6 +255,7 @@ def __init__( self.packed_factor = 32 // config.num_bits self.strategy = config.strategy self.group_size = config.group_size + self.actorder = config.actorder assert config.symmetric, ( "Only symmetric quantization is supported for MoE") @@ -269,9 +271,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs): + assert params_dtype == torch.float16, ( + "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 + ) + # Will transpose the loaded weight along the # intermediate and hidden dim sizes. Will # shard for TP along the transposed dims + intermediate_full = extra_weight_attrs.pop("intermediate_full") extra_weight_attrs.update({ "is_transposed": True, "quant_method": self.strategy @@ -294,11 +301,20 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight_packed", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + # In the case where we have actorder/g_idx, + # we do not partition the w2 scales + load_full_w2 = self.actorder and self.group_size != -1 + w2_scales_size = (intermediate_full + if load_full_w2 else intermediate_size) + + self.is_k_full = (not self.actorder) or (intermediate_size + == intermediate_full) + if self.strategy == "channel": num_groups_w2 = num_groups_w13 = 1 self.group_size = -1 else: - num_groups_w2 = intermediate_size // self.group_size + num_groups_w2 = w2_scales_size // self.group_size num_groups_w13 = hidden_size // self.group_size w13_scale = torch.nn.Parameter(torch.ones(num_experts, @@ -316,6 +332,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w2_weight_scale", w2_scale) set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), requires_grad=False) @@ -335,7 +352,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, ), requires_grad=False, ) - layer.register_parameter("w13_g_idx", w13_g_idx) + layer.register_parameter("w13_weight_g_idx", w13_g_idx) set_weight_attrs(w13_g_idx, extra_weight_attrs) w2_g_idx = torch.nn.Parameter( @@ -346,7 +363,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, ), requires_grad=False, ) - layer.register_parameter("w2_g_idx", w2_g_idx) + layer.register_parameter("w2_weight_g_idx", w2_g_idx) set_weight_attrs(w2_g_idx, extra_weight_attrs) w13_g_idx_sort_indices = torch.nn.Parameter( @@ -422,24 +439,55 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_k2 = layer.w2_weight_packed.shape[2] size_k13 = layer.w13_weight_packed.shape[2] - num_experts = layer.w13_g_idx.shape[0] - device = layer.w13_g_idx.device - layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) + num_experts = layer.w13_weight_g_idx.shape[0] + device = layer.w13_weight_g_idx.device + + # when running models with grouped act order, + # resort to g_idx values provided in checkpoint + if self.actorder == "group": + w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) + + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_weight_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort( + layer.w2_weight_g_idx[e]).to(torch.int32) + w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][ + w2_g_idx_sort_indices[e]] + + replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + + else: + layer.w13_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) marlin_w13_qweight = ops.gptq_marlin_moe_repack( layer.w13_weight_packed, @@ -511,9 +559,9 @@ def apply( router_logits, topk_weights, topk_ids, - g_idx1=layer.w13_g_idx, - g_idx2=layer.w2_g_idx, + g_idx1=layer.w13_weight_g_idx, + g_idx2=layer.w2_weight_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.num_bits, - ) + is_k_full=self.is_k_full) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 61d1c911cd1ad..2e1b5e3c2d3b1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -62,7 +62,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, **kwargs): assert params_dtype == torch.float16, ( - "float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501 + "float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501 ) pack_factor = 32 // self.quant_type.size_bits