From d28f82b2c8608ed25758c5b394c9d0716eaf185f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 23 Nov 2024 22:22:54 -0800 Subject: [PATCH] [model][utils] add extract_layer_index utility function (#10599) Signed-off-by: youkaichao --- vllm/model_executor/models/arctic.py | 41 +++++++++++-------------- vllm/model_executor/models/deepseek.py | 19 +++++++----- vllm/model_executor/models/gemma2.py | 15 +++------ vllm/model_executor/models/olmoe.py | 8 ++--- vllm/model_executor/models/qwen2_moe.py | 6 ++-- vllm/model_executor/models/utils.py | 21 +++++++++++++ 6 files changed, 59 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index ac4c464aa10ac..fd6b5659df5d1 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -33,7 +33,7 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -44,15 +44,14 @@ class ArcticMLP(nn.Module): def __init__(self, config: ArcticConfig, - layer_id: int, expert_id: int = -1, is_residual_mlp: bool = False, quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True): + reduce_results: bool = True, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size self.expert_id = expert_id - self.layer_id = layer_id self.ffn_dim = config.intermediate_size if not is_residual_mlp \ else self.hidden_size @@ -85,13 +84,14 @@ class ArcticMoE(nn.Module): def __init__(self, config: ArcticConfig, - layer_id: int, tp_size: Optional[int] = None, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True): + reduce_results: bool = True, + prefix: str = ""): super().__init__() + layer_id = extract_layer_index(prefix) self.tp_size = tp_size or get_tensor_model_parallel_world_size() self.hidden_size = config.hidden_size self.num_experts = config.num_local_experts @@ -109,15 +109,16 @@ def __init__(self, if not self.is_moe_layer: self.mlp = ArcticMLP(config, - layer_id=layer_id, quant_config=quant_config, - reduce_results=reduce_results) + reduce_results=reduce_results, + prefix=f"{prefix}.mlp") else: self.gate = ReplicatedLinear(self.hidden_size, self.num_experts, bias=False, params_dtype=self.params_dtype, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.gate") if self.is_quant: self.ws = DeepSpeedFPParameter( torch.Size((self.num_experts, 2 * self.intermediate_size, @@ -220,14 +221,12 @@ class ArcticAttention(nn.Module): def __init__( self, config: ArcticConfig, - layer_idx: Optional[int] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.config = config - self.layer_idx = layer_idx self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -298,26 +297,25 @@ class ArcticDecoderLayer(nn.Module): def __init__( self, config: ArcticConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() - self.layer_idx = layer_idx self.hidden_size = config.hidden_size + layer_idx = extract_layer_index(prefix) is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 self.use_residual = config.use_residual and is_moe_layer self.self_attn = ArcticAttention(config, - layer_idx, cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn") self.block_sparse_moe = ArcticMoE( config, - layer_id=layer_idx, quant_config=quant_config, - reduce_results=(not self.use_residual)) + reduce_results=(not self.use_residual), + prefix=f"{prefix}.block_sparse_moe", + ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -328,9 +326,9 @@ def __init__( self.residual_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.residual_mlp = ArcticMLP(config, - layer_id=layer_idx, is_residual_mlp=True, - reduce_results=False) + reduce_results=False, + prefix=f"{prefix}.residual_mlp") def forward( self, @@ -384,11 +382,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=self.vocab_size) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: ArcticDecoderLayer(config, - int(prefix.split(".")[-1]), - cache_config, - quant_config, - prefix=prefix), + lambda prefix: ArcticDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self._attn_implementation = config._attn_implementation self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 32488d931ea1c..74b6bfdf21909 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -63,6 +63,7 @@ def __init__( hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -92,6 +93,7 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -260,12 +262,12 @@ class DeepseekDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() + layer_idx = extract_layer_index(prefix) self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -285,13 +287,16 @@ def __init__( if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(config=config, quant_config=quant_config) + self.mlp = DeepseekMoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") else: self.mlp = DeepseekMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -347,11 +352,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekDecoderLayer(config, - int(prefix.split(".")[-1]), - cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: DeepseekDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 9309cced61bb3..fd8223dd9be1b 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -85,7 +86,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Gemma2Attention(nn.Module): def __init__(self, - layer_idx: int, config: Gemma2Config, hidden_size: int, num_heads: int, @@ -98,7 +98,6 @@ def __init__(self, attn_logits_soft_cap: Optional[float] = None, prefix: str = "") -> None: super().__init__() - self.layer_idx = layer_idx self.config = config self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -145,6 +144,7 @@ def __init__(self, # reference: # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa + layer_idx = extract_layer_index(prefix) use_sliding_window = (layer_idx % 2 == 0 and config.interleaved_sliding_window is not None) sliding_window = config.interleaved_sliding_window if \ @@ -178,7 +178,6 @@ class Gemma2DecoderLayer(nn.Module): def __init__( self, - layer_idx: int, config: Gemma2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, @@ -187,7 +186,6 @@ def __init__( super().__init__() self.hidden_size = config.hidden_size self.self_attn = Gemma2Attention( - layer_idx=layer_idx, config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -262,11 +260,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[-1]), - config, - cache_config, - quant_config, - prefix=prefix), + lambda prefix: Gemma2DecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 5b5b3ef48b035..5d9091cfb9311 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -181,7 +181,6 @@ class OlmoeDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -264,11 +263,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: OlmoeDecoderLayer(config, - int(prefix.split(".")[-1]), - cache_config, - quant_config, - prefix=prefix), + lambda prefix: OlmoeDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=1e-5) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 1091f88ab2534..ba70243c6533d 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -53,7 +53,7 @@ from vllm.utils import print_warning_once from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -244,7 +244,6 @@ class Qwen2MoeDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -269,6 +268,7 @@ def __init__( # Note: Qwen/Qwen2-57B-A14B-Instruct does not have # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) if (layer_idx not in mlp_only_layers) and ( @@ -337,8 +337,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Qwen2MoeDecoderLayer(config=config, - layer_idx=int( - prefix.split(".")[-1]), cache_config=cache_config, quant_config=quant_config, prefix=prefix), diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 2ab9b19e22068..dcfd2cb7d2622 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -629,3 +629,24 @@ def maybe_prefix(prefix: str, name: str) -> str: The string "prefix.name" if prefix was non-empty, otherwise just "name". """ return name if not prefix else f"{prefix}.{name}" + + +def extract_layer_index(layer_name: str) -> int: + """ + Extract the layer index from the module name. + Examples: + - "encoder.layers.0" -> 0 + - "encoder.layers.1.self_attn" -> 1 + - "2.self_attn" -> 2 + - "model.encoder.layers.0.sub.1" -> ValueError + """ + subnames = layer_name.split(".") + int_vals: List[int] = [] + for subname in subnames: + try: + int_vals.append(int(subname)) + except ValueError: + continue + assert len(int_vals) == 1, (f"layer name {layer_name} should" + " only contain one integer") + return int_vals[0]