From 3a3306240524f8bd12c16f8998c6f9c1ad5e1d16 Mon Sep 17 00:00:00 2001 From: Weaxs <459312872@qq.com> Date: Fri, 16 Aug 2024 20:21:41 +0800 Subject: [PATCH] feat: support siliconflow rerank (#7337) --- .../siliconflow/rerank/__init__.py | 0 .../rerank/bce-reranker-base_v1.yaml | 4 + .../rerank/bge-reranker-v2-m3.yaml | 4 + .../siliconflow/rerank/rerank.py | 87 +++++++++++++++++++ .../siliconflow/siliconflow.yaml | 3 +- .../model_runtime/siliconflow/test_rerank.py | 49 +++++++++++ 6 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 api/core/model_runtime/model_providers/siliconflow/rerank/__init__.py create mode 100644 api/core/model_runtime/model_providers/siliconflow/rerank/bce-reranker-base_v1.yaml create mode 100644 api/core/model_runtime/model_providers/siliconflow/rerank/bge-reranker-v2-m3.yaml create mode 100644 api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py create mode 100644 api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/__init__.py b/api/core/model_runtime/model_providers/siliconflow/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/bce-reranker-base_v1.yaml b/api/core/model_runtime/model_providers/siliconflow/rerank/bce-reranker-base_v1.yaml new file mode 100644 index 00000000000000..ff3635bfeb1273 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/bce-reranker-base_v1.yaml @@ -0,0 +1,4 @@ +model: netease-youdao/bce-reranker-base_v1 +model_type: rerank +model_properties: + context_size: 512 diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/bge-reranker-v2-m3.yaml b/api/core/model_runtime/model_providers/siliconflow/rerank/bge-reranker-v2-m3.yaml new file mode 100644 index 00000000000000..807f531b084892 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/bge-reranker-v2-m3.yaml @@ -0,0 +1,4 @@ +model: BAAI/bge-reranker-v2-m3 +model_type: rerank +model_properties: + context_size: 8192 diff --git a/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py new file mode 100644 index 00000000000000..683591581638e4 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/rerank/rerank.py @@ -0,0 +1,87 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class SiliconflowRerankModel(RerankModel): + + def _invoke(self, model: str, credentials: dict, query: str, docs: list[str], + score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) -> RerankResult: + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + base_url = credentials.get('base_url', 'https://api.siliconflow.cn/v1') + if base_url.endswith('/'): + base_url = base_url[:-1] + try: + response = httpx.post( + base_url + '/rerank', + json={ + "model": model, + "query": query, + "documents": docs, + "top_n": top_n, + "return_documents": True + }, + headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + ) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results['results']: + rerank_document = RerankDocument( + index=result['index'], + text=result['document']['text'], + score=result['relevance_score'], + ) + if score_threshold is None or result['relevance_score'] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + try: + + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError] + } \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml index 1ebb1e6d8b149c..c46a891604c480 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml @@ -12,10 +12,11 @@ help: en_US: Get your API Key from SiliconFlow zh_Hans: 从 SiliconFlow 获取 API Key url: - en_US: https://cloud.siliconflow.cn/keys + en_US: https://cloud.siliconflow.cn/account/ak supported_model_types: - llm - text-embedding + - rerank - speech2text configurate_methods: - predefined-model diff --git a/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py new file mode 100644 index 00000000000000..7b3ff8272738a4 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/siliconflow/test_rerank.py @@ -0,0 +1,49 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.siliconflow.rerank.rerank import SiliconflowRerankModel + + +def test_validate_credentials(): + model = SiliconflowRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model="BAAI/bge-reranker-v2-m3", + credentials={ + "api_key": "invalid_key" + }, + ) + + model.validate_credentials( + model="BAAI/bge-reranker-v2-m3", + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + ) + + +def test_invoke_model(): + model = SiliconflowRerankModel() + + result = model.invoke( + model='BAAI/bge-reranker-v2-m3', + credentials={ + "api_key": os.environ.get("API_KEY"), + }, + query="Who is Kasumi?", + docs=[ + "Kasumi is a girl's name of Japanese origin meaning \"mist\".", + "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ", + "and she leads a team named PopiParty." + ], + score_threshold=0.8 + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.8