From 7a95dd701fc508594efff294e7239802050d539f Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 10 Sep 2024 12:52:50 +0800 Subject: [PATCH] code merge error (#8183) Co-authored-by: crazywoola <427733928@qq.com> --- api/controllers/console/datasets/datasets_document.py | 4 ++++ api/services/dataset_service.py | 6 +----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 6bc29a86435fa4..076f3cd44d5af5 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -302,6 +302,8 @@ def post(self): "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + parser.add_argument("embedding_model", type=str, required=False, nullable=True, location="json") + parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator @@ -309,6 +311,8 @@ def post(self): raise Forbidden() if args["indexing_technique"] == "high_quality": + if args["embedding_model"] is None or args["embedding_model_provider"] is None: + raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: model_manager = ModelManager() model_manager.get_default_model_instance( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 4a11de281c9438..cce0874cf4f9b2 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -1057,12 +1057,8 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun dataset_collection_binding_id = None retrieval_model = None if document_data["indexing_technique"] == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model + document_data["embedding_model_provider"], document_data["embedding_model"] ) dataset_collection_binding_id = dataset_collection_binding.id if document_data.get("retrieval_model"):