Skip to content

Commit

Permalink
Feat: Add model provider Text Embedding Inference for embedding and r…
Browse files Browse the repository at this point in the history
…erank (#7132)
  • Loading branch information
liuyanyi authored Aug 9, 2024
1 parent 4cbeb68 commit 5b32f2e
Show file tree
Hide file tree
Showing 13 changed files with 815 additions and 0 deletions.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import logging

from core.model_runtime.model_providers.__base.model_provider import ModelProvider

logger = logging.getLogger(__name__)


class HuggingfaceTeiProvider(ModelProvider):

def validate_provider_credentials(self, credentials: dict) -> None:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
provider: huggingface_tei
label:
en_US: Text Embedding Inference
description:
en_US: A blazing fast inference solution for text embeddings models.
zh_Hans: 用于文本嵌入模型的超快速推理解决方案。
background: "#FFF8DC"
help:
title:
en_US: How to deploy Text Embedding Inference
zh_Hans: 如何部署 Text Embedding Inference
url:
en_US: https://github.com/huggingface/text-embeddings-inference
supported_model_types:
- text-embedding
- rerank
configurate_methods:
- customizable-model
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: server_url
label:
zh_Hans: 服务器URL
en_US: Server url
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080
en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from typing import Optional

import httpx

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.huggingface_tei.tei_helper import TeiHelper


class HuggingfaceTeiRerankModel(RerankModel):
"""
Model class for Text Embedding Inference rerank model.
"""

def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
server_url = credentials['server_url']

if server_url.endswith('/'):
server_url = server_url[:-1]

try:
results = TeiHelper.invoke_rerank(server_url, query, docs)

rerank_documents = []
for result in results:
rerank_document = RerankDocument(
index=result['index'],
text=result['text'],
score=result['score'],
)
if score_threshold is None or result['score'] >= score_threshold:
rerank_documents.append(rerank_document)
if top_n is not None and len(rerank_documents) >= top_n:
break

return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
server_url = credentials['server_url']
extra_args = TeiHelper.get_tei_extra_parameter(server_url, model)
if extra_args.model_type != 'reranker':
raise CredentialsValidateFailedError('Current model is not a rerank model')

credentials['context_size'] = extra_args.max_input_length

self.invoke(
model=model,
credentials=credentials,
query='Whose kasumi',
docs=[
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
'Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ',
'and she leads a team named PopiParty.',
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError, KeyError, ValueError],
}

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.RERANK,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 512)),
},
parameter_rules=[],
)

return entity
183 changes: 183 additions & 0 deletions api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from threading import Lock
from time import time
from typing import Optional

import httpx
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError, MissingSchema, Timeout
from requests.sessions import Session
from yarl import URL


class TeiModelExtraParameter:
model_type: str
max_input_length: int
max_client_batch_size: int

def __init__(self, model_type: str, max_input_length: int, max_client_batch_size: Optional[int] = None) -> None:
self.model_type = model_type
self.max_input_length = max_input_length
self.max_client_batch_size = max_client_batch_size


cache = {}
cache_lock = Lock()


class TeiHelper:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
TeiHelper._clean_cache()
with cache_lock:
if model_name not in cache:
cache[model_name] = {
'expires': time() + 300,
'value': TeiHelper._get_tei_extra_parameter(server_url),
}
return cache[model_name]['value']

@staticmethod
def _clean_cache() -> None:
try:
with cache_lock:
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
for model_uid in expired_keys:
del cache[model_uid]
except RuntimeError as e:
pass

@staticmethod
def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter:
"""
get tei model extra parameter like model_type, max_input_length, max_batch_requests
"""

url = str(URL(server_url) / 'info')

# this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
session = Session()
session.mount('http://', HTTPAdapter(max_retries=3))
session.mount('https://', HTTPAdapter(max_retries=3))

try:
response = session.get(url, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get tei model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200:
raise RuntimeError(
f'get tei model extra parameter failed, status code: {response.status_code}, response: {response.text}'
)

response_json = response.json()

model_type = response_json.get('model_type', {})
if len(model_type.keys()) < 1:
raise RuntimeError('model_type is empty')
model_type = list(model_type.keys())[0]
if model_type not in ['embedding', 'reranker']:
raise RuntimeError(f'invalid model_type: {model_type}')

max_input_length = response_json.get('max_input_length', 512)
max_client_batch_size = response_json.get('max_client_batch_size', 1)

return TeiModelExtraParameter(
model_type=model_type,
max_input_length=max_input_length,
max_client_batch_size=max_client_batch_size
)

@staticmethod
def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]:
"""
Invoke tokenize endpoint
Example response:
[
[
{
"id": 0,
"text": "<s>",
"special": true,
"start": null,
"stop": null
},
{
"id": 7704,
"text": "str",
"special": false,
"start": 0,
"stop": 3
},
< MORE TOKENS >
]
]
:param server_url: server url
:param texts: texts to tokenize
"""
resp = httpx.post(
f'{server_url}/tokenize',
json={'inputs': texts},
)
resp.raise_for_status()
return resp.json()

@staticmethod
def invoke_embeddings(server_url: str, texts: list[str]) -> dict:
"""
Invoke embeddings endpoint
Example response:
{
"object": "list",
"data": [
{
"object": "embedding",
"embedding": [...],
"index": 0
}
],
"model": "MODEL_NAME",
"usage": {
"prompt_tokens": 3,
"total_tokens": 3
}
}
:param server_url: server url
:param texts: texts to embed
"""
# Use OpenAI compatible API here, which has usage tracking
resp = httpx.post(
f'{server_url}/v1/embeddings',
json={'input': texts},
)
resp.raise_for_status()
return resp.json()

@staticmethod
def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]:
"""
Invoke rerank endpoint
Example response:
[
{
"index": 0,
"text": "Deep Learning is ...",
"score": 0.9950755
}
]
:param server_url: server url
:param texts: texts to rerank
:param candidates: candidates to rerank
"""
params = {'query': query, 'texts': docs, 'return_text': True}

response = httpx.post(
server_url + '/rerank',
json=params,
)
response.raise_for_status()
return response.json()
Empty file.
Loading

0 comments on commit 5b32f2e

Please sign in to comment.