Skip to content

Commit

Permalink
Add embedding models to pymilvus. (#1971)
Browse files Browse the repository at this point in the history
Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored Mar 15, 2024
1 parent 818f290 commit 436f559
Show file tree
Hide file tree
Showing 15 changed files with 1,148 additions and 0 deletions.
72 changes: 72 additions & 0 deletions examples/hello_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# hello_model.py simplifies the demonstration of using various embedding functions in PyMilvus,
# focusing on dense, sparse, and hybrid models. This script illustrates:
# - Initializing and using OpenAIEmbeddingFunction for dense embeddings
# - Initializing and using BGEM3EmbeddingFunction for hybrid embeddings
# - Initializing and using SentenceTransformerEmbeddingFunction for dense embeddings
# - Initializing and using BM25EmbeddingFunction for sparse embeddings
# - Initializing and using SpladeEmbeddingFunction for sparse embeddings
import time

from pymilvus.model.dense import OpenAIEmbeddingFunction, SentenceTransformerEmbeddingFunction
from pymilvus.model.hybrid import BGEM3EmbeddingFunction
from pymilvus.model.sparse import BM25EmbeddingFunction, SpladeEmbeddingFunction

fmt = "=== {:30} ==="


def log(msg):
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " " + msg)


# OpenAIEmbeddingFunction usage
docs = [
"Artificial intelligence was founded as an academic discipline in 1956.",
"Alan Turing was the first person to conduct substantial research in AI.",
"Born in Maida Vale, London, Turing was raised in southern England.",
]
log(fmt.format("OpenAIEmbeddingFunction Usage"))

ef_openai = OpenAIEmbeddingFunction(api_key="sk-your-api-key")
embs_openai = ef_openai(docs)
log(f"Dimension: {ef_openai.dim} Embedding Shape: {embs_openai[0].shape}")

# -----------------------------------------------------------------------------
# BGEM3EmbeddingFunction usage
log(fmt.format("BGEM3EmbeddingFunction Usage"))
ef_bge = BGEM3EmbeddingFunction(device="cpu", use_fp16=False)
embs_bge = ef_bge(docs)
log("Embedding Shape: {} Dimension: {}".format(embs_bge["dense"][0].shape, ef_bge.dim))

# -----------------------------------------------------------------------------
# SentenceTransformerEmbeddingFunction usage
log(fmt.format("SentenceTransformerEmbeddingFunction Usage"))
ef_sentence_transformer = SentenceTransformerEmbeddingFunction(device="cpu")
embs_sentence_transformer = ef_sentence_transformer(docs)
log(
"Embedding Shape: {} Dimension: {}".format(
embs_sentence_transformer[0].shape, ef_sentence_transformer.dim
)
)

# -----------------------------------------------------------------------------
# BM25EmbeddingFunction usage
log(fmt.format("BM25EmbeddingFunction Usage"))
ef_bm25 = BM25EmbeddingFunction()
docs_bm25 = [
"Artificial intelligence was founded as an academic discipline in 1956.",
"Alan Turing was the first person to conduct substantial research in AI.",
"Born in Maida Vale, London, Turing was raised in southern England.",
]
ef_bm25.load()
embs_bm25 = ef_bm25.encode_documents(docs)
log(f"Embedding Shape: {embs_bm25[0].shape} Dimension: {ef_bm25.dim}")

# -----------------------------------------------------------------------------
# SpladeEmbeddingFunction usage
log(fmt.format("SpladeEmbeddingFunction Usage"))
ef_splade = SpladeEmbeddingFunction(device="cpu")
embs_splade = ef_splade(["Hello world", "Hello world2"])
log(f"Embedding Shape: {embs_splade[0].shape} Dimension: {ef_splade.dim}")

# -----------------------------------------------------------------------------
log(fmt.format("Demonstrations Finished"))
6 changes: 6 additions & 0 deletions pymilvus/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import dense, hybrid, sparse
from .dense.sentence_transformer import SentenceTransformerEmbeddingFunction

__all__ = ["DefaultEmbeddingFunction", "dense", "sparse", "hybrid"]

DefaultEmbeddingFunction = SentenceTransformerEmbeddingFunction
14 changes: 14 additions & 0 deletions pymilvus/model/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from abc import abstractmethod
from typing import List


class BaseEmbeddingFunction:
model_name: str

@abstractmethod
def __call__(self, texts: List[str]):
""" """

@abstractmethod
def encode_queries(self, queries: List[str]):
""" """
4 changes: 4 additions & 0 deletions pymilvus/model/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .openai import OpenAIEmbeddingFunction
from .sentence_transformer import SentenceTransformerEmbeddingFunction

__all__ = ["OpenAIEmbeddingFunction", "SentenceTransformerEmbeddingFunction"]
63 changes: 63 additions & 0 deletions pymilvus/model/dense/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from collections import defaultdict
from typing import List, Optional

import numpy as np

from pymilvus.model.base import BaseEmbeddingFunction


class OpenAIEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
self,
model_name: str = "text-embedding-ada-002",
api_key: Optional[str] = None,
base_url: Optional[str] = None,
dimensions: Optional[int] = None,
**kwargs,
):
try:
from openai import OpenAI
except ImportError as err:
error_message = "openai is not installed."
raise ImportError(error_message) from err

self._openai_model_meta_info = defaultdict(dict)
self._openai_model_meta_info["text-embedding-3-small"]["dim"] = 1536
self._openai_model_meta_info["text-embedding-3-large"]["dim"] = 3072
self._openai_model_meta_info["text-embedding-ada-002"]["dim"] = 1536

self._model_config = dict({"api_key": api_key, "base_url": base_url}, **kwargs)
additional_encode_config = {}
if dimensions is not None:
additional_encode_config = {"dimensions": dimensions}
self._openai_model_meta_info[model_name]["dim"] = dimensions

self._encode_config = {"model": model_name, **additional_encode_config}
self.model_name = model_name
self.client = OpenAI(**self._model_config)

def encode_queries(self, queries: List[str]) -> List[np.array]:
return self._encode(queries)

def encode_documents(self, documents: List[str]) -> List[np.array]:
return self._encode(documents)

@property
def dim(self):
return self._openai_model_meta_info[self.model_name]["dim"]

def __call__(self, texts: List[str]) -> List[np.array]:
return self._encode(texts)

def _encode_query(self, query: str) -> np.array:
return self._encode(query)[0]

def _encode_document(self, document: str) -> np.array:
return self._encode(document)[0]

def _call_openai_api(self, texts: List[str]):
results = self.client.embeddings.create(input=texts, **self._encode_config).data
return [np.array(data.embedding) for data in results]

def _encode(self, texts: List[str]):
return self._call_openai_api(texts)
77 changes: 77 additions & 0 deletions pymilvus/model/dense/sentence_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import List

import numpy as np

from pymilvus.model.base import BaseEmbeddingFunction


class SentenceTransformerEmbeddingFunction(BaseEmbeddingFunction):
def __init__(
self,
model_name: str = "all-MiniLM-L6-v2",
batch_size: int = 32,
query_instruction: str = "",
doc_instruction: str = "",
device: str = "cpu",
normalize_embeddings: bool = True,
**kwargs,
):
try:
from sentence_transformers import SentenceTransformer
except ImportError as err:
error_message = "sentence-transformers is not installed."
raise ImportError(error_message) from err
self.model_name = model_name
self.query_instruction = query_instruction
self.doc_instruction = doc_instruction
self.batch_size = batch_size
self.normalize_embeddings = normalize_embeddings

_model_config = dict({"model_name_or_path": model_name, "device": device}, **kwargs)
self.model = SentenceTransformer(**_model_config)

def __call__(self, texts: List[str]) -> List[np.array]:
return self._encode(texts)

def _encode(self, texts: List[str]) -> List[np.array]:
embs = self.model.encode(
texts,
batch_size=self.batch_size,
show_progress_bar=False,
convert_to_numpy=True,
)
return list(embs)

@property
def dim(self):
return self.model.get_sentence_embedding_dimension()

def encode_queries(self, queries: List[str]) -> List[np.array]:
instructed_queries = [self.query_instruction + query for query in queries]
return self._encode(instructed_queries)

def encode_documents(self, documents: List[str]) -> List[np.array]:
instructed_documents = [self.doc_instruction + document for document in documents]
return self._encode(instructed_documents)

def _encode_query(self, query: str) -> np.array:
instructed_query = self.query_instruction + query
embs = self.model.encode(
sentences=[instructed_query],
batch_size=1,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings,
)
return embs[0]

def _encode_document(self, document: str) -> np.array:
instructed_document = self.doc_instruction + document
embs = self.model.encode(
sentences=[instructed_document],
batch_size=1,
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings,
)
return embs[0]
3 changes: 3 additions & 0 deletions pymilvus/model/hybrid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .bge_m3 import BGEM3EmbeddingFunction

__all__ = ["BGEM3EmbeddingFunction"]
97 changes: 97 additions & 0 deletions pymilvus/model/hybrid/bge_m3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import logging
from collections import defaultdict
from typing import Dict, List

from scipy.sparse import csr_array

from pymilvus.model.base import BaseEmbeddingFunction

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class BGEM3EmbeddingFunction(BaseEmbeddingFunction):
def __init__(
self,
model_name: str = "BAAI/bge-m3",
batch_size: int = 16,
device: str = "",
normalize_embeddings: bool = True,
use_fp16: bool = True,
return_dense: bool = True,
return_sparse: bool = True,
return_colbert_vecs: bool = False,
**kwargs,
):
try:
from FlagEmbedding import BGEM3FlagModel
except ImportError as err:
error_message = "FlagEmbedding is not installed."
raise ImportError(error_message) from err
self.model_name = model_name
self.batch_size = batch_size
self.normalize_embeddings = normalize_embeddings
self.device = device
self.use_fp16 = use_fp16

if device == "cpu" and use_fp16 is True:
logger.warning(
"Using fp16 with CPU can lead to runtime errors such as 'LayerNormKernelImpl', It's recommended to set 'use_fp16 = False' when using cpu. "
)

_model_config = dict(
{
"model_name_or_path": model_name,
"device": device,
"normalize_embeddings": normalize_embeddings,
"use_fp16": use_fp16,
},
**kwargs,
)
_encode_config = {
"return_dense": return_dense,
"return_sparse": return_sparse,
"return_colbert_vecs": return_colbert_vecs,
}
self._model_config = _model_config
self._encode_config = _encode_config

self.model = BGEM3FlagModel(**self._model_config)
meta_info = defaultdict(dict)
meta_info["BAAI/bge-m3"]["dim"] = {
"dense": 1024,
"sparse": len(self.model.tokenizer),
"colbert_vecs": 1024,
}
self._meta_info = meta_info

def __call__(self, texts: List[str]) -> Dict:
return self._encode(texts)

@property
def dim(self) -> Dict:
return self._meta_info[self.model_name]["dim"]

def _encode(self, texts: List[str]) -> Dict:
output = self.model.encode(sentences=texts, **self._encode_config)
results = {}
if self._encode_config["return_dense"] is True:
results["dense"] = list(output["dense_vecs"])
if self._encode_config["return_sparse"] is True:
sparse_dim = self.dim["sparse"]
results["sparse"] = []
for sparse_vec in output["lexical_weights"]:
indices = [int(k) for k in sparse_vec]
values = list(sparse_vec.values())
row_indices = [0] * len(indices)
csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim))
results["sparse"].append(csr)
if self._encode_config["return_colbert_vecs"] is True:
results["colbert_vecs"] = output["colbert_vecs"]
return results

def encode_queries(self, queries: List[str]) -> Dict:
return self._encode(queries)

def encode_documents(self, documents: List[str]) -> Dict:
return self._encode(documents)
4 changes: 4 additions & 0 deletions pymilvus/model/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .bm25 import BM25EmbeddingFunction
from .splade import SpladeEmbeddingFunction

__all__ = ["SpladeEmbeddingFunction", "BM25EmbeddingFunction"]
9 changes: 9 additions & 0 deletions pymilvus/model/sparse/bm25/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .bm25 import BM25EmbeddingFunction
from .tokenizers import Analyzer, build_analyer_from_yaml, build_default_analyzer

__all__ = [
"BM25EmbeddingFunction",
"Analyzer",
"build_analyer_from_yaml",
"build_default_analyzer",
]
Loading

0 comments on commit 436f559

Please sign in to comment.