From a1de811e7ddfedf5a528f6e34326771ce234318c Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 3 Jan 2025 08:26:47 +0000 Subject: [PATCH] addressed comments on mamba_mixer2.py Signed-off-by: Yu Chin Fabian Lim --- .../layers/mamba/mamba_mixer2.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 74fbfcf1523df..ee1961d73434d 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -28,7 +28,6 @@ # Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated -# also referenced https://github.com/vllm-project/vllm/pull/9292 @CustomOp.register("mixer2_gated_rms_norm") class Mixer2RMSNormGated(CustomOp): @@ -40,6 +39,8 @@ def __init__(self, hidden_size, eps=1e-6): self.tp_size = get_tensor_model_parallel_world_size() set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)}) + assert self.hidden_size % self.tp_size== 0,\ + "Tensor parallel world size must divide hidden size." def forward_native( self, @@ -198,6 +199,9 @@ def __init__(self, self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() + assert num_heads % self.tp_size == 0, \ + "Tensor parallel world size must divide num heads." + self.ssm_state_size = ssm_state_size self.use_rms_norm = use_rms_norm self.activation = activation @@ -247,7 +251,7 @@ def __init__(self, self.num_heads // n_groups, # ratio for mapping back to original group ) - intemediate_settings = (intermediate_size, 0, 1) + intermediate_settings = (intermediate_size, 0, 1) head_setings = (self.num_heads, 0, 1) # - the weight already has a "weight_loader" attribute @@ -260,7 +264,7 @@ def __init__(self, "weight_loader": mamba_v2_sharded_weight_loader( [ - intemediate_settings, + intermediate_settings, group_shard_settings, group_shard_settings, ], @@ -274,7 +278,7 @@ def __init__(self, self.conv1d.weight, { "weight_loader": mamba_v2_sharded_weight_loader([ - intemediate_settings, + intermediate_settings, group_shard_settings, group_shard_settings, ], self.tp_size, tp_rank) @@ -287,8 +291,8 @@ def __init__(self, "weight_loader": mamba_v2_sharded_weight_loader( [ - intemediate_settings, # for gate - intemediate_settings, + intermediate_settings, # for gate + intermediate_settings, group_shard_settings, group_shard_settings, head_setings, # for dt @@ -339,15 +343,7 @@ def forward_cuda( seq_len, _ = hidden_states.shape groups_time_state_size = self.n_groups * self.ssm_state_size - # - doing it differently from mixer v1; little confused with its logic - # - we need to do is to detect if there is any prefill; if there are - # no prefils, then each example will be coming in one sample at a time - # - on the other hand v1 checks for "query_start_loc" - # and "context_lens_tensor" however we have noticed that, even - # when the samples are coming in - # one at a time, they are still not NONE, e.g., - # * "query_start_loc" = [0, 1, ..] - # * "context_lens_tensor" = [8, ...] + # detect if there are prefills has_prefill = attn_metadata.num_prefills > 0 # - also need flags to indicate if there are initial states