diff --git a/llm2vec/llm2vec.py b/llm2vec/llm2vec.py index 60ec7c9..98ed9ce 100644 --- a/llm2vec/llm2vec.py +++ b/llm2vec/llm2vec.py @@ -2,10 +2,11 @@ import logging import os from functools import partial -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Tuple import numpy as np import torch +import torch.nn.functional as F import torch.multiprocessing as mp from peft import PeftModel from torch import Tensor, device, nn @@ -246,13 +247,8 @@ def get_pooling(self, features, last_hidden_states): # All models padded from l self._skip_instruction(features) seq_lengths = features["attention_mask"].sum(dim=-1) if self.pooling_mode == "mean": - return torch.stack( - [ - last_hidden_states[i, -length:, :].mean(dim=0) - for i, length in enumerate(seq_lengths) - ], - dim=0, - ) + input_indices, offsets = self._get_input_offsets(features["attention_mask"]) + return self._mean_embedding(last_hidden_states, input_indices, offsets) elif self.pooling_mode == "weighted_mean": bs, l, _ = last_hidden_states.shape complete_weights = torch.zeros(bs, l, device=last_hidden_states.device) @@ -272,6 +268,66 @@ def get_pooling(self, features, last_hidden_states): # All models padded from l else: raise ValueError(f"{self.pooling_mode} is not implemented yet.") + @staticmethod + def _get_input_offsets( + attention_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute indices and offsets for mean pooling using EmbeddingBag. + + Args: + attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - input_indices: Indices of non-padded tokens in the flattened input. + - offsets: Offsets indicating the start index of each sequence in the flattened input. + """ + # Find the indices of non-padded tokens in flattened hidden_states + input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze() + + # Compute the offsets: for each sequence, where it starts in the flattened input + non_padded_lengths = attention_mask.sum( + dim=1 + ) # Count non-padded tokens per sequence + offsets = torch.cat( + [ + torch.tensor([0], device=attention_mask.device), + non_padded_lengths.cumsum(dim=0)[:-1], + ] + ) + return input_indices, offsets + + @staticmethod + def _mean_embedding( + hidden_states: torch.Tensor, + input_indices: torch.Tensor, + offsets: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the mean of non-padded embeddings using `embedding_bag`, + properly handling padding with offsets. + + Args: + hidden_states (torch.Tensor): Hidden states of shape (batch_size, seq_len, embed_dim). + input_indices (torch.Tensor): Indices of non-padded tokens in flattened form. + offsets (torch.Tensor): Offsets specifying the start of each sequence. + + Returns: + torch.Tensor: Pooled mean embeddings of shape (batch_size, embed_dim). + """ + # Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim) + batch_size, seq_len, embed_dim = hidden_states.shape + token_embeds = hidden_states.view(-1, embed_dim) + + # Use embedding_bag with mode 'mean' and appropriate indices + return F.embedding_bag( + input=input_indices, # Indices of non-padded tokens in flattened form + weight=token_embeds, # The flattened hidden states as embedding matrix + offsets=offsets, # Offsets specifying start of each sequence + mode="mean", # Aggregation mode + ) + def _convert_to_str(self, instruction, text): tokenized_q = self.tokenizer( text,