From 0e627c920fd571bb8d76d113707eb123be342b52 Mon Sep 17 00:00:00 2001 From: takatost Date: Sat, 25 Nov 2023 03:56:00 +0800 Subject: [PATCH] feat: xinference rerank model support (#1615) --- .../console/workspace/model_providers.py | 6 +- .../models/reranking/xinference_reranking.py | 58 ++++++++++++++ .../providers/xinference_provider.py | 8 ++ .../model_providers/rules/xinference.json | 3 +- api/requirements.txt | 2 +- api/tests/integration_tests/.env.example | 5 +- .../models/reranking/__init__.py | 0 .../models/reranking/test_cohere_reranking.py | 61 +++++++++++++++ .../reranking/test_xinference_reranking.py | 78 +++++++++++++++++++ 9 files changed, 215 insertions(+), 6 deletions(-) create mode 100644 api/core/model_providers/models/reranking/xinference_reranking.py create mode 100644 api/tests/integration_tests/models/reranking/__init__.py create mode 100644 api/tests/integration_tests/models/reranking/test_cohere_reranking.py create mode 100644 api/tests/integration_tests/models/reranking/test_xinference_reranking.py diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 749ecd64229d47..0cfa8d17ddf34c 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -115,7 +115,7 @@ def post(self, provider_name: str): parser = reqparse.RequestParser() parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=['text-generation', 'embeddings', 'speech2text'], location='json') + choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json') parser.add_argument('config', type=dict, required=True, nullable=False, location='json') args = parser.parse_args() @@ -155,7 +155,7 @@ def post(self, provider_name: str): parser = reqparse.RequestParser() parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=['text-generation', 'embeddings', 'speech2text'], location='json') + choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='json') parser.add_argument('config', type=dict, required=True, nullable=False, location='json') args = parser.parse_args() @@ -184,7 +184,7 @@ def delete(self, provider_name: str): parser = reqparse.RequestParser() parser.add_argument('model_name', type=str, required=True, nullable=False, location='args') parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=['text-generation', 'embeddings', 'speech2text'], location='args') + choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args') args = parser.parse_args() provider_service = ProviderService() diff --git a/api/core/model_providers/models/reranking/xinference_reranking.py b/api/core/model_providers/models/reranking/xinference_reranking.py new file mode 100644 index 00000000000000..0efcf189f01d0c --- /dev/null +++ b/api/core/model_providers/models/reranking/xinference_reranking.py @@ -0,0 +1,58 @@ +import logging +from typing import Optional, List + +from langchain.schema import Document +from xinference_client.client.restful.restful_client import Client + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.reranking.base import BaseReranking +from core.model_providers.providers.base import BaseModelProvider + + +class XinferenceReranking(BaseReranking): + + def __init__(self, model_provider: BaseModelProvider, name: str): + self.credentials = model_provider.get_model_credentials( + model_name=name, + model_type=self.type + ) + + client = Client(self.credentials['server_url']) + + super().__init__(model_provider, client, name) + + def rerank(self, query: str, documents: List[Document], score_threshold: Optional[float], top_k: Optional[int]) -> Optional[List[Document]]: + docs = [] + doc_id = [] + for document in documents: + if document.metadata['doc_id'] not in doc_id: + doc_id.append(document.metadata['doc_id']) + docs.append(document.page_content) + + model = self.client.get_model(self.credentials['model_uid']) + response = model.rerank(query=query, documents=docs, top_n=top_k) + rerank_documents = [] + + for idx, result in enumerate(response['results']): + # format document + index = result['index'] + rerank_document = Document( + page_content=result['document'], + metadata={ + "doc_id": documents[index].metadata['doc_id'], + "doc_hash": documents[index].metadata['doc_hash'], + "document_id": documents[index].metadata['document_id'], + "dataset_id": documents[index].metadata['dataset_id'], + 'score': result['relevance_score'] + } + ) + # score threshold check + if score_threshold is not None: + if result.relevance_score >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + return rerank_documents + + def handle_exceptions(self, ex: Exception) -> Exception: + return LLMBadRequestError(f"Xinference rerank: {str(ex)}") diff --git a/api/core/model_providers/providers/xinference_provider.py b/api/core/model_providers/providers/xinference_provider.py index af1f050b87a8db..133c5e7cf8bd9f 100644 --- a/api/core/model_providers/providers/xinference_provider.py +++ b/api/core/model_providers/providers/xinference_provider.py @@ -2,11 +2,13 @@ from typing import Type import requests +from xinference_client.client.restful.restful_client import Client from core.helper import encrypter from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode from core.model_providers.models.llm.xinference_model import XinferenceModel +from core.model_providers.models.reranking.xinference_reranking import XinferenceReranking from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.models.base import BaseProviderModel @@ -40,6 +42,8 @@ def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: model_class = XinferenceModel elif model_type == ModelType.EMBEDDINGS: model_class = XinferenceEmbedding + elif model_type == ModelType.RERANKING: + model_class = XinferenceReranking else: raise NotImplementedError @@ -113,6 +117,10 @@ def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelT ) embedding.embed_query("ping") + elif model_type == ModelType.RERANKING: + rerank_client = Client(credential_kwargs['server_url']) + model = rerank_client.get_model(credential_kwargs['model_uid']) + model.rerank(query="ping", documents=["ping", "pong"], top_n=2) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) diff --git a/api/core/model_providers/rules/xinference.json b/api/core/model_providers/rules/xinference.json index 3f1ee225f16b3c..3e426a927b502c 100644 --- a/api/core/model_providers/rules/xinference.json +++ b/api/core/model_providers/rules/xinference.json @@ -6,6 +6,7 @@ "model_flexibility": "configurable", "supported_model_types": [ "text-generation", - "embeddings" + "embeddings", + "reranking" ] } \ No newline at end of file diff --git a/api/requirements.txt b/api/requirements.txt index 7b5ed73f8c893d..ca26601e0ecbf7 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -48,7 +48,7 @@ huggingface_hub~=0.16.4 transformers~=4.31.0 stripe~=5.5.0 pandas==1.5.3 -xinference-client~=0.5.4 +xinference-client~=0.6.4 safetensors==0.3.2 zhipuai==1.0.7 werkzeug==2.3.7 diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 7d00ae0f6aaa95..11091aa34c7b39 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -50,4 +50,7 @@ XINFERENCE_MODEL_UID= OPENLLM_SERVER_URL= # LocalAI Credentials -LOCALAI_SERVER_URL= \ No newline at end of file +LOCALAI_SERVER_URL= + +# Cohere Credentials +COHERE_API_KEY= \ No newline at end of file diff --git a/api/tests/integration_tests/models/reranking/__init__.py b/api/tests/integration_tests/models/reranking/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/models/reranking/test_cohere_reranking.py b/api/tests/integration_tests/models/reranking/test_cohere_reranking.py new file mode 100644 index 00000000000000..bbdd94cbeac99e --- /dev/null +++ b/api/tests/integration_tests/models/reranking/test_cohere_reranking.py @@ -0,0 +1,61 @@ +import json +import os +from unittest.mock import patch + +from langchain.schema import Document + +from core.model_providers.models.reranking.cohere_reranking import CohereReranking +from core.model_providers.providers.cohere_provider import CohereProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_api_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='cohere', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({'api_key': valid_api_key}), + is_valid=True, + ) + + +def get_mock_model(): + valid_api_key = os.environ['COHERE_API_KEY'] + provider = CohereProvider(provider=get_mock_provider(valid_api_key)) + return CohereReranking( + model_provider=provider, + name='rerank-english-v2.0' + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt): + model = get_mock_model() + + docs = [] + docs.append(Document( + page_content='bye', + metadata={ + "doc_id": 'a', + "doc_hash": 'doc_hash', + "document_id": 'document_id', + "dataset_id": 'dataset_id', + } + )) + docs.append(Document( + page_content='hello', + metadata={ + "doc_id": 'b', + "doc_hash": 'doc_hash', + "document_id": 'document_id', + "dataset_id": 'dataset_id', + } + )) + rst = model.rerank('hello', docs, None, 2) + + assert rst[0].page_content == 'hello' diff --git a/api/tests/integration_tests/models/reranking/test_xinference_reranking.py b/api/tests/integration_tests/models/reranking/test_xinference_reranking.py new file mode 100644 index 00000000000000..1f22247e84ff93 --- /dev/null +++ b/api/tests/integration_tests/models/reranking/test_xinference_reranking.py @@ -0,0 +1,78 @@ +import json +import os +from unittest.mock import patch, MagicMock + +from langchain.schema import Document + +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.models.reranking.xinference_reranking import XinferenceReranking +from core.model_providers.providers.xinference_provider import XinferenceProvider +from models.provider import Provider, ProviderType, ProviderModel + + +def get_mock_provider(valid_server_url, valid_model_uid): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='xinference', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({'server_url': valid_server_url, 'model_uid': valid_model_uid}), + is_valid=True, + ) + + +def get_mock_model(mocker): + valid_server_url = os.environ['XINFERENCE_SERVER_URL'] + valid_model_uid = os.environ['XINFERENCE_MODEL_UID'] + model_name = 'bge-reranker-base' + provider = XinferenceProvider(provider=get_mock_provider(valid_server_url, valid_model_uid)) + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='xinference', + model_name=model_name, + model_type=ModelType.RERANKING.value, + encrypted_config=json.dumps({ + 'server_url': valid_server_url, + 'model_uid': valid_model_uid + }), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + return XinferenceReranking( + model_provider=provider, + name=model_name + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt, mocker): + model = get_mock_model(mocker) + + docs = [] + docs.append(Document( + page_content='bye', + metadata={ + "doc_id": 'a', + "doc_hash": 'doc_hash', + "document_id": 'document_id', + "dataset_id": 'dataset_id', + } + )) + docs.append(Document( + page_content='hello', + metadata={ + "doc_id": 'b', + "doc_hash": 'doc_hash', + "document_id": 'document_id', + "dataset_id": 'dataset_id', + } + )) + rst = model.rerank('hello', docs, None, 2) + + assert rst[0].page_content == 'hello'