From 43ac320dce4034d7a54a01d112818ecf2a31fd02 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Fri, 13 Sep 2024 10:17:49 +0800 Subject: [PATCH] fixed score threshold is none --- api/core/rag/datasource/retrieval_service.py | 2 +- api/core/rag/retrieval/dataset_retrieval.py | 2 +- .../tool/dataset_retriever/dataset_multi_retriever_tool.py | 2 +- api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py | 2 +- api/services/hit_testing_service.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index afac1bf30086e9..b438339ecbcbff 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -28,7 +28,7 @@ def retrieve( dataset_id: str, query: str, top_k: int, - score_threshold: Optional[float] = 0.0, + score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None, reranking_mode: Optional[str] = "reranking_model", weights: Optional[dict] = None, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index cdaca8387d2272..12868d6ae43945 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -429,7 +429,7 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, top_k=top_k, score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] - else None, + else 0.0, reranking_model=retrieval_model.get("reranking_model", None) if retrieval_model["reranking_enable"] else None, diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 067600c6013372..6073b8e92e39b3 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -179,7 +179,7 @@ def _retriever( top_k=self.top_k, score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] - else None, + else 0.0, reranking_model=retrieval_model.get("reranking_model", None) if retrieval_model["reranking_enable"] else None, diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index ad533946a17bd9..8dc60408c93b41 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -72,7 +72,7 @@ def _run(self, query: str) -> str: top_k=self.top_k, score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] - else None, + else 0.0, reranking_model=retrieval_model.get("reranking_model", None) if retrieval_model["reranking_enable"] else None, diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index a9f963dbacfeb4..3dafafd5b46f6c 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -42,7 +42,7 @@ def retrieve(cls, dataset: Dataset, query: str, account: Account, retrieval_mode top_k=retrieval_model.get("top_k", 2), score_threshold=retrieval_model.get("score_threshold", 0.0) if retrieval_model["score_threshold_enabled"] - else None, + else 0.0, reranking_model=retrieval_model.get("reranking_model", None) if retrieval_model["reranking_enable"] else None,