diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 0356435e9c257..fa6b95f5481ad 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -29,7 +29,7 @@ LlamaModel) from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, - make_layers, maybe_prefix, + maybe_prefix, merge_multimodal_embeddings) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -363,27 +363,9 @@ class AriaMoELMModel(LlamaModel): """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - # FIXME: this is a hack to disable the compilation of the model - self.do_not_compile = True - - self.layers = None - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: MoEDecoderLayer( - config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - ), - prefix=f"{prefix}.layers", - ) + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=MoEDecoderLayer) # Adapted from LlamaModel.load_weights with the modification of adding # the expert weights mapping to `stacked_params_mapping` diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 33d78d74129c8..355b2f3ef8b28 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import torch from torch import nn @@ -273,7 +273,11 @@ def forward( @support_torch_compile class LlamaModel(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): super().__init__() config = vllm_config.model_config.hf_config @@ -299,10 +303,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: LlamaDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: