Skip to content

Commit

Permalink
remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
xffxff committed Nov 14, 2024
1 parent b2a15af commit 7f70c95
Showing 1 changed file with 0 additions and 72 deletions.
72 changes: 0 additions & 72 deletions gptfast/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,78 +562,6 @@ def __init__(self, config: ModelArgs):

self.llm = Transformer(config)

def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids
):
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape

# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (
num_special_image_tokens.max() * (num_image_patches - 1)
) + sequence_length
batch_indices, non_image_indices = torch.where(
input_ids != self.config.image_token_index
)

# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = (
torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1)
- 1
)
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]

# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size,
max_embed_dim,
embed_dim,
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)

# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
batch_indices, non_image_indices
]

# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim),
True,
dtype=torch.bool,
device=inputs_embeds.device,
)
image_to_overwrite[batch_indices, text_to_overwrite] = False

if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)

final_embedding[image_to_overwrite] = (
image_features.contiguous().reshape(-1, embed_dim).to(target_device)
)

return final_embedding

def prepare_embeddings(self, idx: Tensor, pixel_values: Tensor, pixel_mask: Tensor):
image_outputs, image_attn_mask = self.vision_tower(pixel_values, pixel_mask)
selected_image_feature = image_outputs.last_hidden_state
Expand Down

0 comments on commit 7f70c95

Please sign in to comment.