diff --git a/examples/hello_model.py b/examples/hello_model.py new file mode 100644 index 000000000..ff61bda62 --- /dev/null +++ b/examples/hello_model.py @@ -0,0 +1,72 @@ +# hello_model.py simplifies the demonstration of using various embedding functions in PyMilvus, +# focusing on dense, sparse, and hybrid models. This script illustrates: +# - Initializing and using OpenAIEmbeddingFunction for dense embeddings +# - Initializing and using BGEM3EmbeddingFunction for hybrid embeddings +# - Initializing and using SentenceTransformerEmbeddingFunction for dense embeddings +# - Initializing and using BM25EmbeddingFunction for sparse embeddings +# - Initializing and using SpladeEmbeddingFunction for sparse embeddings +import time + +from pymilvus.model.dense import OpenAIEmbeddingFunction, SentenceTransformerEmbeddingFunction +from pymilvus.model.hybrid import BGEM3EmbeddingFunction +from pymilvus.model.sparse import BM25EmbeddingFunction, SpladeEmbeddingFunction + +fmt = "=== {:30} ===" + + +def log(msg): + print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " " + msg) + + +# OpenAIEmbeddingFunction usage +docs = [ + "Artificial intelligence was founded as an academic discipline in 1956.", + "Alan Turing was the first person to conduct substantial research in AI.", + "Born in Maida Vale, London, Turing was raised in southern England.", +] +log(fmt.format("OpenAIEmbeddingFunction Usage")) + +ef_openai = OpenAIEmbeddingFunction(api_key="sk-your-api-key") +embs_openai = ef_openai(docs) +log(f"Dimension: {ef_openai.dim} Embedding Shape: {embs_openai[0].shape}") + +# ----------------------------------------------------------------------------- +# BGEM3EmbeddingFunction usage +log(fmt.format("BGEM3EmbeddingFunction Usage")) +ef_bge = BGEM3EmbeddingFunction(device="cpu", use_fp16=False) +embs_bge = ef_bge(docs) +log("Embedding Shape: {} Dimension: {}".format(embs_bge["dense"][0].shape, ef_bge.dim)) + +# ----------------------------------------------------------------------------- +# SentenceTransformerEmbeddingFunction usage +log(fmt.format("SentenceTransformerEmbeddingFunction Usage")) +ef_sentence_transformer = SentenceTransformerEmbeddingFunction(device="cpu") +embs_sentence_transformer = ef_sentence_transformer(docs) +log( + "Embedding Shape: {} Dimension: {}".format( + embs_sentence_transformer[0].shape, ef_sentence_transformer.dim + ) +) + +# ----------------------------------------------------------------------------- +# BM25EmbeddingFunction usage +log(fmt.format("BM25EmbeddingFunction Usage")) +ef_bm25 = BM25EmbeddingFunction() +docs_bm25 = [ + "Artificial intelligence was founded as an academic discipline in 1956.", + "Alan Turing was the first person to conduct substantial research in AI.", + "Born in Maida Vale, London, Turing was raised in southern England.", +] +ef_bm25.load() +embs_bm25 = ef_bm25.encode_documents(docs) +log(f"Embedding Shape: {embs_bm25[0].shape} Dimension: {ef_bm25.dim}") + +# ----------------------------------------------------------------------------- +# SpladeEmbeddingFunction usage +log(fmt.format("SpladeEmbeddingFunction Usage")) +ef_splade = SpladeEmbeddingFunction(device="cpu") +embs_splade = ef_splade(["Hello world", "Hello world2"]) +log(f"Embedding Shape: {embs_splade[0].shape} Dimension: {ef_splade.dim}") + +# ----------------------------------------------------------------------------- +log(fmt.format("Demonstrations Finished")) diff --git a/pymilvus/model/__init__.py b/pymilvus/model/__init__.py new file mode 100644 index 000000000..ca6d9c1d5 --- /dev/null +++ b/pymilvus/model/__init__.py @@ -0,0 +1,6 @@ +from . import dense, hybrid, sparse +from .dense.sentence_transformer import SentenceTransformerEmbeddingFunction + +__all__ = ["DefaultEmbeddingFunction", "dense", "sparse", "hybrid"] + +DefaultEmbeddingFunction = SentenceTransformerEmbeddingFunction diff --git a/pymilvus/model/base.py b/pymilvus/model/base.py new file mode 100644 index 000000000..111d391be --- /dev/null +++ b/pymilvus/model/base.py @@ -0,0 +1,14 @@ +from abc import abstractmethod +from typing import List + + +class BaseEmbeddingFunction: + model_name: str + + @abstractmethod + def __call__(self, texts: List[str]): + """ """ + + @abstractmethod + def encode_queries(self, queries: List[str]): + """ """ diff --git a/pymilvus/model/dense/__init__.py b/pymilvus/model/dense/__init__.py new file mode 100644 index 000000000..90e84e2ac --- /dev/null +++ b/pymilvus/model/dense/__init__.py @@ -0,0 +1,4 @@ +from .openai import OpenAIEmbeddingFunction +from .sentence_transformer import SentenceTransformerEmbeddingFunction + +__all__ = ["OpenAIEmbeddingFunction", "SentenceTransformerEmbeddingFunction"] diff --git a/pymilvus/model/dense/openai.py b/pymilvus/model/dense/openai.py new file mode 100644 index 000000000..f09f471d6 --- /dev/null +++ b/pymilvus/model/dense/openai.py @@ -0,0 +1,63 @@ +from collections import defaultdict +from typing import List, Optional + +import numpy as np + +from pymilvus.model.base import BaseEmbeddingFunction + + +class OpenAIEmbeddingFunction(BaseEmbeddingFunction): + def __init__( + self, + model_name: str = "text-embedding-ada-002", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + dimensions: Optional[int] = None, + **kwargs, + ): + try: + from openai import OpenAI + except ImportError as err: + error_message = "openai is not installed." + raise ImportError(error_message) from err + + self._openai_model_meta_info = defaultdict(dict) + self._openai_model_meta_info["text-embedding-3-small"]["dim"] = 1536 + self._openai_model_meta_info["text-embedding-3-large"]["dim"] = 3072 + self._openai_model_meta_info["text-embedding-ada-002"]["dim"] = 1536 + + self._model_config = dict({"api_key": api_key, "base_url": base_url}, **kwargs) + additional_encode_config = {} + if dimensions is not None: + additional_encode_config = {"dimensions": dimensions} + self._openai_model_meta_info[model_name]["dim"] = dimensions + + self._encode_config = {"model": model_name, **additional_encode_config} + self.model_name = model_name + self.client = OpenAI(**self._model_config) + + def encode_queries(self, queries: List[str]) -> List[np.array]: + return self._encode(queries) + + def encode_documents(self, documents: List[str]) -> List[np.array]: + return self._encode(documents) + + @property + def dim(self): + return self._openai_model_meta_info[self.model_name]["dim"] + + def __call__(self, texts: List[str]) -> List[np.array]: + return self._encode(texts) + + def _encode_query(self, query: str) -> np.array: + return self._encode(query)[0] + + def _encode_document(self, document: str) -> np.array: + return self._encode(document)[0] + + def _call_openai_api(self, texts: List[str]): + results = self.client.embeddings.create(input=texts, **self._encode_config).data + return [np.array(data.embedding) for data in results] + + def _encode(self, texts: List[str]): + return self._call_openai_api(texts) diff --git a/pymilvus/model/dense/sentence_transformer.py b/pymilvus/model/dense/sentence_transformer.py new file mode 100644 index 000000000..bf9eb9bdb --- /dev/null +++ b/pymilvus/model/dense/sentence_transformer.py @@ -0,0 +1,77 @@ +from typing import List + +import numpy as np + +from pymilvus.model.base import BaseEmbeddingFunction + + +class SentenceTransformerEmbeddingFunction(BaseEmbeddingFunction): + def __init__( + self, + model_name: str = "all-MiniLM-L6-v2", + batch_size: int = 32, + query_instruction: str = "", + doc_instruction: str = "", + device: str = "cpu", + normalize_embeddings: bool = True, + **kwargs, + ): + try: + from sentence_transformers import SentenceTransformer + except ImportError as err: + error_message = "sentence-transformers is not installed." + raise ImportError(error_message) from err + self.model_name = model_name + self.query_instruction = query_instruction + self.doc_instruction = doc_instruction + self.batch_size = batch_size + self.normalize_embeddings = normalize_embeddings + + _model_config = dict({"model_name_or_path": model_name, "device": device}, **kwargs) + self.model = SentenceTransformer(**_model_config) + + def __call__(self, texts: List[str]) -> List[np.array]: + return self._encode(texts) + + def _encode(self, texts: List[str]) -> List[np.array]: + embs = self.model.encode( + texts, + batch_size=self.batch_size, + show_progress_bar=False, + convert_to_numpy=True, + ) + return list(embs) + + @property + def dim(self): + return self.model.get_sentence_embedding_dimension() + + def encode_queries(self, queries: List[str]) -> List[np.array]: + instructed_queries = [self.query_instruction + query for query in queries] + return self._encode(instructed_queries) + + def encode_documents(self, documents: List[str]) -> List[np.array]: + instructed_documents = [self.doc_instruction + document for document in documents] + return self._encode(instructed_documents) + + def _encode_query(self, query: str) -> np.array: + instructed_query = self.query_instruction + query + embs = self.model.encode( + sentences=[instructed_query], + batch_size=1, + show_progress_bar=False, + convert_to_numpy=True, + normalize_embeddings=self.normalize_embeddings, + ) + return embs[0] + + def _encode_document(self, document: str) -> np.array: + instructed_document = self.doc_instruction + document + embs = self.model.encode( + sentences=[instructed_document], + batch_size=1, + show_progress_bar=False, + convert_to_numpy=True, + normalize_embeddings=self.normalize_embeddings, + ) + return embs[0] diff --git a/pymilvus/model/hybrid/__init__.py b/pymilvus/model/hybrid/__init__.py new file mode 100644 index 000000000..391736d04 --- /dev/null +++ b/pymilvus/model/hybrid/__init__.py @@ -0,0 +1,3 @@ +from .bge_m3 import BGEM3EmbeddingFunction + +__all__ = ["BGEM3EmbeddingFunction"] diff --git a/pymilvus/model/hybrid/bge_m3.py b/pymilvus/model/hybrid/bge_m3.py new file mode 100644 index 000000000..9db69c1c4 --- /dev/null +++ b/pymilvus/model/hybrid/bge_m3.py @@ -0,0 +1,97 @@ +import logging +from collections import defaultdict +from typing import Dict, List + +from scipy.sparse import csr_array + +from pymilvus.model.base import BaseEmbeddingFunction + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class BGEM3EmbeddingFunction(BaseEmbeddingFunction): + def __init__( + self, + model_name: str = "BAAI/bge-m3", + batch_size: int = 16, + device: str = "", + normalize_embeddings: bool = True, + use_fp16: bool = True, + return_dense: bool = True, + return_sparse: bool = True, + return_colbert_vecs: bool = False, + **kwargs, + ): + try: + from FlagEmbedding import BGEM3FlagModel + except ImportError as err: + error_message = "FlagEmbedding is not installed." + raise ImportError(error_message) from err + self.model_name = model_name + self.batch_size = batch_size + self.normalize_embeddings = normalize_embeddings + self.device = device + self.use_fp16 = use_fp16 + + if device == "cpu" and use_fp16 is True: + logger.warning( + "Using fp16 with CPU can lead to runtime errors such as 'LayerNormKernelImpl', It's recommended to set 'use_fp16 = False' when using cpu. " + ) + + _model_config = dict( + { + "model_name_or_path": model_name, + "device": device, + "normalize_embeddings": normalize_embeddings, + "use_fp16": use_fp16, + }, + **kwargs, + ) + _encode_config = { + "return_dense": return_dense, + "return_sparse": return_sparse, + "return_colbert_vecs": return_colbert_vecs, + } + self._model_config = _model_config + self._encode_config = _encode_config + + self.model = BGEM3FlagModel(**self._model_config) + meta_info = defaultdict(dict) + meta_info["BAAI/bge-m3"]["dim"] = { + "dense": 1024, + "sparse": len(self.model.tokenizer), + "colbert_vecs": 1024, + } + self._meta_info = meta_info + + def __call__(self, texts: List[str]) -> Dict: + return self._encode(texts) + + @property + def dim(self) -> Dict: + return self._meta_info[self.model_name]["dim"] + + def _encode(self, texts: List[str]) -> Dict: + output = self.model.encode(sentences=texts, **self._encode_config) + results = {} + if self._encode_config["return_dense"] is True: + results["dense"] = list(output["dense_vecs"]) + if self._encode_config["return_sparse"] is True: + sparse_dim = self.dim["sparse"] + results["sparse"] = [] + for sparse_vec in output["lexical_weights"]: + indices = [int(k) for k in sparse_vec] + values = list(sparse_vec.values()) + row_indices = [0] * len(indices) + csr = csr_array((values, (row_indices, indices)), shape=(1, sparse_dim)) + results["sparse"].append(csr) + if self._encode_config["return_colbert_vecs"] is True: + results["colbert_vecs"] = output["colbert_vecs"] + return results + + def encode_queries(self, queries: List[str]) -> Dict: + return self._encode(queries) + + def encode_documents(self, documents: List[str]) -> Dict: + return self._encode(documents) diff --git a/pymilvus/model/sparse/__init__.py b/pymilvus/model/sparse/__init__.py new file mode 100644 index 000000000..6f94fe783 --- /dev/null +++ b/pymilvus/model/sparse/__init__.py @@ -0,0 +1,4 @@ +from .bm25 import BM25EmbeddingFunction +from .splade import SpladeEmbeddingFunction + +__all__ = ["SpladeEmbeddingFunction", "BM25EmbeddingFunction"] diff --git a/pymilvus/model/sparse/bm25/__init__.py b/pymilvus/model/sparse/bm25/__init__.py new file mode 100644 index 000000000..b64ebeecb --- /dev/null +++ b/pymilvus/model/sparse/bm25/__init__.py @@ -0,0 +1,9 @@ +from .bm25 import BM25EmbeddingFunction +from .tokenizers import Analyzer, build_analyer_from_yaml, build_default_analyzer + +__all__ = [ + "BM25EmbeddingFunction", + "Analyzer", + "build_analyer_from_yaml", + "build_default_analyzer", +] diff --git a/pymilvus/model/sparse/bm25/bm25.py b/pymilvus/model/sparse/bm25/bm25.py new file mode 100644 index 000000000..c74f10dfd --- /dev/null +++ b/pymilvus/model/sparse/bm25/bm25.py @@ -0,0 +1,227 @@ +""" +This file incorporates components from the 'rank_bm25' project by Dorian Brown: +https://github.com/dorianbrown/rank_bm25 +Specifically, the rank_bm25.py file. + +The incorporated components are licensed under the Apache License, Version 2.0 (the "License"); +you may not use these components except in compliance with the License. +You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +import logging +import math +from collections import defaultdict +from multiprocessing import Pool, cpu_count +from pathlib import Path +from typing import Dict, List, Optional + +import requests +from scipy.sparse import csr_array + +from pymilvus.model.base import BaseEmbeddingFunction +from pymilvus.model.sparse.bm25.tokenizers import Analyzer, build_default_analyzer + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +logger.addHandler(console_handler) + + +class BM25EmbeddingFunction(BaseEmbeddingFunction): + def __init__( + self, + analyzer: Analyzer = None, + corpus: Optional[List] = None, + k1: float = 1.5, + b: float = 0.75, + epsilon: float = 0.25, + num_workers: Optional[int] = None, + ): + if analyzer is None: + analyzer = build_default_analyzer(language="en") + self.analyzer = analyzer + self.corpus_size = 0 + self.avgdl = 0 + self.idf = {} + self.k1 = k1 + self.b = b + self.epsilon = epsilon + if num_workers is None: + self.num_workers = cpu_count() + self.num_workers = num_workers + + if analyzer and corpus is not None: + self.fit(corpus) + + def _calc_term_indices(self): + for index, word in enumerate(self.idf): + self.idf[word][1] = index + + def _compute_statistics(self, corpus: List[str]): + term_document_frequencies = defaultdict(int) + total_word_count = 0 + for document in corpus: + total_word_count += len(document) + + frequencies = defaultdict(int) + for word in document: + frequencies[word] += 1 + + for word, _ in frequencies.items(): + term_document_frequencies[word] += 1 + self.corpus_size += 1 + self.avgdl = total_word_count / self.corpus_size + return term_document_frequencies + + def _tokenize_corpus(self, corpus: List[str]): + if self.num_workers == 1: + return [self.analyzer(text) for text in corpus] + pool = Pool(self.num_workers) + return pool.map(self.analyzer, corpus) + + def _calc_idf(self, term_document_frequencies: Dict): + # collect idf sum to calculate an average idf for epsilon value + idf_sum = 0 + # collect words with negative idf to set them a special epsilon value. + # idf can be negative if word is contained in more than half of documents + negative_idfs = [] + for word, freq in term_document_frequencies.items(): + idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5) + if word not in self.idf: + self.idf[word] = [0.0, 0] + self.idf[word][0] = idf + idf_sum += idf + if idf < 0: + negative_idfs.append(word) + self.average_idf = idf_sum / len(self.idf) + + eps = self.epsilon * self.average_idf + for word in negative_idfs: + self.idf[word][0] = eps + + def _rebuild(self, corpus: List[str]): + self._clear() + corpus = self._tokenize_corpus(corpus) + term_document_frequencies = self._compute_statistics(corpus) + self._calc_idf(term_document_frequencies) + self._calc_term_indices() + + def _clear(self): + self.corpus_size = 0 + # idf records the (value, index) + self.idf = defaultdict(list) + + @property + def dim(self): + return len(self.idf) + + def fit(self, corpus: List[str]): + self._rebuild(corpus) + + def _encode_query(self, query: str) -> csr_array: + terms = self.analyzer(query) + values, rows, cols = [], [], [] + for term in terms: + if term in self.idf: + values.append(self.idf[term][0]) + rows.append(0) + cols.append(self.idf[term][1]) + return csr_array((values, (rows, cols)), shape=(1, len(self.idf))) + + def _encode_document(self, doc: str) -> csr_array: + terms = self.analyzer(doc) + frequencies = defaultdict(int) + doc_len = len(terms) + term_set = set() + for term in terms: + frequencies[term] += 1 + term_set.add(term) + values, rows, cols = [], [], [] + for term in term_set: + if term in self.idf: + term_freq = frequencies[term] + value = ( + term_freq + * (self.k1 + 1) + / (term_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)) + ) + rows.append(0) + cols.append(self.idf[term][1]) + values.append(value) + return csr_array((values, (rows, cols)), shape=(1, len(self.idf))) + + def encode_queries(self, queries: List[str]) -> List[csr_array]: + if self.num_workers == 1: + return [self._encode_query(query) for query in queries] + with Pool(self.num_workers) as pool: + return pool.map(self._encode_query, queries) + + def __call__(self, texts: List[str]) -> List[csr_array]: + error_message = "Unsupported function called, please check the documentation of 'BM25EmbeddingFunction'." + raise ValueError(error_message) + + def encode_documents(self, documents: List[str]) -> List[csr_array]: + if self.num_workers == 1: + return [self._encode_document(document) for document in documents] + with Pool(self.num_workers) as pool: + return pool.map(self._encode_document, documents) + + def save(self, path: str): + bm25_params = {} + bm25_params["version"] = "v1" + bm25_params["corpus_size"] = self.corpus_size + bm25_params["avgdl"] = self.avgdl + bm25_params["idf_word"] = [None for _ in range(len(self.idf))] + bm25_params["idf_value"] = [None for _ in range(len(self.idf))] + for word, values in self.idf.items(): + bm25_params["idf_word"][values[1]] = word + bm25_params["idf_value"][values[1]] = values[0] + + bm25_params["k1"] = self.k1 + bm25_params["b"] = self.b + bm25_params["epsilon"] = self.epsilon + + with Path(path).open("w") as json_file: + json.dump(bm25_params, json_file) + + def load(self, path: Optional[str] = None): + default_meta_filename = "bm25_msmarco_v1.json" + default_meta_url = "https://github.com/milvus-io/pymilvus-assets/releases/download/v0.1-bm25v1/bm25_msmarco_v1.json" + if path is None: + logger.info(f"path is None, using default {default_meta_filename}.") + if not Path(default_meta_filename).exists(): + try: + logger.info( + f"{default_meta_filename} not found, start downloading from {default_meta_url} to ./{default_meta_filename}." + ) + response = requests.get(default_meta_url, timeout=30) + response.raise_for_status() + with Path(default_meta_filename).open("wb") as f: + f.write(response.content) + logger.info(f"{default_meta_filename} has been downloaded successfully.") + except requests.exceptions.RequestException as e: + error_message = f"Failed to download the file: {e}" + raise RuntimeError(error_message) from e + path = default_meta_filename + try: + with Path(path).open() as json_file: + bm25_params = json.load(json_file) + except OSError as e: + error_message = f"Error opening file {path}: {e}" + raise RuntimeError(error_message) from e + self.corpus_size = bm25_params["corpus_size"] + self.avgdl = bm25_params["avgdl"] + self.idf = {} + for i in range(len(bm25_params["idf_word"])): + self.idf[bm25_params["idf_word"][i]] = [bm25_params["idf_value"][i], i] + self.k1 = bm25_params["k1"] + self.b = bm25_params["b"] + self.epsilon = bm25_params["epsilon"] diff --git a/pymilvus/model/sparse/bm25/lang.yaml b/pymilvus/model/sparse/bm25/lang.yaml new file mode 100644 index 000000000..4eb8037ea --- /dev/null +++ b/pymilvus/model/sparse/bm25/lang.yaml @@ -0,0 +1,137 @@ +en: + tokenizer: + class: StandardTokenizer + params: {} + filters: + - class: LowercaseFilter + params: {} + - class: PunctuationFilter + params: {} + - class: StopwordFilter + params: + language: 'english' + - class: StemmingFilter + params: + language: 'english' +de: + tokenizer: + class: StandardTokenizer + params: {} + filters: + - class: LowercaseFilter + params: {} + - class: StopwordFilter + params: + language: 'german' + - class: PunctuationFilter + params: {} + - class: StemmingFilter + params: + language: 'german' +fr: + tokenizer: + class: StandardTokenizer + params: {} + filters: + - class: LowercaseFilter + params: {} + - class: PunctuationFilter + params: {} + - class: StopwordFilter + params: + language: 'french' + - class: StemmingFilter + params: + language: 'french' +ru: + tokenizer: + class: StandardTokenizer + params: {} + filters: + - class: LowercaseFilter + params: {} + - class: PunctuationFilter + params: {} + - class: StopwordFilter + params: + language: 'russian' + - class: StemmingFilter + params: + language: 'russian' +sp: + tokenizer: + class: StandardTokenizer + params: {} + filters: + - class: LowercaseFilter + params: {} + - class: PunctuationFilter + params: + extras: '¡¿' + - class: StopwordFilter + params: + language: 'spanish' + - class: StemmingFilter + params: + language: 'spanish' +it: + tokenizer: + class: StandardTokenizer + params: {} + filters: + - class: LowercaseFilter + params: {} + - class: PunctuationFilter + params: {} + - class: StopwordFilter + params: + language: 'italian' + - class: StemmingFilter + params: + language: 'italian' +pt: + tokenizer: + class: StandardTokenizer + params: {} + filters: + - class: LowercaseFilter + params: {} + - class: PunctuationFilter + params: {} + - class: StopwordFilter + params: + language: 'portuguese' + - class: StemmingFilter + params: + language: 'portuguese' +zh: + tokenizer: + class: JiebaTokenizer + params: {} + filters: + - class: StopwordFilter + params: + language: 'chinese' + - class: PunctuationFilter + params: + extras: ' 、"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰–—‘’‛“”„‟…‧﹏﹑﹔·.!?。。' +jp: + tokenizer: + class: MecabTokenizer + params: {} + preprocessors: + - class: CharacterfilterPreprocessor + params: + chars_to_replace: ['、', '。', '「', '」', '『', '』', '【', '】', '(', ')', '{', '}', '・', ':', ';', '!', '?', 'ー', '〜', '…', '‥', '[', ']'] + filters: + - class: StopwordFilter + params: {} + - class: PunctuationFilter + params: {} +kr: + tokenizer: + class: KonlpyTokenizer + params: {} + filters: + - class: StopwordFilter + params: {} diff --git a/pymilvus/model/sparse/bm25/tokenizers.py b/pymilvus/model/sparse/bm25/tokenizers.py new file mode 100644 index 000000000..9ffc4a1fe --- /dev/null +++ b/pymilvus/model/sparse/bm25/tokenizers.py @@ -0,0 +1,204 @@ +import logging +import re +import string +from importlib.util import find_spec +from pathlib import Path +from typing import Any, Dict, List, Match, Optional, Type + +import yaml +from nltk import word_tokenize +from nltk.corpus import stopwords +from nltk.stem.snowball import SnowballStemmer + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +_class_registry = {} + + +def register_class(register_as: str): + def decorator(cls: Type[Any]): + _class_registry[register_as] = cls + return cls + + return decorator + + +class Preprocessor: + def apply(self, text: str): + error_message = "Each preprocessor must implement its 'apply' method." + raise NotImplementedError(error_message) + + +@register_class("CharacterfilterPreprocessor") +class CharacterfilterPreprocessor: + def __init__(self, chars_to_replace: str): + self.replacement_table = str.maketrans({char: " " for char in chars_to_replace}) + + def apply(self, text: str): + return text.translate(self.replacement_table) + + +@register_class("ReplacePreprocessor") +class ReplacePreprocessor: + def __init__(self, replacement_mapping: Dict[str, str]): + self.replacement_mapping = replacement_mapping + self.pattern = re.compile("|".join(map(re.escape, replacement_mapping.keys()))) + + def _replacement_function(self, match: Match): + return self.replacement_mapping[match.group(0)] + + def apply(self, text: str): + return self.pattern.sub(self._replacement_function, text) + + +@register_class("StandardTokenizer") +class StandardTokenizer: + def tokenize(self, text: str): + return word_tokenize(text) + + +class TextFilter: + def apply(self, tokens: List[str]): + error_message = "Each filter must implement the 'apply' method." + raise NotImplementedError(error_message) + + +@register_class("LowercaseFilter") +class LowercaseFilter(TextFilter): + def apply(self, tokens: List[str]): + return [token.lower() for token in tokens] + + +@register_class("StopwordFilter") +class StopwordFilter(TextFilter): + def __init__(self, language: str = "english", stopword_list: Optional[List[str]] = None): + if stopword_list is None: + stopword_list = [] + self.stopwords = set(stopwords.words(language) + stopword_list) + + def apply(self, tokens: List[str]): + return [token for token in tokens if token not in self.stopwords] + + +@register_class("PunctuationFilter") +class PunctuationFilter(TextFilter): + def __init__(self, extras: str = ""): + self.punctuation = set(string.punctuation + extras) + + def apply(self, tokens: List[str]): + return [token for token in tokens if token not in self.punctuation] + + +@register_class("StemmingFilter") +class StemmingFilter(TextFilter): + def __init__(self, language: str = "english"): + self.stemmer = SnowballStemmer(language) + + def apply(self, tokens: List[str]): + return [self.stemmer.stem(token) for token in tokens] + + +class Tokenizer: + def tokenize(self, text: str): + error_message = "Each tokenizer must implement its 'tokenize' method." + raise NotImplementedError(error_message) + + +@register_class("JiebaTokenizer") +class JiebaTokenizer(Tokenizer): + def __init__(self): + if find_spec("jieba") is None: + error_message = "jieba is required for JiebaTokenizer but is not installed. Please install it using 'pip install jieba'." + logger.error(error_message) + raise ImportError(error_message) + + def tokenize(self, text: str): + import jieba + + return jieba.lcut(text) + + +@register_class("MecabTokenizer") +class MecabTokenizer(Tokenizer): + def __init__(self): + if find_spec("MeCab") is None: + error_message = "MeCab is required for MecabTokenizer but is not installed. Please install it using 'pip install mecab-python3'." + logger.error(error_message) + raise ImportError(error_message) + + def tokenize(self, text: str): + import MeCab + + wakati = MeCab.Tagger("-Owakati") + return wakati.parse(text).split() + + +@register_class("KonlpyTokenizer") +class KonlpyTokenizer(Tokenizer): + def __init__(self): + if find_spec("konlpy") is None: + error_message = "konlpy is required for KonlpyTokenizer but is not installed. Please install it using 'pip install konlpy'." + logger.error(error_message) + raise ImportError(error_message) + + def tokenize(self, text: str): + from konlpy.tag import Kkma + + return Kkma().nouns(text) + + +class Analyzer: + def __init__( + self, + name: str, + tokenizer: Tokenizer, + preprocessors: Optional[List[Preprocessor]] = None, + filters: Optional[List[TextFilter]] = None, + ): + self.name = name + self.tokenizer = tokenizer + self.preprocessors = preprocessors + self.filters = filters + + def __call__(self, text: str): + for preprocessor in self.preprocessors: + text = preprocessor.apply(text) + tokens = self.tokenizer.tokenize(text) + for _filter in self.filters: + tokens = _filter.apply(tokens) + return tokens + + +def build_default_analyzer(language: str = "en"): + default_config_path = Path(__file__).parent / "lang.yaml" + return build_analyer_from_yaml(default_config_path, language) + + +def build_analyer_from_yaml(filepath: str, name: str): + with Path(filepath).open() as file: + config = yaml.safe_load(file) + + lang_config = config.get(name) + if not lang_config: + error_message = f"No configuration found {name}" + raise ValueError(error_message) + + tokenizer_class_type = _class_registry[lang_config["tokenizer"]["class"]] + tokenizer_params = lang_config["tokenizer"]["params"] + + tokenizer = tokenizer_class_type(**tokenizer_params) + preprocessors = [] + filters = [] + if "preprocessors" in lang_config: + preprocessors = [ + _class_registry[filter_config["class"]](**filter_config["params"]) + for filter_config in lang_config["preprocessors"] + ] + if "filters" in lang_config: + filters = [ + _class_registry[filter_config["class"]](**filter_config["params"]) + for filter_config in lang_config["filters"] + ] + + return Analyzer(name=name, tokenizer=tokenizer, preprocessors=preprocessors, filters=filters) diff --git a/pymilvus/model/sparse/splade.py b/pymilvus/model/sparse/splade.py new file mode 100644 index 000000000..34e549625 --- /dev/null +++ b/pymilvus/model/sparse/splade.py @@ -0,0 +1,218 @@ +""" +The following code is adapted from/inspired by the 'neural-cherche' project: +https://github.com/raphaelsty/neural-cherche +Specifically, neural-cherche/neural_cherche/models/splade.py + +MIT License + +Copyright (c) 2023 Raphael Sourty + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import logging +from typing import Dict, List, Optional + +import numpy as np +import torch +from scipy.sparse import csr_array + +from pymilvus.model.base import BaseEmbeddingFunction + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +class SpladeEmbeddingFunction(BaseEmbeddingFunction): + model_name: str + + def __init__( + self, + model_name: str = "naver/splade-cocondenser-ensembledistil", + batch_size: int = 32, + query_instruction: str = "", + doc_instruction: str = "", + device: Optional[str] = "cpu", + k_tokens_query: Optional[int] = None, + k_tokens_document: Optional[int] = None, + **kwargs, + ): + self.model_name = model_name + + _model_config = dict( + {"model_name_or_path": model_name, "batch_size": batch_size, "device": device}, + **kwargs, + ) + self._model_config = _model_config + self.model = _SpladeImplementation(**self._model_config) + self.device = device + self.k_tokens_query = k_tokens_query + self.k_tokens_document = k_tokens_document + self.query_instruction = query_instruction + self.doc_instruction = doc_instruction + + def __call__(self, texts: List[str]) -> List[csr_array]: + embs = self._encode(texts, None) + return list(embs) + + def encode_documents(self, documents: List[str]) -> List[csr_array]: + embs = self._encode( + [self.doc_instruction + document for document in documents], + self.k_tokens_document, + ) + return list(embs) + + def _encode(self, texts: List[str], k_tokens: int) -> List[csr_array]: + embs = self.model.forward(texts, k_tokens=k_tokens) + return list(embs) + + def encode_queries(self, queries: List[str]) -> List[csr_array]: + embs = self._encode( + [self.query_instruction + query for query in queries], + self.k_tokens_query, + ) + return list(embs) + + @property + def dim(self) -> int: + return len(self.model.tokenizer) + + def _encode_query(self, query: str) -> csr_array: + return self.model.forward([self.query_instruction + query], k_tokens=self.k_tokens_query)[0] + + def _encode_document(self, document: str) -> csr_array: + return self.model.forward( + [self.doc_instruction + document], k_tokens=self.k_tokens_document + )[0] + + +class _SpladeImplementation: + def __init__( + self, + model_name_or_path: Optional[str] = None, + device: Optional[str] = None, + batch_size: int = 32, + **kwargs, + ): + try: + from transformers import AutoModelForMaskedLM, AutoTokenizer + except ImportError as _: + logger.error("transformers is not installed.") + + self.device = device + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + self.model = AutoModelForMaskedLM.from_pretrained(model_name_or_path, **kwargs) + self.model.to(self.device) + self.batch_size = batch_size + + self.relu = torch.nn.ReLU() + self.relu.to(self.device) + self.model.config.output_hidden_states = True + + def _encode(self, texts: List[str]): + encoded_input = self.tokenizer.batch_encode_plus( + texts, + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + add_special_tokens=True, + padding=True, + ) + encoded_input = {key: val.to(self.device) for key, val in encoded_input.items()} + output = self.model(**encoded_input) + return output.logits + + def forward(self, texts: List[str], k_tokens: int): + logits = self._encode(texts=texts) + activations = self._get_activation(logits=logits) + + if k_tokens is None: + nonzero_indices = [ + torch.nonzero(activations["sparse_activations"][i]).t()[0] + for i in range(len(texts)) + ] + activations["activations"] = nonzero_indices + else: + activations = self._update_activations(**activations, k_tokens=k_tokens) + return self._convert_to_csr_array(activations) + + def _get_activation(self, logits: torch.Tensor) -> Dict[str, torch.Tensor]: + return {"sparse_activations": torch.amax(torch.log1p(self.relu(logits)), dim=1)} + + def _update_activations(self, sparse_activations: torch.Tensor, k_tokens: int) -> torch.Tensor: + activations = torch.topk(input=sparse_activations, k=k_tokens, dim=1).indices + + # Set value of max sparse_activations which are not in top k to 0. + sparse_activations = sparse_activations * torch.zeros( + (sparse_activations.shape[0], sparse_activations.shape[1]), + dtype=int, + device=self.device, + ).scatter_(dim=1, index=activations.long(), value=1) + + return { + "activations": activations, + "sparse_activations": sparse_activations, + } + + def _filter_activations( + self, activations: torch.Tensor, k_tokens: int, **kwargs + ) -> torch.Tensor: + _, activations = torch.topk(input=activations, k=k_tokens, dim=1, **kwargs) + return activations + + def _convert_to_csr_array(self, activations: Dict): + csr_array_list = [] + + if activations["sparse_activations"].shape[0] != len(activations["activations"]): + error_msg = ( + "The shape of 'sparse_activations' does not match the length of 'activations'" + ) + raise ValueError(error_msg) + + for i, column_indices in enumerate(activations["activations"]): + values = ( + torch.gather(activations["sparse_activations"][i], 0, column_indices) + .cpu() + .detach() + .numpy() + ) + row_indices = np.zeros(len(activations["activations"][i])) + col_indices = activations["activations"][i].cpu().detach().numpy() + csr_array_list.append( + csr_array( + (values.flatten(), (row_indices, col_indices)), + shape=(1, activations["sparse_activations"].shape[1]), + ) + ) + return csr_array_list + + def _convert_to_csr_array2(self, activations: Dict): + values = ( + torch.gather(activations["sparse_activations"], 1, activations["activations"]) + .cpu() + .detach() + .numpy() + ) + rows, cols = activations["activations"].shape + row_indices = np.repeat(np.arange(rows), cols) + col_indices = activations["activations"].detach().cpu().numpy().flatten() + return csr_array( + (values.flatten(), (row_indices, col_indices)), + shape=(rows, activations["sparse_activations"].shape[1]), + ) diff --git a/pyproject.toml b/pyproject.toml index 903907b18..049050b64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,19 @@ dynamic = ["version"] [project.urls] "repository" = 'https://github.com/milvus-io/pymilvus' +[project.optional-dependencies] +model = [ + "openai >= 1.12.0", + "sentence-transformers", + "FlagEmbedding >= 1.2.2", + "nltk", + "transformers >= 4.33.0", + "jieba", + "konlpy", + "mecab-python3", + "scipy >= 1.10.0", +] + [tool.setuptools.dynamic] version = { attr = "_version_helper.version"}