Skip to content

Commit

Permalink
[Quantization] Stricter checks for MoE gate (#2109)
Browse files Browse the repository at this point in the history
This PR strenthens the MoE gate checks to include checking number of
experts, given the real MoE gate router layer's output feature number
is the number of experts and is usually very small.

This PR comes from a regression that there is a layer in RWKV6 that
ends with name "gate" is not for MoE at all.
  • Loading branch information
MasterJH5574 authored Apr 9, 2024
1 parent 36d0e6a commit 3e71b70
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 5 deletions.
6 changes: 5 additions & 1 deletion python/mlc_llm/quantization/awq_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ def visit_module(self, name: str, node: nn.Module) -> Any:
The new node to replace current node.
"""

if isinstance(node, nn.Linear) and not is_final_fc(name) and not is_moe_gate(name):
if (
isinstance(node, nn.Linear)
and not is_final_fc(name)
and not is_moe_gate(name, node)
):
return AWQQuantizeLinear.from_linear(node, self.config)
return self.visit(name, node)

Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/quantization/ft_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def visit_module(self, name: str, node: nn.Module) -> Any:
group_quantize = self.config.fallback_group_quantize()
self.quant_map.map_func[weight_name] = group_quantize.quantize_weight
return GroupQuantizeLinear.from_linear(node, group_quantize)
if not is_moe_gate(name):
if not is_moe_gate(name, node):
self.quant_map.map_func[weight_name] = self.config.quantize_weight
return FTQuantizeLinear.from_linear(node, self.config)
if isinstance(node, nn.Embedding):
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/quantization/group_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def visit_module(self, name: str, node: nn.Module) -> Any:
if (
isinstance(node, nn.Linear)
and (not is_final_fc(name) or self.config.quantize_final_fc)
and not is_moe_gate(name)
and not is_moe_gate(name, node)
):
weight_name = f"{name}.weight"
self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"]
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_llm/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def is_final_fc(name: str) -> bool:
return name in ["head", "lm_head", "lm_head.linear", "embed_out"]


def is_moe_gate(name: str) -> bool:
def is_moe_gate(name: str, node: nn.Linear) -> bool:
"""Check whether the parameter is the MoE gate layer."""
return name.endswith("gate")
return name.endswith("gate") and isinstance(node.out_features, int) and node.out_features < 16


def compile_quantize_func(mod: IRModule, device) -> Callable:
Expand Down

0 comments on commit 3e71b70

Please sign in to comment.