Skip to content

Commit

Permalink
BUG: fix ImportError when optional dependency FlagEmbedding is not in…
Browse files Browse the repository at this point in the history
…stalled (#2649)
  • Loading branch information
zjuyzj authored Dec 10, 2024
1 parent 70a0081 commit 15d0978
Showing 1 changed file with 151 additions and 136 deletions.
287 changes: 151 additions & 136 deletions xinference/model/embedding/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,154 +261,165 @@ def _fix_langchain_openai_inputs(self, sentences: Union[str, List[str]]):
def create_embedding(self, sentences: Union[str, List[str]], **kwargs):
sentences = self._fix_langchain_openai_inputs(sentences)

from FlagEmbedding import BGEM3FlagModel
from sentence_transformers import SentenceTransformer

kwargs.setdefault("normalize_embeddings", True)

@no_type_check
def _encode_bgem3(
model: Union[SentenceTransformer, BGEM3FlagModel],
sentences: Union[str, List[str]],
batch_size: int = 32,
show_progress_bar: bool = None,
output_value: str = "sparse_embedding",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
normalize_embeddings: bool = False,
**kwargs,
):
"""
Computes sentence embeddings with bge-m3 model
Nothing special here, just replace sentence-transformer with FlagEmbedding
TODO: think about how to solve the redundant code of encode method in the future
:param sentences: the sentences to embed
:param batch_size: the batch size used for the computation
:param show_progress_bar: Output a progress bar when encode sentences
:param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
:param device: Which torch.device to use for the computation
:param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
try:
from FlagEmbedding import BGEM3FlagModel

@no_type_check
def _encode_bgem3(
model: Union[SentenceTransformer, BGEM3FlagModel],
sentences: Union[str, List[str]],
batch_size: int = 32,
show_progress_bar: bool = None,
output_value: str = "sparse_embedding",
convert_to_numpy: bool = True,
convert_to_tensor: bool = False,
device: str = None,
normalize_embeddings: bool = False,
**kwargs,
):
"""
Computes sentence embeddings with bge-m3 model
Nothing special here, just replace sentence-transformer with FlagEmbedding
TODO: think about how to solve the redundant code of encode method in the future
:param sentences: the sentences to embed
:param batch_size: the batch size used for the computation
:param show_progress_bar: Output a progress bar when encode sentences
:param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values
:param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors.
:param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy
:param device: Which torch.device to use for the computation
:param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
:return:
By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
"""
import torch
from tqdm.autonotebook import trange

if show_progress_bar is None:
show_progress_bar = (
logger.getEffectiveLevel() == logging.INFO
or logger.getEffectiveLevel() == logging.DEBUG
)

:return:
By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.
"""
import torch
from tqdm.autonotebook import trange
if convert_to_tensor:
convert_to_numpy = False

if output_value != "sparse_embedding":
convert_to_tensor = False
convert_to_numpy = False

input_was_string = False
if isinstance(sentences, str) or not hasattr(
sentences, "__len__"
): # Cast an individual sentence to a list with length 1
sentences = [sentences]
input_was_string = True

if device is None:
# Same as SentenceTransformer.py
from sentence_transformers.util import get_device_name

device = get_device_name()
logger.info(f"Use pytorch device_name: {device}")

all_embeddings = []
all_token_nums = 0

# The original code does not support other inference engines
def _text_length(text):
if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, "__len__"): # Object has no len() method
return 1
elif len(text) == 0 or isinstance(
text[0], int
): # Empty string or list of ints
return len(text)
else:
return sum(
[len(t) for t in text]
) # Sum of length of individual strings

if show_progress_bar is None:
show_progress_bar = (
logger.getEffectiveLevel() == logging.INFO
or logger.getEffectiveLevel() == logging.DEBUG
length_sorted_idx = np.argsort(
[-_text_length(sen) for sen in sentences]
)
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

for start_index in trange(
0,
len(sentences),
batch_size,
desc="Batches",
disable=not show_progress_bar,
):
sentences_batch = sentences_sorted[
start_index : start_index + batch_size
]

with torch.no_grad():
out_features = model.encode(sentences_batch, **kwargs)

if output_value == "token_embeddings":
embeddings = []
for token_emb, attention in zip(
out_features[output_value],
out_features["attention_mask"],
):
last_mask_id = len(attention) - 1
while (
last_mask_id > 0
and attention[last_mask_id].item() == 0
):
last_mask_id -= 1

embeddings.append(token_emb[0 : last_mask_id + 1])
elif output_value is None: # Return all outputs
embeddings = []
for sent_idx in range(
len(out_features["sentence_embedding"])
):
row = {
name: out_features[name][sent_idx]
for name in out_features
}
embeddings.append(row)
# for sparse embedding
else:
if kwargs.get("return_sparse"):
embeddings = out_features["lexical_weights"]
else:
embeddings = out_features["dense_vecs"]

if convert_to_tensor:
convert_to_numpy = False

if output_value != "sparse_embedding":
convert_to_tensor = False
convert_to_numpy = False

input_was_string = False
if isinstance(sentences, str) or not hasattr(
sentences, "__len__"
): # Cast an individual sentence to a list with length 1
sentences = [sentences]
input_was_string = True

if device is None:
# Same as SentenceTransformer.py
from sentence_transformers.util import get_device_name

device = get_device_name()
logger.info(f"Use pytorch device_name: {device}")

all_embeddings = []
all_token_nums = 0
if convert_to_numpy:
embeddings = embeddings.cpu()

# The original code does not support other inference engines
def _text_length(text):
if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, "__len__"): # Object has no len() method
return 1
elif len(text) == 0 or isinstance(
text[0], int
): # Empty string or list of ints
return len(text)
else:
return sum(
[len(t) for t in text]
) # Sum of length of individual strings
all_embeddings.extend(embeddings)

length_sorted_idx = np.argsort([-_text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]

for start_index in trange(
0,
len(sentences),
batch_size,
desc="Batches",
disable=not show_progress_bar,
):
sentences_batch = sentences_sorted[
start_index : start_index + batch_size
all_embeddings = [
all_embeddings[idx] for idx in np.argsort(length_sorted_idx)
]

with torch.no_grad():
out_features = model.encode(sentences_batch, **kwargs)

if output_value == "token_embeddings":
embeddings = []
for token_emb, attention in zip(
out_features[output_value], out_features["attention_mask"]
):
last_mask_id = len(attention) - 1
while (
last_mask_id > 0 and attention[last_mask_id].item() == 0
):
last_mask_id -= 1

embeddings.append(token_emb[0 : last_mask_id + 1])
elif output_value is None: # Return all outputs
embeddings = []
for sent_idx in range(len(out_features["sentence_embedding"])):
row = {
name: out_features[name][sent_idx]
for name in out_features
}
embeddings.append(row)
# for sparse embedding
if convert_to_tensor:
if len(all_embeddings):
all_embeddings = torch.stack(all_embeddings)
else:
if kwargs.get("return_sparse"):
embeddings = out_features["lexical_weights"]
else:
embeddings = out_features["dense_vecs"]

if convert_to_numpy:
embeddings = embeddings.cpu()

all_embeddings.extend(embeddings)

all_embeddings = [
all_embeddings[idx] for idx in np.argsort(length_sorted_idx)
]
all_embeddings = torch.Tensor()
elif convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])

if convert_to_tensor:
if len(all_embeddings):
all_embeddings = torch.stack(all_embeddings)
else:
all_embeddings = torch.Tensor()
elif convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
if input_was_string:
all_embeddings = all_embeddings[0]

if input_was_string:
all_embeddings = all_embeddings[0]
return all_embeddings, all_token_nums

return all_embeddings, all_token_nums
except ImportError:
_encode_bgem3 = None

# copied from sentence-transformers, and modify it to return tokens num
@no_type_check
Expand Down Expand Up @@ -582,6 +593,10 @@ def encode(

return all_embeddings, all_token_nums

is_bge_m3_flag_model = (
self._kwargs.get("hybrid_mode")
and "m3" in self._model_spec.model_name.lower()
)
if (
"gte" in self._model_spec.model_name.lower()
and "qwen2" in self._model_spec.model_name.lower()
Expand All @@ -593,7 +608,8 @@ def encode(
convert_to_numpy=False,
**kwargs,
)
elif isinstance(self._model, BGEM3FlagModel):
elif is_bge_m3_flag_model:
assert _encode_bgem3 is not None
all_embeddings, all_token_nums = _encode_bgem3(
self._model, sentences, convert_to_numpy=False, **kwargs
)
Expand All @@ -608,7 +624,7 @@ def encode(
all_embeddings = [all_embeddings]
embedding_list = []
for index, data in enumerate(all_embeddings):
if kwargs.get("return_sparse") and isinstance(self._model, BGEM3FlagModel):
if kwargs.get("return_sparse") and is_bge_m3_flag_model:
embedding_list.append(
EmbeddingData(
index=index,
Expand All @@ -628,8 +644,7 @@ def encode(
result = Embedding(
object=(
"list" # type: ignore
if not isinstance(self._model, BGEM3FlagModel)
and not kwargs.get("return_sparse")
if not is_bge_m3_flag_model and not kwargs.get("return_sparse")
else "dict"
),
model=self._model_uid,
Expand Down

0 comments on commit 15d0978

Please sign in to comment.