From d3311562fbe740a883e7f03f0b59620587cabb29 Mon Sep 17 00:00:00 2001 From: wnma Date: Wed, 4 Sep 2024 18:55:37 +0800 Subject: [PATCH] [Bugfix] remove post_layernorm in siglip (#8106) --- vllm/model_executor/models/siglip.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 114dbf09b0c53..0bee75e2f0cbb 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -443,14 +443,27 @@ def __init__( self.config = config embed_dim = config.hidden_size + if (num_hidden_layers_override is None + or num_hidden_layers_override == config.num_hidden_layers): + self.need_post_layernorm = True + elif num_hidden_layers_override > config.num_hidden_layers: + raise ValueError( + "num_hidden_layers_override cannot be greater than " + "num_hidden_layers") + else: + self.need_post_layernorm = False + self.embeddings = SiglipVisionEmbeddings(config) self.encoder = SiglipEncoder( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, ) - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) + if self.need_post_layernorm: + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + else: + self.post_layernorm = nn.Identity() self.use_head = (True if not hasattr(config, "vision_use_head") else config.vision_use_head) if self.use_head: @@ -470,7 +483,6 @@ def forward( encoder_outputs = self.encoder(inputs_embeds=hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) - # TODO: add this back when pooled_output is used in inference # if self.use_head: # pooled_output = self.head(last_hidden_state) @@ -499,6 +511,10 @@ def __init__( num_hidden_layers_override=num_hidden_layers_override, ) + @property + def need_post_layernorm(self): + return self.vision_model.need_post_layernorm + def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @@ -517,6 +533,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): layer_count = len(self.vision_model.encoder.layers) for name, loaded_weight in weights: + # post_layernorm is optional in SiglipVisionModel + if ("vision_model.post_layernorm" in name + and not self.need_post_layernorm): + continue + # omit layers when num_hidden_layers_override is set if "vision_model.encoder.layers." in name: layer_idx = int(name.split(".")[3])