Skip to content

Commit

Permalink
fix condition for is_k_full; clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Jan 10, 2025
1 parent cc81732 commit c7a912e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 28 deletions.
49 changes: 29 additions & 20 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,15 +240,26 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
elif shard_id == "w2":
param_data[expert_id] = loaded_weight

def _load_model_weight_or_group_weight_scale(
self, shard_dim: int, expert_data: torch.Tensor, shard_id: str,
loaded_weight: torch.Tensor, tp_rank: int, load_full_w2: bool):
# Load grouped weight scales for group quantization
# or model weights
# In act_order scenario, we need to load full w2 scales
def _load_model_weight_or_group_weight_scale(self,
shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full_w2: bool = False):
"""
Load grouped weight scales for group quantization or model weights
:param shard_dim: dimension to shard
:param expert_data: parameter for a particular expert
:param shard_id: either w1, w2, or w3
:param loaded_weight: checkpoint weight to load into the param
:param tp_rank: tensor parallel rank
:param load_full_w2: whether or not the w2 loaded should be sharded.
"""
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
# In the case where we have actorder/g_idx, we do not partition the
# w2 scales, as indicated by `load_full` argument, for all tp cases
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
Expand Down Expand Up @@ -292,9 +303,12 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)

def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int,
load_full: bool):
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):

# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
Expand All @@ -318,12 +332,10 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):

if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
self._load_w2(shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
load_full=False)
tp_rank=tp_rank)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
Expand Down Expand Up @@ -403,16 +415,14 @@ def weight_loader(self, param: torch.nn.Parameter,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
#TODO load_full_w2 must be set to True only with group act_order
# and tp>1
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
load_full_w2=True)
load_full_w2=getattr(param, "load_full_w2", False))
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
Expand All @@ -438,8 +448,7 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
load_full_w2=False)
tp_rank=tp_rank)
return

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):

# not needed by fp8
extra_weight_attrs.pop("intermediate_full")
params_dtype = torch.float8_e4m3fn

# WEIGHTS
Expand Down Expand Up @@ -264,12 +266,12 @@ def __init__(

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
intermediate_full: int, params_dtype: torch.dtype,
**extra_weight_attrs):
params_dtype: torch.dtype, **extra_weight_attrs):

# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
intermediate_full = extra_weight_attrs.pop("intermediate_full")
extra_weight_attrs.update({
"is_transposed": True,
"quant_method": self.strategy
Expand All @@ -292,15 +294,20 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

self.is_k_full = (intermediate_full == intermediate_size)
scales_size = (intermediate_full if self.actorder
and self.group_size != -1 else intermediate_size)
# In the case where we have actorder/g_idx,
# we do not partition the w2 scales
load_full_w2 = self.actorder and self.group_size != -1
w2_scales_size = (intermediate_full
if load_full_w2 else intermediate_size)
# @eliza TODO: is this condition actually needed/is it doing anything?
self.is_k_full = (not self.actorder) or (
self.actorder and intermediate_size == intermediate_full)

if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = scales_size // self.group_size
num_groups_w2 = w2_scales_size // self.group_size
num_groups_w13 = hidden_size // self.group_size

w13_scale = torch.nn.Parameter(torch.ones(num_experts,
Expand All @@ -318,6 +325,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})

w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False)
Expand Down Expand Up @@ -427,6 +435,8 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
num_experts = layer.w13_weight_g_idx.shape[0]
device = layer.w13_weight_g_idx.device

# when running models with grouped act order,
# resort g_idx values provided
if self.actorder == "group":
w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
Expand Down Expand Up @@ -542,5 +552,4 @@ def apply(
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits,
is_k_full=self.is_k_full,
)
is_k_full=self.is_k_full)

0 comments on commit c7a912e

Please sign in to comment.