Skip to content

Commit

Permalink
addressed comments on mamba_mixer2.py
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jan 3, 2025
1 parent 3167671 commit a1de811
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@


# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
# also referenced https://github.com/vllm-project/vllm/pull/9292
@CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp):

Expand All @@ -40,6 +39,8 @@ def __init__(self, hidden_size, eps=1e-6):
self.tp_size = get_tensor_model_parallel_world_size()
set_weight_attrs(self.weight,
{"weight_loader": sharded_weight_loader(0)})
assert self.hidden_size % self.tp_size== 0,\
"Tensor parallel world size must divide hidden size."

def forward_native(
self,
Expand Down Expand Up @@ -198,6 +199,9 @@ def __init__(self,
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

assert num_heads % self.tp_size == 0, \
"Tensor parallel world size must divide num heads."

self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm
self.activation = activation
Expand Down Expand Up @@ -247,7 +251,7 @@ def __init__(self,
self.num_heads //
n_groups, # ratio for mapping back to original group
)
intemediate_settings = (intermediate_size, 0, 1)
intermediate_settings = (intermediate_size, 0, 1)
head_setings = (self.num_heads, 0, 1)

# - the weight already has a "weight_loader" attribute
Expand All @@ -260,7 +264,7 @@ def __init__(self,
"weight_loader":
mamba_v2_sharded_weight_loader(
[
intemediate_settings,
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
Expand All @@ -274,7 +278,7 @@ def __init__(self,
self.conv1d.weight, {
"weight_loader":
mamba_v2_sharded_weight_loader([
intemediate_settings,
intermediate_settings,
group_shard_settings,
group_shard_settings,
], self.tp_size, tp_rank)
Expand All @@ -287,8 +291,8 @@ def __init__(self,
"weight_loader":
mamba_v2_sharded_weight_loader(
[
intemediate_settings, # for gate
intemediate_settings,
intermediate_settings, # for gate
intermediate_settings,
group_shard_settings,
group_shard_settings,
head_setings, # for dt
Expand Down Expand Up @@ -339,15 +343,7 @@ def forward_cuda(
seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size

# - doing it differently from mixer v1; little confused with its logic
# - we need to do is to detect if there is any prefill; if there are
# no prefils, then each example will be coming in one sample at a time
# - on the other hand v1 checks for "query_start_loc"
# and "context_lens_tensor" however we have noticed that, even
# when the samples are coming in
# one at a time, they are still not NONE, e.g.,
# * "query_start_loc" = [0, 1, ..]
# * "context_lens_tensor" = [8, ...]
# detect if there are prefills
has_prefill = attn_metadata.num_prefills > 0

# - also need flags to indicate if there are initial states
Expand Down

0 comments on commit a1de811

Please sign in to comment.