From dd6ee68e387daf3f19ec52003d8672adfa5d2b8e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 19 Nov 2024 23:52:39 +0800 Subject: [PATCH 1/5] disable mamba weights tracker and fix mpl speculator Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/mamba.py | 8 ++------ vllm/model_executor/models/mlp_speculator.py | 3 ++- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 405b8f7787ba8..ac0d265a961f0 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,5 +1,5 @@ """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -243,10 +243,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") @@ -258,5 +256,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index f2aa2653c4f5c..d49da5f29aa14 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -193,7 +193,8 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - param = params_dict.get(name.replace("speculator.", "")) + name = name.replace("speculator.", "") + param = params_dict.get(name) if param is not None: weight_loader = getattr(param, "weight_loader", default_weight_loader) From d6427023fae900d7b4173df878e7eec90e338027 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 20 Nov 2024 00:53:04 +0800 Subject: [PATCH 2/5] revert mamba weights tracking Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/mamba.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ac0d265a961f0..fac01ee0c8d99 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,5 +1,5 @@ """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Set, Tuple import torch from torch import nn @@ -48,7 +48,7 @@ def __init__(self, time_step_rank=config.time_step_rank, use_conv_bias=config.use_conv_bias, use_bias=config.use_bias, - use_rms_norm=self.is_falcon_mamba, + use_rms_norm=False, rms_norm_eps=mixer_rms_eps, activation=config.hidden_act) @@ -243,8 +243,10 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") @@ -256,3 +258,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params From d6ebd1ff69dce8dc55ef8cd65dec33592cad685f Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 20 Nov 2024 11:19:35 +0800 Subject: [PATCH 3/5] revert mamba Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/mamba.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index fac01ee0c8d99..1c45559577574 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -48,7 +48,7 @@ def __init__(self, time_step_rank=config.time_step_rank, use_conv_bias=config.use_conv_bias, use_bias=config.use_bias, - use_rms_norm=False, + use_rms_norm=self.is_falcon_mamba, rms_norm_eps=mixer_rms_eps, activation=config.hidden_act) @@ -246,7 +246,6 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") @@ -258,5 +257,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params From 30f4b2d913c0186b4f0b30079fb8f3f995505705 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 20 Nov 2024 11:24:05 +0800 Subject: [PATCH 4/5] code format Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/mamba.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 1c45559577574..2f12acff08044 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -243,8 +243,7 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "A_log" in name: From 852f0775ca64943feaea85ac6232c2dcc38ce28d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 20 Nov 2024 11:25:22 +0800 Subject: [PATCH 5/5] code format Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 2f12acff08044..ac0d265a961f0 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,5 +1,5 @@ """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn