Skip to content

Commit

Permalink
[Bugfix] Fix Mamba model initialization and MLP Speculator weights lo…
Browse files Browse the repository at this point in the history
…ading (vllm-project#10456)

Signed-off-by: Isotr0py <[email protected]>
  • Loading branch information
Isotr0py authored and weilong.yu committed Dec 13, 2024
1 parent f646207 commit a9219a1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
8 changes: 2 additions & 6 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""PyTorch MAMBA model."""
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, List, Optional, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -243,10 +243,8 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "A_log" in name:
name = name.replace("A_log", "A")
Expand All @@ -258,5 +256,3 @@ def load_weights(self, weights: Iterable[Tuple[str,
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
3 changes: 2 additions & 1 deletion vllm/model_executor/models/mlp_speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
param = params_dict.get(name.replace("speculator.", ""))
name = name.replace("speculator.", "")
param = params_dict.get(name)
if param is not None:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down

0 comments on commit a9219a1

Please sign in to comment.