From 936d727cfffeee01eea8e2fa721098d10b2fb87c Mon Sep 17 00:00:00 2001 From: Fabian David Schmidt Date: Wed, 2 Oct 2024 20:24:16 +0200 Subject: [PATCH 1/2] feat(pooling): faster avg. with EmbeddingBags --- llm2vec/llm2vec.py | 67 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/llm2vec/llm2vec.py b/llm2vec/llm2vec.py index 32965ef..afe5588 100644 --- a/llm2vec/llm2vec.py +++ b/llm2vec/llm2vec.py @@ -2,10 +2,11 @@ import logging import os from functools import partial -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Tuple import numpy as np import torch +import torch.nn.functional as F import torch.multiprocessing as mp from peft import PeftModel from torch import Tensor, device, nn @@ -263,6 +264,9 @@ def get_pooling(self, features, last_hidden_states): # All models padded from l complete_weights[i].sum(), min=1e-9 ) return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) + elif self.pooling_mode == "embedding_bag": + input_indices, offsets = self._get_input_offsets(features["attention_mask"]) + return self._mean_embedding(last_hidden_states, input_indices, offsets) elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": return last_hidden_states[:, -1] elif self.pooling_mode == "bos_token": @@ -272,6 +276,66 @@ def get_pooling(self, features, last_hidden_states): # All models padded from l else: raise ValueError(f"{self.pooling_mode} is not implemented yet.") + @staticmethod + def _get_input_offsets( + attention_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute indices and offsets for mean pooling using EmbeddingBag. + + Args: + attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - input_indices: Indices of non-padded tokens in the flattened input. + - offsets: Offsets indicating the start index of each sequence in the flattened input. + """ + # Find the indices of non-padded tokens in flattened hidden_states + input_indices = attention_mask.view(-1).nonzero(as_tuple=False).squeeze() + + # Compute the offsets: for each sequence, where it starts in the flattened input + non_padded_lengths = attention_mask.sum( + dim=1 + ) # Count non-padded tokens per sequence + offsets = torch.cat( + [ + torch.tensor([0], device=attention_mask.device), + non_padded_lengths.cumsum(dim=0)[:-1], + ] + ) + return input_indices, offsets + + @staticmethod + def _mean_embedding( + hidden_states: torch.Tensor, + input_indices: torch.Tensor, + offsets: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the mean of non-padded embeddings using `embedding_bag`, + properly handling padding with offsets. + + Args: + hidden_states (torch.Tensor): Hidden states of shape (batch_size, seq_len, embed_dim). + input_indices (torch.Tensor): Indices of non-padded tokens in flattened form. + offsets (torch.Tensor): Offsets specifying the start of each sequence. + + Returns: + torch.Tensor: Pooled mean embeddings of shape (batch_size, embed_dim). + """ + # Flatten hidden_states to 2D: shape (batch_size * seq_len, embedding_dim) + batch_size, seq_len, embed_dim = hidden_states.shape + token_embeds = hidden_states.view(-1, embed_dim) + + # Use embedding_bag with mode 'mean' and appropriate indices + return F.embedding_bag( + input=input_indices, # Indices of non-padded tokens in flattened form + weight=token_embeds, # The flattened hidden states as embedding matrix + offsets=offsets, # Offsets specifying start of each sequence + mode="mean", # Aggregation mode + ) + def _convert_to_str(self, instruction, text): tokenized_q = self.tokenizer( text, @@ -372,7 +436,6 @@ def encode( ) all_embeddings.append(embeddings) else: - num_proc = torch.cuda.device_count() cuda_compatible_multiprocess = mp.get_context("spawn") with cuda_compatible_multiprocess.Pool(num_proc) as p: From 478869c20ba5216137cde0c49ee240ca432adcdf Mon Sep 17 00:00:00 2001 From: Fabian David Schmidt Date: Thu, 3 Oct 2024 20:15:01 +0200 Subject: [PATCH 2/2] feat(mean): make embeddingbag default --- llm2vec/llm2vec.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/llm2vec/llm2vec.py b/llm2vec/llm2vec.py index afe5588..98ed9ce 100644 --- a/llm2vec/llm2vec.py +++ b/llm2vec/llm2vec.py @@ -247,13 +247,8 @@ def get_pooling(self, features, last_hidden_states): # All models padded from l self._skip_instruction(features) seq_lengths = features["attention_mask"].sum(dim=-1) if self.pooling_mode == "mean": - return torch.stack( - [ - last_hidden_states[i, -length:, :].mean(dim=0) - for i, length in enumerate(seq_lengths) - ], - dim=0, - ) + input_indices, offsets = self._get_input_offsets(features["attention_mask"]) + return self._mean_embedding(last_hidden_states, input_indices, offsets) elif self.pooling_mode == "weighted_mean": bs, l, _ = last_hidden_states.shape complete_weights = torch.zeros(bs, l, device=last_hidden_states.device) @@ -264,9 +259,6 @@ def get_pooling(self, features, last_hidden_states): # All models padded from l complete_weights[i].sum(), min=1e-9 ) return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) - elif self.pooling_mode == "embedding_bag": - input_indices, offsets = self._get_input_offsets(features["attention_mask"]) - return self._mean_embedding(last_hidden_states, input_indices, offsets) elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": return last_hidden_states[:, -1] elif self.pooling_mode == "bos_token":