-
Notifications
You must be signed in to change notification settings - Fork 334
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add embedding models to pymilvus. (#1971)
Signed-off-by: wxywb <[email protected]>
- Loading branch information
Showing
15 changed files
with
1,148 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# hello_model.py simplifies the demonstration of using various embedding functions in PyMilvus, | ||
# focusing on dense, sparse, and hybrid models. This script illustrates: | ||
# - Initializing and using OpenAIEmbeddingFunction for dense embeddings | ||
# - Initializing and using BGEM3EmbeddingFunction for hybrid embeddings | ||
# - Initializing and using SentenceTransformerEmbeddingFunction for dense embeddings | ||
# - Initializing and using BM25EmbeddingFunction for sparse embeddings | ||
# - Initializing and using SpladeEmbeddingFunction for sparse embeddings | ||
import time | ||
|
||
from pymilvus.model.dense import OpenAIEmbeddingFunction, SentenceTransformerEmbeddingFunction | ||
from pymilvus.model.hybrid import BGEM3EmbeddingFunction | ||
from pymilvus.model.sparse import BM25EmbeddingFunction, SpladeEmbeddingFunction | ||
|
||
fmt = "=== {:30} ===" | ||
|
||
|
||
def log(msg): | ||
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " " + msg) | ||
|
||
|
||
# OpenAIEmbeddingFunction usage | ||
docs = [ | ||
"Artificial intelligence was founded as an academic discipline in 1956.", | ||
"Alan Turing was the first person to conduct substantial research in AI.", | ||
"Born in Maida Vale, London, Turing was raised in southern England.", | ||
] | ||
log(fmt.format("OpenAIEmbeddingFunction Usage")) | ||
|
||
ef_openai = OpenAIEmbeddingFunction(api_key="sk-your-api-key") | ||
embs_openai = ef_openai(docs) | ||
log(f"Dimension: {ef_openai.dim} Embedding Shape: {embs_openai[0].shape}") | ||
|
||
# ----------------------------------------------------------------------------- | ||
# BGEM3EmbeddingFunction usage | ||
log(fmt.format("BGEM3EmbeddingFunction Usage")) | ||
ef_bge = BGEM3EmbeddingFunction(device="cpu", use_fp16=False) | ||
embs_bge = ef_bge(docs) | ||
log("Embedding Shape: {} Dimension: {}".format(embs_bge["dense"][0].shape, ef_bge.dim)) | ||
|
||
# ----------------------------------------------------------------------------- | ||
# SentenceTransformerEmbeddingFunction usage | ||
log(fmt.format("SentenceTransformerEmbeddingFunction Usage")) | ||
ef_sentence_transformer = SentenceTransformerEmbeddingFunction(device="cpu") | ||
embs_sentence_transformer = ef_sentence_transformer(docs) | ||
log( | ||
"Embedding Shape: {} Dimension: {}".format( | ||
embs_sentence_transformer[0].shape, ef_sentence_transformer.dim | ||
) | ||
) | ||
|
||
# ----------------------------------------------------------------------------- | ||
# BM25EmbeddingFunction usage | ||
log(fmt.format("BM25EmbeddingFunction Usage")) | ||
ef_bm25 = BM25EmbeddingFunction() | ||
docs_bm25 = [ | ||
"Artificial intelligence was founded as an academic discipline in 1956.", | ||
"Alan Turing was the first person to conduct substantial research in AI.", | ||
"Born in Maida Vale, London, Turing was raised in southern England.", | ||
] | ||
ef_bm25.load() | ||
embs_bm25 = ef_bm25.encode_documents(docs) | ||
log(f"Embedding Shape: {embs_bm25[0].shape} Dimension: {ef_bm25.dim}") | ||
|
||
# ----------------------------------------------------------------------------- | ||
# SpladeEmbeddingFunction usage | ||
log(fmt.format("SpladeEmbeddingFunction Usage")) | ||
ef_splade = SpladeEmbeddingFunction(device="cpu") | ||
embs_splade = ef_splade(["Hello world", "Hello world2"]) | ||
log(f"Embedding Shape: {embs_splade[0].shape} Dimension: {ef_splade.dim}") | ||
|
||
# ----------------------------------------------------------------------------- | ||
log(fmt.format("Demonstrations Finished")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from . import dense, hybrid, sparse | ||
from .dense.sentence_transformer import SentenceTransformerEmbeddingFunction | ||
|
||
__all__ = ["DefaultEmbeddingFunction", "dense", "sparse", "hybrid"] | ||
|
||
DefaultEmbeddingFunction = SentenceTransformerEmbeddingFunction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from abc import abstractmethod | ||
from typing import List | ||
|
||
|
||
class BaseEmbeddingFunction: | ||
model_name: str | ||
|
||
@abstractmethod | ||
def __call__(self, texts: List[str]): | ||
""" """ | ||
|
||
@abstractmethod | ||
def encode_queries(self, queries: List[str]): | ||
""" """ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .openai import OpenAIEmbeddingFunction | ||
from .sentence_transformer import SentenceTransformerEmbeddingFunction | ||
|
||
__all__ = ["OpenAIEmbeddingFunction", "SentenceTransformerEmbeddingFunction"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from collections import defaultdict | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
|
||
from pymilvus.model.base import BaseEmbeddingFunction | ||
|
||
|
||
class OpenAIEmbeddingFunction(BaseEmbeddingFunction): | ||
def __init__( | ||
self, | ||
model_name: str = "text-embedding-ada-002", | ||
api_key: Optional[str] = None, | ||
base_url: Optional[str] = None, | ||
dimensions: Optional[int] = None, | ||
**kwargs, | ||
): | ||
try: | ||
from openai import OpenAI | ||
except ImportError as err: | ||
error_message = "openai is not installed." | ||
raise ImportError(error_message) from err | ||
|
||
self._openai_model_meta_info = defaultdict(dict) | ||
self._openai_model_meta_info["text-embedding-3-small"]["dim"] = 1536 | ||
self._openai_model_meta_info["text-embedding-3-large"]["dim"] = 3072 | ||
self._openai_model_meta_info["text-embedding-ada-002"]["dim"] = 1536 | ||
|
||
self._model_config = dict({"api_key": api_key, "base_url": base_url}, **kwargs) | ||
additional_encode_config = {} | ||
if dimensions is not None: | ||
additional_encode_config = {"dimensions": dimensions} | ||
self._openai_model_meta_info[model_name]["dim"] = dimensions | ||
|
||
self._encode_config = {"model": model_name, **additional_encode_config} | ||
self.model_name = model_name | ||
self.client = OpenAI(**self._model_config) | ||
|
||
def encode_queries(self, queries: List[str]) -> List[np.array]: | ||
return self._encode(queries) | ||
|
||
def encode_documents(self, documents: List[str]) -> List[np.array]: | ||
return self._encode(documents) | ||
|
||
@property | ||
def dim(self): | ||
return self._openai_model_meta_info[self.model_name]["dim"] | ||
|
||
def __call__(self, texts: List[str]) -> List[np.array]: | ||
return self._encode(texts) | ||
|
||
def _encode_query(self, query: str) -> np.array: | ||
return self._encode(query)[0] | ||
|
||
def _encode_document(self, document: str) -> np.array: | ||
return self._encode(document)[0] | ||
|
||
def _call_openai_api(self, texts: List[str]): | ||
results = self.client.embeddings.create(input=texts, **self._encode_config).data | ||
return [np.array(data.embedding) for data in results] | ||
|
||
def _encode(self, texts: List[str]): | ||
return self._call_openai_api(texts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
from pymilvus.model.base import BaseEmbeddingFunction | ||
|
||
|
||
class SentenceTransformerEmbeddingFunction(BaseEmbeddingFunction): | ||
def __init__( | ||
self, | ||
model_name: str = "all-MiniLM-L6-v2", | ||
batch_size: int = 32, | ||
query_instruction: str = "", | ||
doc_instruction: str = "", | ||
device: str = "cpu", | ||
normalize_embeddings: bool = True, | ||
**kwargs, | ||
): | ||
try: | ||
from sentence_transformers import SentenceTransformer | ||
except ImportError as err: | ||
error_message = "sentence-transformers is not installed." | ||
raise ImportError(error_message) from err | ||
self.model_name = model_name | ||
self.query_instruction = query_instruction | ||
self.doc_instruction = doc_instruction | ||
self.batch_size = batch_size | ||
self.normalize_embeddings = normalize_embeddings | ||
|
||
_model_config = dict({"model_name_or_path": model_name, "device": device}, **kwargs) | ||
self.model = SentenceTransformer(**_model_config) | ||
|
||
def __call__(self, texts: List[str]) -> List[np.array]: | ||
return self._encode(texts) | ||
|
||
def _encode(self, texts: List[str]) -> List[np.array]: | ||
embs = self.model.encode( | ||
texts, | ||
batch_size=self.batch_size, | ||
show_progress_bar=False, | ||
convert_to_numpy=True, | ||
) | ||
return list(embs) | ||
|
||
@property | ||
def dim(self): | ||
return self.model.get_sentence_embedding_dimension() | ||
|
||
def encode_queries(self, queries: List[str]) -> List[np.array]: | ||
instructed_queries = [self.query_instruction + query for query in queries] | ||
return self._encode(instructed_queries) | ||
|
||
def encode_documents(self, documents: List[str]) -> List[np.array]: | ||
instructed_documents = [self.doc_instruction + document for document in documents] | ||
return self._encode(instructed_documents) | ||
|
||
def _encode_query(self, query: str) -> np.array: | ||
instructed_query = self.query_instruction + query | ||
embs = self.model.encode( | ||
sentences=[instructed_query], | ||
batch_size=1, | ||
show_progress_bar=False, | ||
convert_to_numpy=True, | ||
normalize_embeddings=self.normalize_embeddings, | ||
) | ||
return embs[0] | ||
|
||
def _encode_document(self, document: str) -> np.array: | ||
instructed_document = self.doc_instruction + document | ||
embs = self.model.encode( | ||
sentences=[instructed_document], | ||
batch_size=1, | ||
show_progress_bar=False, | ||
convert_to_numpy=True, | ||
normalize_embeddings=self.normalize_embeddings, | ||
) | ||
return embs[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .bge_m3 import BGEM3EmbeddingFunction | ||
|
||
__all__ = ["BGEM3EmbeddingFunction"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import logging | ||
from collections import defaultdict | ||
from typing import Dict, List | ||
|
||
from scipy.sparse import csr_array | ||
|
||
from pymilvus.model.base import BaseEmbeddingFunction | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
|
||
class BGEM3EmbeddingFunction(BaseEmbeddingFunction): | ||
def __init__( | ||
self, | ||
model_name: str = "BAAI/bge-m3", | ||
batch_size: int = 16, | ||
device: str = "", | ||
normalize_embeddings: bool = True, | ||
use_fp16: bool = True, | ||
return_dense: bool = True, | ||
return_sparse: bool = True, | ||
return_colbert_vecs: bool = False, | ||
**kwargs, | ||
): | ||
try: | ||
from FlagEmbedding import BGEM3FlagModel | ||
except ImportError as err: | ||
error_message = "FlagEmbedding is not installed." | ||
raise ImportError(error_message) from err | ||
self.model_name = model_name | ||
self.batch_size = batch_size | ||
self.normalize_embeddings = normalize_embeddings | ||
self.device = device | ||
self.use_fp16 = use_fp16 | ||
|
||
if device == "cpu" and use_fp16 is True: | ||
logger.warning( | ||
"Using fp16 with CPU can lead to runtime errors such as 'LayerNormKernelImpl', It's recommended to set 'use_fp16 = False' when using cpu. " | ||
) | ||
|
||
_model_config = dict( | ||
{ | ||
"model_name_or_path": model_name, | ||
"device": device, | ||
"normalize_embeddings": normalize_embeddings, | ||
"use_fp16": use_fp16, | ||
}, | ||
**kwargs, | ||
) | ||
_encode_config = { | ||
"return_dense": return_dense, | ||
"return_sparse": return_sparse, | ||
"return_colbert_vecs": return_colbert_vecs, | ||
} | ||
self._model_config = _model_config | ||
self._encode_config = _encode_config | ||
|
||
self.model = BGEM3FlagModel(**self._model_config) | ||
meta_info = defaultdict(dict) | ||
meta_info["BAAI/bge-m3"]["dim"] = { | ||
"dense": 1024, | ||
"sparse": len(self.model.tokenizer), | ||
"colbert_vecs": 1024, | ||
} | ||
self._meta_info = meta_info | ||
|
||
def __call__(self, texts: List[str]) -> Dict: | ||
return self._encode(texts) | ||
|
||
@property | ||
def dim(self) -> Dict: | ||
return self._meta_info[self.model_name]["dim"] | ||
|
||
def _encode(self, texts: List[str]) -> Dict: | ||
output = self.model.encode(sentences=texts, **self._encode_config) | ||
results = {} | ||
if self._encode_config["return_dense"] is True: | ||
results["dense"] = list(output["dense_vecs"]) | ||
if self._encode_config["return_sparse"] is True: | ||
sparse_dim = self.dim["sparse"] | ||
results["sparse"] = [] | ||
for sparse_vec in output["lexical_weights"]: | ||
indices = [int(k) for k in sparse_vec] | ||
values = list(sparse_vec.values()) | ||
row_indices = [0] * len(indices) | ||
csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim)) | ||
results["sparse"].append(csr) | ||
if self._encode_config["return_colbert_vecs"] is True: | ||
results["colbert_vecs"] = output["colbert_vecs"] | ||
return results | ||
|
||
def encode_queries(self, queries: List[str]) -> Dict: | ||
return self._encode(queries) | ||
|
||
def encode_documents(self, documents: List[str]) -> Dict: | ||
return self._encode(documents) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .bm25 import BM25EmbeddingFunction | ||
from .splade import SpladeEmbeddingFunction | ||
|
||
__all__ = ["SpladeEmbeddingFunction", "BM25EmbeddingFunction"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from .bm25 import BM25EmbeddingFunction | ||
from .tokenizers import Analyzer, build_analyer_from_yaml, build_default_analyzer | ||
|
||
__all__ = [ | ||
"BM25EmbeddingFunction", | ||
"Analyzer", | ||
"build_analyer_from_yaml", | ||
"build_default_analyzer", | ||
] |
Oops, something went wrong.