From 0fd83fbf44c25fc247f97afd2703740664d62bce Mon Sep 17 00:00:00 2001 From: wxywb Date: Sun, 17 Mar 2024 16:03:00 +0800 Subject: [PATCH] Change the sparse embedding from List[csr_array] to csr_array (#1972) * Implement batchify features to bge_m3 and splade. Signed-off-by: wxywb * Change sparse embedding output from List[csr_array] to csr_array. Signed-off-by: wxywb --------- Signed-off-by: wxywb --- pymilvus/model/hybrid/bge_m3.py | 4 +- pymilvus/model/sparse/bm25/bm25.py | 20 ++++----- pymilvus/model/sparse/splade.py | 71 +++++++++++++----------------- 3 files changed, 41 insertions(+), 54 deletions(-) diff --git a/pymilvus/model/hybrid/bge_m3.py b/pymilvus/model/hybrid/bge_m3.py index 9db69c1c4..7181bc1b2 100644 --- a/pymilvus/model/hybrid/bge_m3.py +++ b/pymilvus/model/hybrid/bge_m3.py @@ -2,7 +2,7 @@ from collections import defaultdict from typing import Dict, List -from scipy.sparse import csr_array +from scipy.sparse import csr_array, vstack from pymilvus.model.base import BaseEmbeddingFunction @@ -49,6 +49,7 @@ def __init__( **kwargs, ) _encode_config = { + "batch_size": batch_size, "return_dense": return_dense, "return_sparse": return_sparse, "return_colbert_vecs": return_colbert_vecs, @@ -86,6 +87,7 @@ def _encode(self, texts: List[str]) -> Dict: row_indices = [0] * len(indices) csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim)) results["sparse"].append(csr) + results["sparse"] = vstack(results["sparse"]) if self._encode_config["return_colbert_vecs"] is True: results["colbert_vecs"] = output["colbert_vecs"] return results diff --git a/pymilvus/model/sparse/bm25/bm25.py b/pymilvus/model/sparse/bm25/bm25.py index c74f10dfd..204a938a5 100644 --- a/pymilvus/model/sparse/bm25/bm25.py +++ b/pymilvus/model/sparse/bm25/bm25.py @@ -23,7 +23,7 @@ from typing import Dict, List, Optional import requests -from scipy.sparse import csr_array +from scipy.sparse import csr_array, vstack from pymilvus.model.base import BaseEmbeddingFunction from pymilvus.model.sparse.bm25.tokenizers import Analyzer, build_default_analyzer @@ -158,21 +158,17 @@ def _encode_document(self, doc: str) -> csr_array: values.append(value) return csr_array((values, (rows, cols)), shape=(1, len(self.idf))) - def encode_queries(self, queries: List[str]) -> List[csr_array]: - if self.num_workers == 1: - return [self._encode_query(query) for query in queries] - with Pool(self.num_workers) as pool: - return pool.map(self._encode_query, queries) + def encode_queries(self, queries: List[str]) -> csr_array: + sparse_embs = [self._encode_query(query) for query in queries] + return vstack(sparse_embs) - def __call__(self, texts: List[str]) -> List[csr_array]: + def __call__(self, texts: List[str]) -> csr_array: error_message = "Unsupported function called, please check the documentation of 'BM25EmbeddingFunction'." raise ValueError(error_message) - def encode_documents(self, documents: List[str]) -> List[csr_array]: - if self.num_workers == 1: - return [self._encode_document(document) for document in documents] - with Pool(self.num_workers) as pool: - return pool.map(self._encode_document, documents) + def encode_documents(self, documents: List[str]) -> csr_array: + sparse_embs = [self._encode_document(document) for document in documents] + return vstack(sparse_embs) def save(self, path: str): bm25_params = {} diff --git a/pymilvus/model/sparse/splade.py b/pymilvus/model/sparse/splade.py index 34e549625..edda3c448 100644 --- a/pymilvus/model/sparse/splade.py +++ b/pymilvus/model/sparse/splade.py @@ -31,7 +31,7 @@ import numpy as np import torch -from scipy.sparse import csr_array +from scipy.sparse import csr_array, vstack from pymilvus.model.base import BaseEmbeddingFunction @@ -67,27 +67,23 @@ def __init__( self.query_instruction = query_instruction self.doc_instruction = doc_instruction - def __call__(self, texts: List[str]) -> List[csr_array]: - embs = self._encode(texts, None) - return list(embs) + def __call__(self, texts: List[str]) -> csr_array: + return self._encode(texts, None) - def encode_documents(self, documents: List[str]) -> List[csr_array]: - embs = self._encode( + def encode_documents(self, documents: List[str]) -> csr_array: + return self._encode( [self.doc_instruction + document for document in documents], self.k_tokens_document, ) - return list(embs) - def _encode(self, texts: List[str], k_tokens: int) -> List[csr_array]: - embs = self.model.forward(texts, k_tokens=k_tokens) - return list(embs) + def _encode(self, texts: List[str], k_tokens: int) -> csr_array: + return self.model.forward(texts, k_tokens=k_tokens) - def encode_queries(self, queries: List[str]) -> List[csr_array]: - embs = self._encode( + def encode_queries(self, queries: List[str]) -> csr_array: + return self._encode( [self.query_instruction + query for query in queries], self.k_tokens_query, ) - return list(embs) @property def dim(self) -> int: @@ -138,19 +134,27 @@ def _encode(self, texts: List[str]): output = self.model(**encoded_input) return output.logits - def forward(self, texts: List[str], k_tokens: int): - logits = self._encode(texts=texts) - activations = self._get_activation(logits=logits) - - if k_tokens is None: - nonzero_indices = [ - torch.nonzero(activations["sparse_activations"][i]).t()[0] - for i in range(len(texts)) - ] - activations["activations"] = nonzero_indices - else: - activations = self._update_activations(**activations, k_tokens=k_tokens) - return self._convert_to_csr_array(activations) + def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]: + return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)] + + def forward(self, texts: List[str], k_tokens: int) -> csr_array: + batched_texts = self._batchify(texts, self.batch_size) + sparse_embs = [] + for batch_texts in batched_texts: + logits = self._encode(texts=batch_texts) + activations = self._get_activation(logits=logits) + if k_tokens is None: + nonzero_indices = [ + torch.nonzero(activations["sparse_activations"][i]).t()[0] + for i in range(len(batch_texts)) + ] + activations["activations"] = nonzero_indices + else: + activations = self._update_activations(**activations, k_tokens=k_tokens) + batch_csr = self._convert_to_csr_array(activations) + sparse_embs.extend(batch_csr) + + return vstack(sparse_embs) def _get_activation(self, logits: torch.Tensor) -> Dict[str, torch.Tensor]: return {"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1)} @@ -201,18 +205,3 @@ def _convert_to_csr_array(self, activations: Dict): ) ) return csr_array_list - - def _convert_to_csr_array2(self, activations: Dict): - values = ( - torch.gather(activations["sparse_activations"], 1, activations["activations"]) - .cpu() - .detach() - .numpy() - ) - rows, cols = activations["activations"].shape - row_indices = np.repeat(np.arange(rows), cols) - col_indices = activations["activations"].detach().cpu().numpy().flatten() - return csr_array( - (values.flatten(), (row_indices, col_indices)), - shape=(rows, activations["sparse_activations"].shape[1]), - )