Skip to content

Commit

Permalink
[VLM] Minor space optimization for ClipVisionModel (vllm-project#6436)
Browse files Browse the repository at this point in the history
  • Loading branch information
ywang96 authored Jul 15, 2024
1 parent 22e79ee commit 6ae1597
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 39 deletions.
46 changes: 25 additions & 21 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,22 +214,24 @@ class CLIPEncoder(nn.Module):

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
super().__init__()
self.config = config

if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
CLIPEncoderLayer(config=config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
for _ in range(num_hidden_layers)
])

def forward(self,
inputs_embeds: torch.Tensor,
vision_feature_layer: int = -1):
def forward(self, inputs_embeds: torch.Tensor):

# Encoder forward pass only up to the required layer
num_layer = len(self.layers) + vision_feature_layer + 1
hidden_states = inputs_embeds
for encoder_layer in self.layers[:num_layer]:
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)

return hidden_states
Expand All @@ -239,7 +241,8 @@ class CLIPVisionTransformer(nn.Module):

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
super().__init__()
self.config = config
embed_dim = config.hidden_size
Expand All @@ -249,18 +252,19 @@ def __init__(self,
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(config=config, quant_config=quant_config)
self.encoder = CLIPEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)

def forward(
self,
pixel_values: torch.Tensor,
vision_feature_layer: int = -1,
) -> torch.Tensor:

hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states,
vision_feature_layer=vision_feature_layer)
hidden_states = self.encoder(inputs_embeds=hidden_states)

return hidden_states

Expand All @@ -272,17 +276,17 @@ class CLIPVisionModel(nn.Module):

def __init__(self,
config: CLIPVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
super().__init__()
self.vision_model = CLIPVisionTransformer(config=config,
quant_config=quant_config)
self.vision_model = CLIPVisionTransformer(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override)

def forward(self,
pixel_values: Optional[torch.Tensor] = None,
vision_feature_layer: int = -1):
def forward(self, pixel_values: Optional[torch.Tensor] = None):

return self.vision_model(pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer)
return self.vision_model(pixel_values=pixel_values)

@property
def device(self):
Expand Down
16 changes: 12 additions & 4 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,17 @@ def __init__(self,
self.config = config
self.multimodal_config = multimodal_config

# Initialize the vision tower only up to the required feature layer
vision_feature_layer = config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1

# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config.vision_config)
self.vision_tower = CLIPVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
Expand Down Expand Up @@ -193,8 +202,7 @@ def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,

# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values,
self.config.vision_feature_layer)
image_features = vision_tower(pixel_values)

return self._select_image_features(
image_features,
Expand Down Expand Up @@ -333,7 +341,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
if use_default_weight_loading and name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down
16 changes: 12 additions & 4 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,17 @@ def __init__(self,
self.config = config
self.multimodal_config = multimodal_config

# Initialize the vision tower only up to the required feature layer
vision_feature_layer = config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1

# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config=config.vision_config)
self.vision_tower = CLIPVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
Expand Down Expand Up @@ -312,8 +321,7 @@ def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,

# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values,
self.config.vision_feature_layer)
image_features = vision_tower(pixel_values)

return self._select_image_features(
image_features,
Expand Down Expand Up @@ -561,7 +569,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
if use_default_weight_loading and name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down
27 changes: 17 additions & 10 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,11 @@ def __init__(self, wte=None) -> None:

def get_img_features(self,
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
LAYER_IDX = self.layer_idx
TYPE_FEATURE = self.type_feature

# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the img_processor
img_feature = self.img_processor(img_embeds,
vision_feature_layer=LAYER_IDX)
img_feature = self.img_processor(img_embeds)

if TYPE_FEATURE == "patch":
patch_feature = img_feature[:, 1:]
Expand All @@ -111,7 +109,17 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None:
config, 'n_embd') else config.hidden_size

clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
self.img_processor = CLIPVisionModel(clip_config)
self.layer_idx = config.img_processor.get('layer_idx', -2)

# Initialize the CLIP only up to the required feature layer
if self.layer_idx < 0:
num_hidden_layers = clip_config.num_hidden_layers + \
self.layer_idx + 1
else:
num_hidden_layers = self.layer_idx + 1

self.img_processor = CLIPVisionModel(
clip_config, num_hidden_layers_override=num_hidden_layers)
image_dim_out = config.img_processor['image_dim_out']
self.num_img_tokens = config.img_processor['num_img_tokens']

Expand Down Expand Up @@ -142,8 +150,6 @@ def __init__(self, config: PretrainedConfig, wte=None) -> None:
self.img_projection = nn.Sequential(*layers)

self.vocab_size = config.vocab_size

self.layer_idx = config.img_processor.get('layer_idx', -2)
self.type_feature = config.img_processor.get('type_feature', 'patch')

def forward(self, input_ids: torch.LongTensor,
Expand Down Expand Up @@ -588,7 +594,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

0 comments on commit 6ae1597

Please sign in to comment.