Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add uninitialized params tracking for AutoWeightsLoader #10327

Merged
merged 28 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Inference-only Snowflake Arctic model."""
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -468,7 +468,8 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
Expand Down Expand Up @@ -506,6 +507,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("ws", f"experts.{expert_id}.w3.weight", expert_id))

params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()

logger.info(
"It will take ~10 minutes loading from the 16-bit weights. "
Expand Down Expand Up @@ -561,3 +563,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
8 changes: 6 additions & 2 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -392,13 +392,15 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
Expand Down Expand Up @@ -437,6 +439,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Set, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -337,7 +337,8 @@ def forward(

return self.encoder(hidden_states, kv_caches, attn_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "query", "q"),
Expand All @@ -346,6 +347,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
]

params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "pooler" in name:
continue
Expand All @@ -368,6 +370,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


class BertEmbeddingModel(nn.Module):
Expand Down Expand Up @@ -411,8 +415,9 @@ def pooler(
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
self.model.load_weights(weights)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
return self.model.load_weights(weights)

def _build_model(self,
vllm_config: VllmConfig,
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, Optional, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -415,14 +415,16 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:

return self.post_layernorm(hidden_states)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.encoder.layers)

for name, loaded_weight in weights:
Expand Down Expand Up @@ -450,3 +452,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
7 changes: 4 additions & 3 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)

import torch
Expand Down Expand Up @@ -692,6 +692,7 @@ def sample(
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
return loader.load_weights(weights)
8 changes: 6 additions & 2 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union

import torch
from torch import nn
Expand Down Expand Up @@ -330,8 +330,10 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if name == "lm_head.weight":
continue
Expand Down Expand Up @@ -360,3 +362,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
8 changes: 6 additions & 2 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)

import torch
Expand Down Expand Up @@ -1034,7 +1034,8 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
Expand All @@ -1044,6 +1045,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
Expand Down Expand Up @@ -1111,3 +1113,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
10 changes: 8 additions & 2 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""Inference-only ChatGLM model compatible with THUDM weights."""
from argparse import Namespace
from array import array
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict)

import torch
from PIL import Image
Expand Down Expand Up @@ -645,7 +646,8 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
"transformer.vision.linear_proj.merged_proj.weight": {
Expand All @@ -655,6 +657,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
}

params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
is_weight_to_be_merge = False
for _, merged_weight_dict in merged_weights_dict.items():
Expand All @@ -677,6 +680,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)

for combined_name, merged_weight_dict in merged_weights_dict.items():
if combined_name in params_dict:
Expand All @@ -686,3 +690,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, combined_weight)
loaded_params.add(combined_name)
return loaded_params
8 changes: 6 additions & 2 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -483,14 +483,16 @@ def device(self):

# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
] if self.shard_weight else []
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
layer_count = len(self.vision_model.encoder.layers)

for name, loaded_weight in weights:
Expand Down Expand Up @@ -518,3 +520,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
4 changes: 3 additions & 1 deletion vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,8 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
Expand Down Expand Up @@ -435,3 +436,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
8 changes: 6 additions & 2 deletions vllm/model_executor/models/dbrx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, List, Optional, Set, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -405,13 +405,15 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:

expert_params_mapping = [(
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
Expand All @@ -435,3 +437,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
8 changes: 6 additions & 2 deletions vllm/model_executor/models/decilm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""

from typing import Iterable, Tuple
from typing import Iterable, Set, Tuple

import torch

Expand Down Expand Up @@ -57,7 +57,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
delattr(config, "num_key_value_heads_per_layer")
super().__init__(vllm_config=vllm_config)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
Expand All @@ -67,6 +68,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
Expand Down Expand Up @@ -97,6 +99,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
hidden_size = self.config.hidden_size
Expand Down
Loading