Skip to content

Commit

Permalink
[Misc] Add uninitialized params tracking for AutoWeightsLoader (vll…
Browse files Browse the repository at this point in the history
…m-project#10327)

Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: rickyx <[email protected]>
  • Loading branch information
Isotr0py authored and rickyyx committed Nov 20, 2024
1 parent 48443fb commit 4e9c8f9
Show file tree
Hide file tree
Showing 74 changed files with 454 additions and 185 deletions.
12 changes: 11 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,17 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
with target_device:
model = _initialize_model(vllm_config=vllm_config)

model.load_weights(self._get_all_weights(model_config, model))
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self._get_all_weights(model_config, model))
# We only enable strict check for non-quantiized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Inference-only Snowflake Arctic model."""
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -480,7 +480,8 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
Expand Down Expand Up @@ -518,6 +519,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("ws", f"experts.{expert_id}.w3.weight", expert_id))

params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()

logger.info(
"It will take ~10 minutes loading from the 16-bit weights. "
Expand Down Expand Up @@ -573,3 +575,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
8 changes: 6 additions & 2 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -404,13 +404,15 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
Expand Down Expand Up @@ -449,6 +451,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Set, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -337,7 +337,8 @@ def forward(

return self.encoder(hidden_states, kv_caches, attn_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"),
Expand All @@ -346,6 +347,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]

params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "pooler" in name:
continue
Expand All @@ -368,6 +370,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class BertEmbeddingModel(nn.Module):
Expand Down
12 changes: 8 additions & 4 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, Optional, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -415,14 +415,16 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:

return self.post_layernorm(hidden_states)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.encoder.layers)

for name, loaded_weight in weights:
Expand All @@ -440,8 +442,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue

param = params_dict[name.replace(weight_name, param_name)]
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
Expand All @@ -450,3 +452,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
7 changes: 4 additions & 3 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)

import torch
Expand Down Expand Up @@ -692,6 +692,7 @@ def sample(
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
return loader.load_weights(weights)
8 changes: 6 additions & 2 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -341,8 +341,10 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
Expand Down Expand Up @@ -371,3 +373,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
8 changes: 6 additions & 2 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)

import torch
Expand Down Expand Up @@ -1034,7 +1034,8 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
Expand All @@ -1044,6 +1045,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
Expand Down Expand Up @@ -1111,3 +1113,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
10 changes: 8 additions & 2 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""Inference-only ChatGLM model compatible with THUDM weights."""
from argparse import Namespace
from array import array
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict)

import torch
from PIL import Image
Expand Down Expand Up @@ -645,7 +646,8 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
"transformer.vision.linear_proj.merged_proj.weight": {
Expand All @@ -655,6 +657,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
}

params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
is_weight_to_be_merge = False
for _, merged_weight_dict in merged_weights_dict.items():
Expand All @@ -677,6 +680,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)

for combined_name, merged_weight_dict in merged_weights_dict.items():
if combined_name in params_dict:
Expand All @@ -686,3 +690,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, combined_weight)
loaded_params.add(combined_name)
return loaded_params
11 changes: 8 additions & 3 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -483,14 +483,16 @@ def device(self):

# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers)

for name, loaded_weight in weights:
Expand All @@ -508,8 +510,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

param = params_dict[name.replace(weight_name, param_name)]
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
Expand All @@ -518,3 +521,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
4 changes: 3 additions & 1 deletion vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,8 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
Expand Down Expand Up @@ -447,3 +448,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
8 changes: 6 additions & 2 deletions vllm/model_executor/models/dbrx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -417,13 +417,15 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:

expert_params_mapping = [(
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
Expand All @@ -447,3 +449,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
Loading

0 comments on commit 4e9c8f9

Please sign in to comment.