diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py index 01ba9da65d..fd143d59a0 100644 --- a/xinference/model/embedding/core.py +++ b/xinference/model/embedding/core.py @@ -261,154 +261,165 @@ def _fix_langchain_openai_inputs(self, sentences: Union[str, List[str]]): def create_embedding(self, sentences: Union[str, List[str]], **kwargs): sentences = self._fix_langchain_openai_inputs(sentences) - from FlagEmbedding import BGEM3FlagModel from sentence_transformers import SentenceTransformer kwargs.setdefault("normalize_embeddings", True) - @no_type_check - def _encode_bgem3( - model: Union[SentenceTransformer, BGEM3FlagModel], - sentences: Union[str, List[str]], - batch_size: int = 32, - show_progress_bar: bool = None, - output_value: str = "sparse_embedding", - convert_to_numpy: bool = True, - convert_to_tensor: bool = False, - device: str = None, - normalize_embeddings: bool = False, - **kwargs, - ): - """ - Computes sentence embeddings with bge-m3 model - Nothing special here, just replace sentence-transformer with FlagEmbedding - TODO: think about how to solve the redundant code of encode method in the future - - :param sentences: the sentences to embed - :param batch_size: the batch size used for the computation - :param show_progress_bar: Output a progress bar when encode sentences - :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values - :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors. - :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy - :param device: Which torch.device to use for the computation - :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. + try: + from FlagEmbedding import BGEM3FlagModel + + @no_type_check + def _encode_bgem3( + model: Union[SentenceTransformer, BGEM3FlagModel], + sentences: Union[str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = None, + output_value: str = "sparse_embedding", + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str = None, + normalize_embeddings: bool = False, + **kwargs, + ): + """ + Computes sentence embeddings with bge-m3 model + Nothing special here, just replace sentence-transformer with FlagEmbedding + TODO: think about how to solve the redundant code of encode method in the future + + :param sentences: the sentences to embed + :param batch_size: the batch size used for the computation + :param show_progress_bar: Output a progress bar when encode sentences + :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values + :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors. + :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy + :param device: Which torch.device to use for the computation + :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. + + :return: + By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned. + """ + import torch + from tqdm.autonotebook import trange + + if show_progress_bar is None: + show_progress_bar = ( + logger.getEffectiveLevel() == logging.INFO + or logger.getEffectiveLevel() == logging.DEBUG + ) - :return: - By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned. - """ - import torch - from tqdm.autonotebook import trange + if convert_to_tensor: + convert_to_numpy = False + + if output_value != "sparse_embedding": + convert_to_tensor = False + convert_to_numpy = False + + input_was_string = False + if isinstance(sentences, str) or not hasattr( + sentences, "__len__" + ): # Cast an individual sentence to a list with length 1 + sentences = [sentences] + input_was_string = True + + if device is None: + # Same as SentenceTransformer.py + from sentence_transformers.util import get_device_name + + device = get_device_name() + logger.info(f"Use pytorch device_name: {device}") + + all_embeddings = [] + all_token_nums = 0 + + # The original code does not support other inference engines + def _text_length(text): + if isinstance(text, dict): # {key: value} case + return len(next(iter(text.values()))) + elif not hasattr(text, "__len__"): # Object has no len() method + return 1 + elif len(text) == 0 or isinstance( + text[0], int + ): # Empty string or list of ints + return len(text) + else: + return sum( + [len(t) for t in text] + ) # Sum of length of individual strings - if show_progress_bar is None: - show_progress_bar = ( - logger.getEffectiveLevel() == logging.INFO - or logger.getEffectiveLevel() == logging.DEBUG + length_sorted_idx = np.argsort( + [-_text_length(sen) for sen in sentences] ) + sentences_sorted = [sentences[idx] for idx in length_sorted_idx] + + for start_index in trange( + 0, + len(sentences), + batch_size, + desc="Batches", + disable=not show_progress_bar, + ): + sentences_batch = sentences_sorted[ + start_index : start_index + batch_size + ] + + with torch.no_grad(): + out_features = model.encode(sentences_batch, **kwargs) + + if output_value == "token_embeddings": + embeddings = [] + for token_emb, attention in zip( + out_features[output_value], + out_features["attention_mask"], + ): + last_mask_id = len(attention) - 1 + while ( + last_mask_id > 0 + and attention[last_mask_id].item() == 0 + ): + last_mask_id -= 1 + + embeddings.append(token_emb[0 : last_mask_id + 1]) + elif output_value is None: # Return all outputs + embeddings = [] + for sent_idx in range( + len(out_features["sentence_embedding"]) + ): + row = { + name: out_features[name][sent_idx] + for name in out_features + } + embeddings.append(row) + # for sparse embedding + else: + if kwargs.get("return_sparse"): + embeddings = out_features["lexical_weights"] + else: + embeddings = out_features["dense_vecs"] - if convert_to_tensor: - convert_to_numpy = False - - if output_value != "sparse_embedding": - convert_to_tensor = False - convert_to_numpy = False - - input_was_string = False - if isinstance(sentences, str) or not hasattr( - sentences, "__len__" - ): # Cast an individual sentence to a list with length 1 - sentences = [sentences] - input_was_string = True - - if device is None: - # Same as SentenceTransformer.py - from sentence_transformers.util import get_device_name - - device = get_device_name() - logger.info(f"Use pytorch device_name: {device}") - - all_embeddings = [] - all_token_nums = 0 + if convert_to_numpy: + embeddings = embeddings.cpu() - # The original code does not support other inference engines - def _text_length(text): - if isinstance(text, dict): # {key: value} case - return len(next(iter(text.values()))) - elif not hasattr(text, "__len__"): # Object has no len() method - return 1 - elif len(text) == 0 or isinstance( - text[0], int - ): # Empty string or list of ints - return len(text) - else: - return sum( - [len(t) for t in text] - ) # Sum of length of individual strings + all_embeddings.extend(embeddings) - length_sorted_idx = np.argsort([-_text_length(sen) for sen in sentences]) - sentences_sorted = [sentences[idx] for idx in length_sorted_idx] - - for start_index in trange( - 0, - len(sentences), - batch_size, - desc="Batches", - disable=not show_progress_bar, - ): - sentences_batch = sentences_sorted[ - start_index : start_index + batch_size + all_embeddings = [ + all_embeddings[idx] for idx in np.argsort(length_sorted_idx) ] - with torch.no_grad(): - out_features = model.encode(sentences_batch, **kwargs) - - if output_value == "token_embeddings": - embeddings = [] - for token_emb, attention in zip( - out_features[output_value], out_features["attention_mask"] - ): - last_mask_id = len(attention) - 1 - while ( - last_mask_id > 0 and attention[last_mask_id].item() == 0 - ): - last_mask_id -= 1 - - embeddings.append(token_emb[0 : last_mask_id + 1]) - elif output_value is None: # Return all outputs - embeddings = [] - for sent_idx in range(len(out_features["sentence_embedding"])): - row = { - name: out_features[name][sent_idx] - for name in out_features - } - embeddings.append(row) - # for sparse embedding + if convert_to_tensor: + if len(all_embeddings): + all_embeddings = torch.stack(all_embeddings) else: - if kwargs.get("return_sparse"): - embeddings = out_features["lexical_weights"] - else: - embeddings = out_features["dense_vecs"] - - if convert_to_numpy: - embeddings = embeddings.cpu() - - all_embeddings.extend(embeddings) - - all_embeddings = [ - all_embeddings[idx] for idx in np.argsort(length_sorted_idx) - ] + all_embeddings = torch.Tensor() + elif convert_to_numpy: + all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) - if convert_to_tensor: - if len(all_embeddings): - all_embeddings = torch.stack(all_embeddings) - else: - all_embeddings = torch.Tensor() - elif convert_to_numpy: - all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + if input_was_string: + all_embeddings = all_embeddings[0] - if input_was_string: - all_embeddings = all_embeddings[0] + return all_embeddings, all_token_nums - return all_embeddings, all_token_nums + except ImportError: + _encode_bgem3 = None # copied from sentence-transformers, and modify it to return tokens num @no_type_check @@ -582,6 +593,10 @@ def encode( return all_embeddings, all_token_nums + is_bge_m3_flag_model = ( + self._kwargs.get("hybrid_mode") + and "m3" in self._model_spec.model_name.lower() + ) if ( "gte" in self._model_spec.model_name.lower() and "qwen2" in self._model_spec.model_name.lower() @@ -593,7 +608,8 @@ def encode( convert_to_numpy=False, **kwargs, ) - elif isinstance(self._model, BGEM3FlagModel): + elif is_bge_m3_flag_model: + assert _encode_bgem3 is not None all_embeddings, all_token_nums = _encode_bgem3( self._model, sentences, convert_to_numpy=False, **kwargs ) @@ -608,7 +624,7 @@ def encode( all_embeddings = [all_embeddings] embedding_list = [] for index, data in enumerate(all_embeddings): - if kwargs.get("return_sparse") and isinstance(self._model, BGEM3FlagModel): + if kwargs.get("return_sparse") and is_bge_m3_flag_model: embedding_list.append( EmbeddingData( index=index, @@ -628,8 +644,7 @@ def encode( result = Embedding( object=( "list" # type: ignore - if not isinstance(self._model, BGEM3FlagModel) - and not kwargs.get("return_sparse") + if not is_bge_m3_flag_model and not kwargs.get("return_sparse") else "dict" ), model=self._model_uid,