Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix parameter names and process_after_weight_loading for W4A16 MoE Group Act Order #11528

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,19 @@ def __init__(
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None

self.quant_method.create_weights(
layer=self,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
moe_quant_params = {
"num_experts": num_experts,
"hidden_size": hidden_size,
"intermediate_size": self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
moe_quant_params["intermediate_full"] = intermediate_size

self.quant_method.create_weights(layer=self, **moe_quant_params)

def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
Expand All @@ -312,19 +318,30 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
elif shard_id == "w2":
param_data[expert_id] = loaded_weight

def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
def _load_model_weight_or_group_weight_scale(self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
tp_rank: int,
load_full_w2: bool = False):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
# In the case where we have actorder/g_idx, we do not partition the
# w2 scales, as indicated by `load_full` argument, for all tp cases
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full=load_full_w2)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
Expand Down Expand Up @@ -364,15 +381,21 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)

def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):

# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)

Expand All @@ -387,8 +410,7 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):

if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
Expand Down Expand Up @@ -428,7 +450,7 @@ def weight_loader(self, param: torch.nn.Parameter,
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = ~shard_dim
shard_dim = int(not shard_dim)

# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
Expand Down Expand Up @@ -480,7 +502,8 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
tp_rank=tp_rank,
load_full_w2=getattr(param, "load_full_w2", False))
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
Expand Down Expand Up @@ -254,6 +255,7 @@ def __init__(
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy
self.group_size = config.group_size
self.actorder = config.actorder
assert config.symmetric, (
"Only symmetric quantization is supported for MoE")

Expand All @@ -269,9 +271,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):

assert params_dtype == torch.float16, (
"float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
)

# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
intermediate_full = extra_weight_attrs.pop("intermediate_full")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe rename intermediate_size -> intermediate_size_per_partition and intermediate_full -> intermediate_size?
This would make the names consistent with other quant configs, e.g. vllm/model_executor/layers/quantization/gptq.py

extra_weight_attrs.update({
"is_transposed": True,
"quant_method": self.strategy
Expand All @@ -294,11 +301,20 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

# In the case where we have actorder/g_idx,
# we do not partition the w2 scales
load_full_w2 = self.actorder and self.group_size != -1
w2_scales_size = (intermediate_full
if load_full_w2 else intermediate_size)

self.is_k_full = (not self.actorder) or (intermediate_size
== intermediate_full)

if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = intermediate_size // self.group_size
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size

w13_scale = torch.nn.Parameter(torch.ones(num_experts,
Expand All @@ -316,6 +332,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})

w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False)
Expand All @@ -335,7 +352,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx", w13_g_idx)
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)

w2_g_idx = torch.nn.Parameter(
Expand All @@ -346,7 +363,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx", w2_g_idx)
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)

w13_g_idx_sort_indices = torch.nn.Parameter(
Expand Down Expand Up @@ -422,24 +439,55 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
size_k2 = layer.w2_weight_packed.shape[2]
size_k13 = layer.w13_weight_packed.shape[2]

num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_weight_g_idx.device

# when running models with grouped act order,
# resort to g_idx values provided in checkpoint
if self.actorder == "group":
dsikka marked this conversation as resolved.
Show resolved Hide resolved
w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx)

for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(
layer.w13_weight_g_idx[e]).to(torch.int32)
w2_g_idx_sort_indices[e] = torch.argsort(
layer.w2_weight_g_idx[e]).to(torch.int32)
w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][
w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][
w2_g_idx_sort_indices[e]]

replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices)

else:
layer.w13_weight_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_weight_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32,
device=device),
requires_grad=False,
)

marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_weight_packed,
Expand Down Expand Up @@ -511,9 +559,9 @@ def apply(
router_logits,
topk_weights,
topk_ids,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
g_idx1=layer.w13_weight_g_idx,
g_idx2=layer.w2_weight_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits,
)
is_k_full=self.is_k_full)
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int,
**kwargs):

assert params_dtype == torch.float16, (
"float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501
"float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501
)

pack_factor = 32 // self.quant_type.size_bits
Expand Down
Loading