-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat: Add model provider Text Embedding Inference for embedding and r…
…erank (#7132)
- Loading branch information
Showing
13 changed files
with
815 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
11 changes: 11 additions & 0 deletions
11
api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import logging | ||
|
||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class HuggingfaceTeiProvider(ModelProvider): | ||
|
||
def validate_provider_credentials(self, credentials: dict) -> None: | ||
pass |
36 changes: 36 additions & 0 deletions
36
api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
provider: huggingface_tei | ||
label: | ||
en_US: Text Embedding Inference | ||
description: | ||
en_US: A blazing fast inference solution for text embeddings models. | ||
zh_Hans: 用于文本嵌入模型的超快速推理解决方案。 | ||
background: "#FFF8DC" | ||
help: | ||
title: | ||
en_US: How to deploy Text Embedding Inference | ||
zh_Hans: 如何部署 Text Embedding Inference | ||
url: | ||
en_US: https://github.com/huggingface/text-embeddings-inference | ||
supported_model_types: | ||
- text-embedding | ||
- rerank | ||
configurate_methods: | ||
- customizable-model | ||
model_credential_schema: | ||
model: | ||
label: | ||
en_US: Model Name | ||
zh_Hans: 模型名称 | ||
placeholder: | ||
en_US: Enter your model name | ||
zh_Hans: 输入模型名称 | ||
credential_form_schemas: | ||
- variable: server_url | ||
label: | ||
zh_Hans: 服务器URL | ||
en_US: Server url | ||
type: secret-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080 | ||
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 |
Empty file.
137 changes: 137 additions & 0 deletions
137
api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
from typing import Optional | ||
|
||
import httpx | ||
|
||
from core.model_runtime.entities.common_entities import I18nObject | ||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType | ||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | ||
from core.model_runtime.errors.invoke import ( | ||
InvokeAuthorizationError, | ||
InvokeBadRequestError, | ||
InvokeConnectionError, | ||
InvokeError, | ||
InvokeRateLimitError, | ||
InvokeServerUnavailableError, | ||
) | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel | ||
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper | ||
|
||
|
||
class HuggingfaceTeiRerankModel(RerankModel): | ||
""" | ||
Model class for Text Embedding Inference rerank model. | ||
""" | ||
|
||
def _invoke( | ||
self, | ||
model: str, | ||
credentials: dict, | ||
query: str, | ||
docs: list[str], | ||
score_threshold: Optional[float] = None, | ||
top_n: Optional[int] = None, | ||
user: Optional[str] = None, | ||
) -> RerankResult: | ||
""" | ||
Invoke rerank model | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param query: search query | ||
:param docs: docs for reranking | ||
:param score_threshold: score threshold | ||
:param top_n: top n | ||
:param user: unique user id | ||
:return: rerank result | ||
""" | ||
if len(docs) == 0: | ||
return RerankResult(model=model, docs=[]) | ||
server_url = credentials['server_url'] | ||
|
||
if server_url.endswith('/'): | ||
server_url = server_url[:-1] | ||
|
||
try: | ||
results = TeiHelper.invoke_rerank(server_url, query, docs) | ||
|
||
rerank_documents = [] | ||
for result in results: | ||
rerank_document = RerankDocument( | ||
index=result['index'], | ||
text=result['text'], | ||
score=result['score'], | ||
) | ||
if score_threshold is None or result['score'] >= score_threshold: | ||
rerank_documents.append(rerank_document) | ||
if top_n is not None and len(rerank_documents) >= top_n: | ||
break | ||
|
||
return RerankResult(model=model, docs=rerank_documents) | ||
except httpx.HTTPStatusError as e: | ||
raise InvokeServerUnavailableError(str(e)) | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
""" | ||
Validate model credentials | ||
:param model: model name | ||
:param credentials: model credentials | ||
:return: | ||
""" | ||
try: | ||
server_url = credentials['server_url'] | ||
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) | ||
if extra_args.model_type != 'reranker': | ||
raise CredentialsValidateFailedError('Current model is not a rerank model') | ||
|
||
credentials['context_size'] = extra_args.max_input_length | ||
|
||
self.invoke( | ||
model=model, | ||
credentials=credentials, | ||
query='Whose kasumi', | ||
docs=[ | ||
'Kasumi is a girl\'s name of Japanese origin meaning "mist".', | ||
'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ', | ||
'and she leads a team named PopiParty.', | ||
], | ||
score_threshold=0.8, | ||
) | ||
except Exception as ex: | ||
raise CredentialsValidateFailedError(str(ex)) | ||
|
||
@property | ||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | ||
""" | ||
Map model invoke error to unified error | ||
The key is the error type thrown to the caller | ||
The value is the error type thrown by the model, | ||
which needs to be converted into a unified error type for the caller. | ||
:return: Invoke error mapping | ||
""" | ||
return { | ||
InvokeConnectionError: [InvokeConnectionError], | ||
InvokeServerUnavailableError: [InvokeServerUnavailableError], | ||
InvokeRateLimitError: [InvokeRateLimitError], | ||
InvokeAuthorizationError: [InvokeAuthorizationError], | ||
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError], | ||
} | ||
|
||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: | ||
""" | ||
used to define customizable model schema | ||
""" | ||
entity = AIModelEntity( | ||
model=model, | ||
label=I18nObject(en_US=model), | ||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, | ||
model_type=ModelType.RERANK, | ||
model_properties={ | ||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)), | ||
}, | ||
parameter_rules=[], | ||
) | ||
|
||
return entity |
183 changes: 183 additions & 0 deletions
183
api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
from threading import Lock | ||
from time import time | ||
from typing import Optional | ||
|
||
import httpx | ||
from requests.adapters import HTTPAdapter | ||
from requests.exceptions import ConnectionError, MissingSchema, Timeout | ||
from requests.sessions import Session | ||
from yarl import URL | ||
|
||
|
||
class TeiModelExtraParameter: | ||
model_type: str | ||
max_input_length: int | ||
max_client_batch_size: int | ||
|
||
def __init__(self, model_type: str, max_input_length: int, max_client_batch_size: Optional[int] = None) -> None: | ||
self.model_type = model_type | ||
self.max_input_length = max_input_length | ||
self.max_client_batch_size = max_client_batch_size | ||
|
||
|
||
cache = {} | ||
cache_lock = Lock() | ||
|
||
|
||
class TeiHelper: | ||
@staticmethod | ||
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: | ||
TeiHelper._clean_cache() | ||
with cache_lock: | ||
if model_name not in cache: | ||
cache[model_name] = { | ||
'expires': time() + 300, | ||
'value': TeiHelper._get_tei_extra_parameter(server_url), | ||
} | ||
return cache[model_name]['value'] | ||
|
||
@staticmethod | ||
def _clean_cache() -> None: | ||
try: | ||
with cache_lock: | ||
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()] | ||
for model_uid in expired_keys: | ||
del cache[model_uid] | ||
except RuntimeError as e: | ||
pass | ||
|
||
@staticmethod | ||
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: | ||
""" | ||
get tei model extra parameter like model_type, max_input_length, max_batch_requests | ||
""" | ||
|
||
url = str(URL(server_url) / 'info') | ||
|
||
# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 | ||
session = Session() | ||
session.mount('http://', HTTPAdapter(max_retries=3)) | ||
session.mount('https://', HTTPAdapter(max_retries=3)) | ||
|
||
try: | ||
response = session.get(url, timeout=10) | ||
except (MissingSchema, ConnectionError, Timeout) as e: | ||
raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}') | ||
if response.status_code != 200: | ||
raise RuntimeError( | ||
f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}' | ||
) | ||
|
||
response_json = response.json() | ||
|
||
model_type = response_json.get('model_type', {}) | ||
if len(model_type.keys()) < 1: | ||
raise RuntimeError('model_type is empty') | ||
model_type = list(model_type.keys())[0] | ||
if model_type not in ['embedding', 'reranker']: | ||
raise RuntimeError(f'invalid model_type: {model_type}') | ||
|
||
max_input_length = response_json.get('max_input_length', 512) | ||
max_client_batch_size = response_json.get('max_client_batch_size', 1) | ||
|
||
return TeiModelExtraParameter( | ||
model_type=model_type, | ||
max_input_length=max_input_length, | ||
max_client_batch_size=max_client_batch_size | ||
) | ||
|
||
@staticmethod | ||
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: | ||
""" | ||
Invoke tokenize endpoint | ||
Example response: | ||
[ | ||
[ | ||
{ | ||
"id": 0, | ||
"text": "<s>", | ||
"special": true, | ||
"start": null, | ||
"stop": null | ||
}, | ||
{ | ||
"id": 7704, | ||
"text": "str", | ||
"special": false, | ||
"start": 0, | ||
"stop": 3 | ||
}, | ||
< MORE TOKENS > | ||
] | ||
] | ||
:param server_url: server url | ||
:param texts: texts to tokenize | ||
""" | ||
resp = httpx.post( | ||
f'{server_url}/tokenize', | ||
json={'inputs': texts}, | ||
) | ||
resp.raise_for_status() | ||
return resp.json() | ||
|
||
@staticmethod | ||
def invoke_embeddings(server_url: str, texts: list[str]) -> dict: | ||
""" | ||
Invoke embeddings endpoint | ||
Example response: | ||
{ | ||
"object": "list", | ||
"data": [ | ||
{ | ||
"object": "embedding", | ||
"embedding": [...], | ||
"index": 0 | ||
} | ||
], | ||
"model": "MODEL_NAME", | ||
"usage": { | ||
"prompt_tokens": 3, | ||
"total_tokens": 3 | ||
} | ||
} | ||
:param server_url: server url | ||
:param texts: texts to embed | ||
""" | ||
# Use OpenAI compatible API here, which has usage tracking | ||
resp = httpx.post( | ||
f'{server_url}/v1/embeddings', | ||
json={'input': texts}, | ||
) | ||
resp.raise_for_status() | ||
return resp.json() | ||
|
||
@staticmethod | ||
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: | ||
""" | ||
Invoke rerank endpoint | ||
Example response: | ||
[ | ||
{ | ||
"index": 0, | ||
"text": "Deep Learning is ...", | ||
"score": 0.9950755 | ||
} | ||
] | ||
:param server_url: server url | ||
:param texts: texts to rerank | ||
:param candidates: candidates to rerank | ||
""" | ||
params = {'query': query, 'texts': docs, 'return_text': True} | ||
|
||
response = httpx.post( | ||
server_url + '/rerank', | ||
json=params, | ||
) | ||
response.raise_for_status() | ||
return response.json() |
Empty file.
Oops, something went wrong.