diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 405b8f7787ba8..ac0d265a961f0 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,5 +1,5 @@ """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -243,10 +243,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") @@ -258,5 +256,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index f2aa2653c4f5c..d49da5f29aa14 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -193,7 +193,8 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - param = params_dict.get(name.replace("speculator.", "")) + name = name.replace("speculator.", "") + param = params_dict.get(name) if param is not None: weight_loader = getattr(param, "weight_loader", default_weight_loader)