From 5ae9cada7396995e3c40557749444673aec447e6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 29 Nov 2024 13:58:50 +0000 Subject: [PATCH] Simplify code Signed-off-by: DarkLight1337 --- vllm/model_executor/models/adapters.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index bc93808ac5718..1e031c54be306 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -9,11 +9,6 @@ _T = TypeVar("_T", bound=type[nn.Module]) -def _is_paramless(module: nn.Module): - # NOTE: all([]) returns True - return all(False for _ in module.parameters()) - - def as_embedding_model(cls: _T) -> _T: """Subclass an existing vLLM model to support embeddings.""" # Avoid modifying existing embedding models @@ -40,10 +35,9 @@ def __init__( super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) # These are not used in embedding models - if hasattr(self, "lm_head"): - del self.lm_head - if hasattr(self, "logits_processor"): - del self.logits_processor + for attr in ("lm_head", "logits_processor"): + if hasattr(self, attr): + delattr(self, attr) pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -77,7 +71,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): if hasattr(self, "model") and hasattr(self.model, "load_weights"): # Whether only `self.model` contains parameters model_is_only_param = all( - name == "model" or _is_paramless(child) + name == "model" or next(child.parameters(), None) is None for name, child in self.named_children()) if model_is_only_param: