diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index f0d2a9e7f06be..aea3354cada90 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1,6 +1,7 @@ # ruff: noqa: SIM117 import collections import copy +import dataclasses import fnmatch import glob import json @@ -8,7 +9,8 @@ import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple, Type +from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple, + Type, cast) import gguf import huggingface_hub @@ -207,6 +209,22 @@ def load_model(self, *, model_config: ModelConfig, class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: @@ -313,17 +331,16 @@ def _prepare_weights(self, model_name_or_path: str, return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, model_name_or_path: str, revision: Optional[str], - fall_back_to_pt: bool + self, source: "Source" ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( - model_name_or_path, revision, fall_back_to_pt) + source.model_or_path, source.revision, source.fall_back_to_pt) if self.load_config.load_format == LoadFormat.NPCACHE: # Currently np_cache only support *.bin checkpoints assert use_safetensors is False weights_iterator = np_cache_weights_iterator( - model_name_or_path, self.load_config.download_dir, hf_folder, + source.model_or_path, self.load_config.download_dir, hf_folder, hf_weights_files) elif use_safetensors: weights_iterator = safetensors_weights_iterator(hf_weights_files) @@ -341,7 +358,29 @@ def _xla_weights_iterator(iterator: Generator): xm.mark_step() weights_iterator = _xla_weights_iterator(weights_iterator) - return weights_iterator + + # Apply the prefix. + return ((source.prefix + name, tensor) + for (name, tensor) in weights_iterator) + + def _get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + + primary_weights = DefaultModelLoader.Source( + model_config.model, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", + True)) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast(Iterable[DefaultModelLoader.Source], + getattr(model, "secondary_weights", ())) + for source in secondary_weights: + yield from self._get_weights_iterator(source) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, @@ -360,13 +399,8 @@ def load_model(self, *, model_config: ModelConfig, model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - model.load_weights( - self._get_weights_iterator(model_config.model, - model_config.revision, - fall_back_to_pt=getattr( - model, - "fall_back_to_pt_during_load", - True)), ) + + model.load_weights(self._get_all_weights(model_config, model)) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 32a0e895005cb..71808eb4c2719 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.utils import (flatten_bn, @@ -334,14 +335,23 @@ def __init__(self, self.multi_modal_config = multimodal_config assert self.multi_modal_config + self.secondary_weights = [] + self.audio_tower = ModifiedWhisperEncoder(config.audio_config) if config.audio_model_id is not None: - self.audio_tower = ModifiedWhisperEncoder.from_pretrained( - config.audio_model_id) - else: - self.audio_tower = ModifiedWhisperEncoder(config.audio_config) + self.secondary_weights.append( + DefaultModelLoader.Source( + model_or_path=config.audio_model_id, + revision=None, + prefix="audio_tower.", + )) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) + if config.text_model_id is not None: + self.secondary_weights.append( + DefaultModelLoader.Source(model_or_path=config.text_model_id, + revision=None, + prefix="language_model.")) def _audio_features_to_embeddings( self, input_features: torch.Tensor) -> torch.Tensor: @@ -466,6 +476,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components weights_group = group_weights_with_prefix(weights) + # load audio tower weights + audio_tower_weights = weights_group["audio_tower"] + audio_tower_params_dict = dict( + self.audio_tower.named_parameters( + prefix=self.audio_tower.base_model_prefix)) + for name, loaded_weight in audio_tower_weights: + if name in audio_tower_params_dict: + param = audio_tower_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + # load projector weights projector_weights = weights_group["multi_modal_projector"] projector_params_dict = dict(