Skip to content

Commit

Permalink
Change the sparse embedding from List[csr_array] to csr_array (#1972)
Browse files Browse the repository at this point in the history
* Implement batchify features to bge_m3 and splade.

Signed-off-by: wxywb <[email protected]>

* Change sparse embedding output from List[csr_array] to csr_array.

Signed-off-by: wxywb <[email protected]>

---------

Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored Mar 17, 2024
1 parent 436f559 commit 0fd83fb
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 54 deletions.
4 changes: 3 additions & 1 deletion pymilvus/model/hybrid/bge_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from typing import Dict, List

from scipy.sparse import csr_array
from scipy.sparse import csr_array, vstack

from pymilvus.model.base import BaseEmbeddingFunction

Expand Down Expand Up @@ -49,6 +49,7 @@ def __init__(
**kwargs,
)
_encode_config = {
"batch_size": batch_size,
"return_dense": return_dense,
"return_sparse": return_sparse,
"return_colbert_vecs": return_colbert_vecs,
Expand Down Expand Up @@ -86,6 +87,7 @@ def _encode(self, texts: List[str]) -> Dict:
row_indices = [0] * len(indices)
csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim))
results["sparse"].append(csr)
results["sparse"] = vstack(results["sparse"])
if self._encode_config["return_colbert_vecs"] is True:
results["colbert_vecs"] = output["colbert_vecs"]
return results
Expand Down
20 changes: 8 additions & 12 deletions pymilvus/model/sparse/bm25/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Dict, List, Optional

import requests
from scipy.sparse import csr_array
from scipy.sparse import csr_array, vstack

from pymilvus.model.base import BaseEmbeddingFunction
from pymilvus.model.sparse.bm25.tokenizers import Analyzer, build_default_analyzer
Expand Down Expand Up @@ -158,21 +158,17 @@ def _encode_document(self, doc: str) -> csr_array:
values.append(value)
return csr_array((values, (rows, cols)), shape=(1, len(self.idf)))

def encode_queries(self, queries: List[str]) -> List[csr_array]:
if self.num_workers == 1:
return [self._encode_query(query) for query in queries]
with Pool(self.num_workers) as pool:
return pool.map(self._encode_query, queries)
def encode_queries(self, queries: List[str]) -> csr_array:
sparse_embs = [self._encode_query(query) for query in queries]
return vstack(sparse_embs)

def __call__(self, texts: List[str]) -> List[csr_array]:
def __call__(self, texts: List[str]) -> csr_array:
error_message = "Unsupported function called, please check the documentation of 'BM25EmbeddingFunction'."
raise ValueError(error_message)

def encode_documents(self, documents: List[str]) -> List[csr_array]:
if self.num_workers == 1:
return [self._encode_document(document) for document in documents]
with Pool(self.num_workers) as pool:
return pool.map(self._encode_document, documents)
def encode_documents(self, documents: List[str]) -> csr_array:
sparse_embs = [self._encode_document(document) for document in documents]
return vstack(sparse_embs)

def save(self, path: str):
bm25_params = {}
Expand Down
71 changes: 30 additions & 41 deletions pymilvus/model/sparse/splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

import numpy as np
import torch
from scipy.sparse import csr_array
from scipy.sparse import csr_array, vstack

from pymilvus.model.base import BaseEmbeddingFunction

Expand Down Expand Up @@ -67,27 +67,23 @@ def __init__(
self.query_instruction = query_instruction
self.doc_instruction = doc_instruction

def __call__(self, texts: List[str]) -> List[csr_array]:
embs = self._encode(texts, None)
return list(embs)
def __call__(self, texts: List[str]) -> csr_array:
return self._encode(texts, None)

def encode_documents(self, documents: List[str]) -> List[csr_array]:
embs = self._encode(
def encode_documents(self, documents: List[str]) -> csr_array:
return self._encode(
[self.doc_instruction + document for document in documents],
self.k_tokens_document,
)
return list(embs)

def _encode(self, texts: List[str], k_tokens: int) -> List[csr_array]:
embs = self.model.forward(texts, k_tokens=k_tokens)
return list(embs)
def _encode(self, texts: List[str], k_tokens: int) -> csr_array:
return self.model.forward(texts, k_tokens=k_tokens)

def encode_queries(self, queries: List[str]) -> List[csr_array]:
embs = self._encode(
def encode_queries(self, queries: List[str]) -> csr_array:
return self._encode(
[self.query_instruction + query for query in queries],
self.k_tokens_query,
)
return list(embs)

@property
def dim(self) -> int:
Expand Down Expand Up @@ -138,19 +134,27 @@ def _encode(self, texts: List[str]):
output = self.model(**encoded_input)
return output.logits

def forward(self, texts: List[str], k_tokens: int):
logits = self._encode(texts=texts)
activations = self._get_activation(logits=logits)

if k_tokens is None:
nonzero_indices = [
torch.nonzero(activations["sparse_activations"][i]).t()[0]
for i in range(len(texts))
]
activations["activations"] = nonzero_indices
else:
activations = self._update_activations(**activations, k_tokens=k_tokens)
return self._convert_to_csr_array(activations)
def _batchify(self, texts: List[str], batch_size: int) -> List[List[str]]:
return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]

def forward(self, texts: List[str], k_tokens: int) -> csr_array:
batched_texts = self._batchify(texts, self.batch_size)
sparse_embs = []
for batch_texts in batched_texts:
logits = self._encode(texts=batch_texts)
activations = self._get_activation(logits=logits)
if k_tokens is None:
nonzero_indices = [
torch.nonzero(activations["sparse_activations"][i]).t()[0]
for i in range(len(batch_texts))
]
activations["activations"] = nonzero_indices
else:
activations = self._update_activations(**activations, k_tokens=k_tokens)
batch_csr = self._convert_to_csr_array(activations)
sparse_embs.extend(batch_csr)

return vstack(sparse_embs)

def _get_activation(self, logits: torch.Tensor) -> Dict[str, torch.Tensor]:
return {"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1)}
Expand Down Expand Up @@ -201,18 +205,3 @@ def _convert_to_csr_array(self, activations: Dict):
)
)
return csr_array_list

def _convert_to_csr_array2(self, activations: Dict):
values = (
torch.gather(activations["sparse_activations"], 1, activations["activations"])
.cpu()
.detach()
.numpy()
)
rows, cols = activations["activations"].shape
row_indices = np.repeat(np.arange(rows), cols)
col_indices = activations["activations"].detach().cpu().numpy().flatten()
return csr_array(
(values.flatten(), (row_indices, col_indices)),
shape=(rows, activations["sparse_activations"].shape[1]),
)

0 comments on commit 0fd83fb

Please sign in to comment.