diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index d7c431b95080da..3a4a6d75e1d3a8 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -218,7 +218,7 @@ def post(self): args["doc_form"], args["doc_language"], ) - return response, 200 + return response.model_dump(), 200 class DataSourceNotionDatasetSyncApi(Resource): diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f3c3736b25acc5..0c0d2e20035b43 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -464,7 +464,7 @@ def post(self): except Exception as e: raise IndexingEstimateError(str(e)) - return response, 200 + return response.model_dump(), 200 class DatasetRelatedAppListApi(Resource): @@ -733,6 +733,18 @@ def get(self, dataset_id): }, 200 +class DatasetAutoDisableLogApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id): + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + if dataset is None: + raise NotFound("Dataset not found.") + return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200 + + api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetApi, "/datasets/") api.add_resource(DatasetUseCheckApi, "/datasets//use-check") @@ -747,3 +759,4 @@ def get(self, dataset_id): api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users") +api.add_resource(DatasetAutoDisableLogApi, "/datasets//auto-disable-logs") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index ca41e504be7eda..552a2ab3ff819d 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -52,6 +52,7 @@ from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from services.dataset_service import DatasetService, DocumentService +from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from tasks.add_document_to_index_task import add_document_to_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task @@ -255,20 +256,22 @@ def post(self, dataset_id): parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") parser.add_argument("original_document_id", type=str, required=False, location="json") parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") + parser.add_argument( "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) - parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() + knowledge_config = KnowledgeConfig(**args) - if not dataset.indexing_technique and not args["indexing_technique"]: + if not dataset.indexing_technique and not knowledge_config.indexing_technique: raise ValueError("indexing_technique is required.") # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: - documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: @@ -278,6 +281,25 @@ def post(self, dataset_id): return {"documents": documents, "batch": batch} + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id): + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + + try: + document_ids = request.args.getlist("document_id") + DocumentService.delete_documents(dataset, document_ids) + except services.errors.document.DocumentIndexingError: + raise DocumentIndexingError("Cannot delete document during indexing.") + + return {"result": "success"}, 204 + class DatasetInitApi(Resource): @setup_required @@ -313,9 +335,9 @@ def post(self): # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - - if args["indexing_technique"] == "high_quality": - if args["embedding_model"] is None or args["embedding_model_provider"] is None: + knowledge_config = KnowledgeConfig(**args) + if knowledge_config.indexing_technique == "high_quality": + if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: model_manager = ModelManager() @@ -334,11 +356,11 @@ def post(self): raise ProviderNotInitializeError(ex.description) # validate args - DocumentService.document_create_args_validate(args) + DocumentService.document_create_args_validate(knowledge_config) try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, document_data=args, account=current_user + tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -409,7 +431,7 @@ def get(self, dataset_id, document_id): except Exception as e: raise IndexingEstimateError(str(e)) - return response + return response.model_dump(), 200 class DocumentBatchIndexingEstimateApi(DocumentResource): @@ -422,7 +444,7 @@ def get(self, dataset_id, batch): documents = self.get_batch_documents(dataset_id, batch) response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} if not documents: - return response + return response, 200 data_process_rule = documents[0].dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() info_list = [] @@ -509,7 +531,7 @@ def get(self, dataset_id, batch): raise ProviderNotInitializeError(ex.description) except Exception as e: raise IndexingEstimateError(str(e)) - return response + return response.model_dump(), 200 class DocumentBatchIndexingStatusApi(DocumentResource): @@ -582,7 +604,8 @@ def get(self, dataset_id, document_id): if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} elif metadata == "without": - process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id) + document_process_rules = document.dataset_process_rule.to_dict() data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -590,7 +613,8 @@ def get(self, dataset_id, document_id): "data_source_type": document.data_source_type, "data_source_info": data_source_info, "dataset_process_rule_id": document.dataset_process_rule_id, - "dataset_process_rule": process_rules, + "dataset_process_rule": dataset_process_rules, + "document_process_rule": document_process_rules, "name": document.name, "created_from": document.created_from, "created_by": document.created_by, @@ -613,7 +637,8 @@ def get(self, dataset_id, document_id): "doc_language": document.doc_language, } else: - process_rules = DatasetService.get_process_rules(dataset_id) + dataset_process_rules = DatasetService.get_process_rules(dataset_id) + document_process_rules = document.dataset_process_rule.to_dict() data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -621,7 +646,8 @@ def get(self, dataset_id, document_id): "data_source_type": document.data_source_type, "data_source_info": data_source_info, "dataset_process_rule_id": document.dataset_process_rule_id, - "dataset_process_rule": process_rules, + "dataset_process_rule": dataset_process_rules, + "document_process_rule": document_process_rules, "name": document.name, "created_from": document.created_from, "created_by": document.created_by, @@ -757,9 +783,8 @@ class DocumentStatusApi(DocumentResource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") - def patch(self, dataset_id, document_id, action): + def patch(self, dataset_id, action): dataset_id = str(dataset_id) - document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if dataset is None: raise NotFound("Dataset not found.") @@ -774,84 +799,79 @@ def patch(self, dataset_id, document_id, action): # check user's permission DatasetService.check_dataset_permission(dataset, current_user) - document = self.get_document(dataset_id, document_id) + document_ids = request.args.getlist("document_id") + for document_id in document_ids: + document = self.get_document(dataset_id, document_id) - indexing_cache_key = "document_{}_indexing".format(document.id) - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - raise InvalidActionError("Document is being indexed, please try again later") + indexing_cache_key = "document_{}_indexing".format(document.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later") - if action == "enable": - if document.enabled: - raise InvalidActionError("Document already enabled.") + if action == "enable": + if document.enabled: + continue + document.enabled = True + document.disabled_at = None + document.disabled_by = None + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - document.enabled = True - document.disabled_at = None - document.disabled_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + add_document_to_index_task.delay(document_id) - add_document_to_index_task.delay(document_id) + elif action == "disable": + if not document.completed_at or document.indexing_status != "completed": + raise InvalidActionError(f"Document: {document.name} is not completed.") + if not document.enabled: + continue - return {"result": "success"}, 200 + document.enabled = False + document.disabled_at = datetime.now(UTC).replace(tzinfo=None) + document.disabled_by = current_user.id + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - elif action == "disable": - if not document.completed_at or document.indexing_status != "completed": - raise InvalidActionError("Document is not completed.") - if not document.enabled: - raise InvalidActionError("Document already disabled.") + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) - document.enabled = False - document.disabled_at = datetime.now(UTC).replace(tzinfo=None) - document.disabled_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + remove_document_from_index_task.delay(document_id) - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + elif action == "archive": + if document.archived: + continue - remove_document_from_index_task.delay(document_id) + document.archived = True + document.archived_at = datetime.now(UTC).replace(tzinfo=None) + document.archived_by = current_user.id + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - return {"result": "success"}, 200 + if document.enabled: + # Set cache to prevent indexing the same document multiple times + redis_client.setex(indexing_cache_key, 600, 1) - elif action == "archive": - if document.archived: - raise InvalidActionError("Document already archived.") + remove_document_from_index_task.delay(document_id) - document.archived = True - document.archived_at = datetime.now(UTC).replace(tzinfo=None) - document.archived_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() + elif action == "un_archive": + if not document.archived: + continue + document.archived = False + document.archived_at = None + document.archived_by = None + document.updated_at = datetime.now(UTC).replace(tzinfo=None) + db.session.commit() - if document.enabled: # Set cache to prevent indexing the same document multiple times redis_client.setex(indexing_cache_key, 600, 1) - remove_document_from_index_task.delay(document_id) - - return {"result": "success"}, 200 - elif action == "un_archive": - if not document.archived: - raise InvalidActionError("Document is not archived.") - - document.archived = False - document.archived_at = None - document.archived_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) + add_document_to_index_task.delay(document_id) - add_document_to_index_task.delay(document_id) - - return {"result": "success"}, 200 - else: - raise InvalidActionError() + else: + raise InvalidActionError() + return {"result": "success"}, 200 class DocumentPauseApi(DocumentResource): @@ -1022,7 +1042,7 @@ def get(self, dataset_id, document_id): ) api.add_resource(DocumentDeleteApi, "/datasets//documents/") api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") -api.add_resource(DocumentStatusApi, "/datasets//documents//status/") +api.add_resource(DocumentStatusApi, "/datasets//documents/status//batch") api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") api.add_resource(DocumentRecoverApi, "/datasets//documents//processing/resume") api.add_resource(DocumentRetryApi, "/datasets//retry") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 2d5933ca23609a..96654c09fd0223 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,5 +1,4 @@ import uuid -from datetime import UTC, datetime import pandas as pd from flask import request @@ -10,7 +9,13 @@ import services from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError -from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError +from controllers.console.datasets.error import ( + ChildChunkDeleteIndexError, + ChildChunkIndexingError, + InvalidActionError, + NoFileUploadedError, + TooManyFilesError, +) from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_knowledge_limit_check, @@ -20,15 +25,15 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db from extensions.ext_redis import redis_client -from fields.segment_fields import segment_fields +from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import login_required -from models import DocumentSegment +from models.dataset import ChildChunk, DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService +from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs +from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError +from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task -from tasks.disable_segment_from_index_task import disable_segment_from_index_task -from tasks.enable_segment_to_index_task import enable_segment_to_index_task class DatasetDocumentSegmentListApi(Resource): @@ -53,15 +58,16 @@ def get(self, dataset_id, document_id): raise NotFound("Document not found.") parser = reqparse.RequestParser() - parser.add_argument("last_id", type=str, default=None, location="args") parser.add_argument("limit", type=int, default=20, location="args") parser.add_argument("status", type=str, action="append", default=[], location="args") parser.add_argument("hit_count_gte", type=int, default=None, location="args") parser.add_argument("enabled", type=str, default="all", location="args") parser.add_argument("keyword", type=str, default=None, location="args") + parser.add_argument("page", type=int, default=1, location="args") + args = parser.parse_args() - last_id = args["last_id"] + page = args["page"] limit = min(args["limit"], 100) status_list = args["status"] hit_count_gte = args["hit_count_gte"] @@ -69,14 +75,7 @@ def get(self, dataset_id, document_id): query = DocumentSegment.query.filter( DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ) - - if last_id is not None: - last_segment = db.session.get(DocumentSegment, str(last_id)) - if last_segment: - query = query.filter(DocumentSegment.position > last_segment.position) - else: - return {"data": [], "has_more": False, "limit": limit}, 200 + ).order_by(DocumentSegment.position.asc()) if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) @@ -93,21 +92,44 @@ def get(self, dataset_id, document_id): elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) - total = query.count() - segments = query.order_by(DocumentSegment.position).limit(limit + 1).all() - - has_more = False - if len(segments) > limit: - has_more = True - segments = segments[:-1] + segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) - return { - "data": marshal(segments, segment_fields), - "doc_form": document.doc_form, - "has_more": has_more, + response = { + "data": marshal(segments.items, segment_fields), "limit": limit, - "total": total, - }, 200 + "total": segments.total, + "total_pages": segments.pages, + "page": page, + } + return response, 200 + + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id, document_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + segment_ids = request.args.getlist("segment_id") + + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + SegmentService.delete_segments(segment_ids, document, dataset) + return {"result": "success"}, 200 class DatasetDocumentSegmentApi(Resource): @@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource): @login_required @account_initialization_required @cloud_edition_billing_resource_check("vector_space") - def patch(self, dataset_id, segment_id, action): + def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound("Dataset not found.") + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # The role of the current user in the ta table must be admin, owner, or editor @@ -147,59 +173,17 @@ def patch(self, dataset_id, segment_id, action): ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) + segment_ids = request.args.getlist("segment_id") - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() - - if not segment: - raise NotFound("Segment not found.") - - if segment.status != "completed": - raise NotFound("Segment is not completed, enable or disable function is not allowed") - - document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) + document_indexing_cache_key = "document_{}_indexing".format(document.id) cache_result = redis_client.get(document_indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") - - indexing_cache_key = "segment_{}_indexing".format(segment.id) - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - raise InvalidActionError("Segment is being indexed, please try again later") - - if action == "enable": - if segment.enabled: - raise InvalidActionError("Segment is already enabled.") - - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.commit() - - # Set cache to prevent indexing the same segment multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - enable_segment_to_index_task.delay(segment.id) - - return {"result": "success"}, 200 - elif action == "disable": - if not segment.enabled: - raise InvalidActionError("Segment is already disabled.") - - segment.enabled = False - segment.disabled_at = datetime.now(UTC).replace(tzinfo=None) - segment.disabled_by = current_user.id - db.session.commit() - - # Set cache to prevent indexing the same segment multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - disable_segment_from_index_task.delay(segment.id) - - return {"result": "success"}, 200 - else: - raise InvalidActionError() + try: + SegmentService.update_segments_status(segment_ids, action, dataset, document) + except Exception as e: + raise InvalidActionError(str(e)) + return {"result": "success"}, 200 class DatasetDocumentSegmentAddApi(Resource): @@ -307,9 +291,12 @@ def patch(self, dataset_id, document_id, segment_id): parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument("answer", type=str, required=False, nullable=True, location="json") parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") + parser.add_argument( + "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" + ) args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.update_segment(args, segment, document, dataset) + segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @@ -412,8 +399,248 @@ def get(self, job_id): return {"job_id": job_id, "job_status": cache_result.decode()}, 200 +class ChildChunkAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") + def post(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + if not current_user.is_editor: + raise Forbidden() + # check embedding model setting + if dataset.indexing_technique == "high_quality": + try: + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunk, child_chunk_fields)}, 200 + + @setup_required + @login_required + @account_initialization_required + def get(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + parser = reqparse.RequestParser() + parser.add_argument("limit", type=int, default=20, location="args") + parser.add_argument("keyword", type=str, default=None, location="args") + parser.add_argument("page", type=int, default=1, location="args") + + args = parser.parse_args() + + page = args["page"] + limit = min(args["limit"], 100) + keyword = args["keyword"] + + child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) + return { + "data": marshal(child_chunks.items, child_chunk_fields), + "total": child_chunks.total, + "total_pages": child_chunks.pages, + "page": page, + "limit": limit, + }, 200 + + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + def patch(self, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # The role of the current user in the ta table must be admin, owner, or editor + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] + child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunks, child_chunk_fields)}, 200 + + +class ChildChunkUpdateApi(Resource): + @setup_required + @login_required + @account_initialization_required + def delete(self, dataset_id, document_id, segment_id, child_chunk_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # check child chunk + child_chunk_id = str(child_chunk_id) + child_chunk = ChildChunk.query.filter( + ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id + ).first() + if not child_chunk: + raise NotFound("Child chunk not found.") + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + try: + SegmentService.delete_child_chunk(child_chunk, dataset) + except ChildChunkDeleteIndexServiceError as e: + raise ChildChunkDeleteIndexError(str(e)) + return {"result": "success"}, 200 + + @setup_required + @login_required + @account_initialization_required + @cloud_edition_billing_resource_check("vector_space") + def patch(self, dataset_id, document_id, segment_id, child_chunk_id): + # check dataset + dataset_id = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id) + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = DocumentSegment.query.filter( + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id + ).first() + if not segment: + raise NotFound("Segment not found.") + # check child chunk + child_chunk_id = str(child_chunk_id) + child_chunk = ChildChunk.query.filter( + ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id + ).first() + if not child_chunk: + raise NotFound("Child chunk not found.") + # The role of the current user in the ta table must be admin or owner + if not current_user.is_editor: + raise Forbidden() + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + # validate args + parser = reqparse.RequestParser() + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + try: + child_chunk = SegmentService.update_child_chunk( + args.get("content"), child_chunk, segment, document, dataset + ) + except ChildChunkIndexingServiceError as e: + raise ChildChunkIndexingError(str(e)) + return {"data": marshal(child_chunk, child_chunk_fields)}, 200 + + api.add_resource(DatasetDocumentSegmentListApi, "/datasets//documents//segments") -api.add_resource(DatasetDocumentSegmentApi, "/datasets//segments//") +api.add_resource( + DatasetDocumentSegmentApi, "/datasets//documents//segment/" +) api.add_resource(DatasetDocumentSegmentAddApi, "/datasets//documents//segment") api.add_resource( DatasetDocumentSegmentUpdateApi, @@ -424,3 +651,11 @@ def get(self, job_id): "/datasets//documents//segments/batch_import", "/datasets/batch_import_status/", ) +api.add_resource( + ChildChunkAddApi, + "/datasets//documents//segments//child_chunks", +) +api.add_resource( + ChildChunkUpdateApi, + "/datasets//documents//segments//child_chunks/", +) diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 6a7a3971a8b33f..2f00a84de697a7 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -89,3 +89,15 @@ class IndexingEstimateError(BaseHTTPException): error_code = "indexing_estimate_error" description = "Knowledge indexing estimate failed: {message}" code = 500 + + +class ChildChunkIndexingError(BaseHTTPException): + error_code = "child_chunk_indexing_error" + description = "Create child chunk index failed: {message}" + code = 500 + + +class ChildChunkDeleteIndexError(BaseHTTPException): + error_code = "child_chunk_delete_index_error" + description = "Delete child chunk index failed: {message}" + code = 500 diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 34904574a8b88d..1c500f51bffc08 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -16,6 +16,7 @@ from fields.segment_fields import segment_fields from models.dataset import Dataset, DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService +from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs class SegmentApi(DatasetApiResource): @@ -193,7 +194,7 @@ def post(self, tenant_id, dataset_id, document_id, segment_id): args = parser.parse_args() SegmentService.segment_create_args_validate(args["segment"], document) - segment = SegmentService.update_segment(args["segment"], segment, document, dataset) + segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py new file mode 100644 index 00000000000000..90c98797338270 --- /dev/null +++ b/api/core/entities/knowledge_entities.py @@ -0,0 +1,19 @@ +from typing import Optional + +from pydantic import BaseModel + + +class PreviewDetail(BaseModel): + content: str + child_chunks: Optional[list[str]] = None + + +class QAPreviewDetail(BaseModel): + question: str + answer: str + + +class IndexingEstimate(BaseModel): + total_segments: int + preview: list[PreviewDetail] + qa_preview: Optional[list[QAPreviewDetail]] = None diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 1f0a0d0ef1dda4..c51dca79efb513 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -8,34 +8,34 @@ import uuid from typing import Any, Optional, cast -from flask import Flask, current_app +from flask import current_app from flask_login import current_user # type: ignore from sqlalchemy.orm.exc import ObjectDeletedError from configs import dify_config +from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail from core.errors.error import ProviderTokenNotInitError -from core.llm_generator.llm_generator import LLMGenerator from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from core.rag.splitter.fixed_text_splitter import ( EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter, ) from core.rag.splitter.text_splitter import TextSplitter -from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper -from models.dataset import Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.feature_service import FeatureService @@ -115,6 +115,9 @@ def run_in_splitting_status(self, dataset_document: DatasetDocument): for document_segment in document_segments: db.session.delete(document_segment) + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + # delete child chunks + db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete() db.session.commit() # get the process rule processing_rule = ( @@ -183,7 +186,22 @@ def run_in_indexing_status(self, dataset_document: DatasetDocument): "dataset_id": document_segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = document_segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": document_segment.document_id, + "dataset_id": document_segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) # build index @@ -222,7 +240,7 @@ def indexing_estimate( doc_language: str = "English", dataset_id: Optional[str] = None, indexing_technique: str = "economy", - ) -> dict: + ) -> IndexingEstimate: """ Estimate the indexing for the document. """ @@ -258,31 +276,38 @@ def indexing_estimate( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) - preview_texts: list[str] = [] + preview_texts = [] + total_segments = 0 index_type = doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - all_text_docs = [] for extract_setting in extract_settings: # extract - text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) - all_text_docs.extend(text_docs) processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) ) - - # get splitter - splitter = self._get_splitter(processing_rule, embedding_model_instance) - - # split to documents - documents = self._split_to_documents_for_estimate( - text_docs=text_docs, splitter=splitter, processing_rule=processing_rule + text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) + documents = index_processor.transform( + text_docs, + embedding_model_instance=embedding_model_instance, + process_rule=processing_rule.to_dict(), + tenant_id=current_user.current_tenant_id, + doc_language=doc_language, + preview=True, ) - total_segments += len(documents) for document in documents: - if len(preview_texts) < 5: - preview_texts.append(document.page_content) + if len(preview_texts) < 10: + if doc_form and doc_form == "qa_model": + preview_detail = QAPreviewDetail( + question=document.page_content, answer=document.metadata.get("answer") + ) + preview_texts.append(preview_detail) + else: + preview_detail = PreviewDetail(content=document.page_content) + if document.children: + preview_detail.child_chunks = [child.page_content for child in document.children] + preview_texts.append(preview_detail) # delete image files and related db records image_upload_file_ids = get_image_upload_file_ids(document.page_content) @@ -299,15 +324,8 @@ def indexing_estimate( db.session.delete(image_file) if doc_form and doc_form == "qa_model": - if len(preview_texts) > 0: - # qa model document - response = LLMGenerator.generate_qa_document( - current_user.current_tenant_id, preview_texts[0], doc_language - ) - document_qa_list = self.format_split_text(response) - - return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts} - return {"total_segments": total_segments, "preview": preview_texts} + return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) + return IndexingEstimate(total_segments=total_segments, preview=preview_texts) def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict @@ -401,31 +419,26 @@ def filter_string(text): @staticmethod def _get_splitter( - processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] + processing_rule_mode: str, + max_tokens: int, + chunk_overlap: int, + separator: str, + embedding_model_instance: Optional[ModelInstance], ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - character_splitter: TextSplitter - if processing_rule.mode == "custom": + if processing_rule_mode in ["custom", "hierarchical"]: # The user-defined segmentation rule - rules = json.loads(processing_rule.rules) - segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH - if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: + if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") - separator = segmentation["separator"] if separator: separator = separator.replace("\\n", "\n") - if segmentation.get("chunk_overlap"): - chunk_overlap = segmentation["chunk_overlap"] - else: - chunk_overlap = 0 - character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( - chunk_size=segmentation["max_tokens"], + chunk_size=max_tokens, chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], @@ -443,142 +456,6 @@ def _get_splitter( return character_splitter - def _step_split( - self, - text_docs: list[Document], - splitter: TextSplitter, - dataset: Dataset, - dataset_document: DatasetDocument, - processing_rule: DatasetProcessRule, - ) -> list[Document]: - """ - Split the text documents into documents and save them to the document segment. - """ - documents = self._split_to_documents( - text_docs=text_docs, - splitter=splitter, - processing_rule=processing_rule, - tenant_id=dataset.tenant_id, - document_form=dataset_document.doc_form, - document_language=dataset_document.doc_language, - ) - - # save node to document segment - doc_store = DatasetDocumentStore( - dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id - ) - - # add document segments - doc_store.add_documents(documents) - - # update document status to indexing - cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - self._update_document_index_status( - document_id=dataset_document.id, - after_indexing_status="indexing", - extra_update_params={ - DatasetDocument.cleaning_completed_at: cur_time, - DatasetDocument.splitting_completed_at: cur_time, - }, - ) - - # update segment status to indexing - self._update_segments_by_document( - dataset_document_id=dataset_document.id, - update_params={ - DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - }, - ) - - return documents - - def _split_to_documents( - self, - text_docs: list[Document], - splitter: TextSplitter, - processing_rule: DatasetProcessRule, - tenant_id: str, - document_form: str, - document_language: str, - ) -> list[Document]: - """ - Split the text documents into nodes. - """ - all_documents: list[Document] = [] - all_qa_documents: list[Document] = [] - for text_doc in text_docs: - # document clean - document_text = self._document_clean(text_doc.page_content, processing_rule) - text_doc.page_content = document_text - - # parse document to nodes - documents = splitter.split_documents([text_doc]) - split_documents = [] - for document_node in documents: - if document_node.page_content.strip(): - if document_node.metadata is not None: - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(document_node.page_content) - document_node.metadata["doc_id"] = doc_id - document_node.metadata["doc_hash"] = hash - # delete Splitter character - page_content = document_node.page_content - document_node.page_content = remove_leading_symbols(page_content) - - if document_node.page_content: - split_documents.append(document_node) - all_documents.extend(split_documents) - # processing qa document - if document_form == "qa_model": - for i in range(0, len(all_documents), 10): - threads = [] - sub_documents = all_documents[i : i + 10] - for doc in sub_documents: - document_format_thread = threading.Thread( - target=self.format_qa_document, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "tenant_id": tenant_id, - "document_node": doc, - "all_qa_documents": all_qa_documents, - "document_language": document_language, - }, - ) - threads.append(document_format_thread) - document_format_thread.start() - for thread in threads: - thread.join() - return all_qa_documents - return all_documents - - def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): - format_documents = [] - if document_node.page_content is None or not document_node.page_content.strip(): - return - with flask_app.app_context(): - try: - # qa model document - response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language) - document_qa_list = self.format_split_text(response) - qa_documents = [] - for result in document_qa_list: - qa_document = Document( - page_content=result["question"], metadata=document_node.metadata.model_copy() - ) - if qa_document.metadata is not None: - doc_id = str(uuid.uuid4()) - hash = helper.generate_text_hash(result["question"]) - qa_document.metadata["answer"] = result["answer"] - qa_document.metadata["doc_id"] = doc_id - qa_document.metadata["doc_hash"] = hash - qa_documents.append(qa_document) - format_documents.extend(qa_documents) - except Exception as e: - logging.exception("Failed to format qa document") - - all_qa_documents.extend(format_documents) - def _split_to_documents_for_estimate( self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule ) -> list[Document]: @@ -624,11 +501,11 @@ def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str: return document_text @staticmethod - def format_split_text(text): + def format_split_text(text: str) -> list[QAPreviewDetail]: regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" matches = re.findall(regex, text, re.UNICODE) - return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] + return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a] def _load( self, @@ -654,13 +531,14 @@ def _load( indexing_start_at = time.perf_counter() tokens = 0 chunk_size = 10 + if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: + # create keyword index + create_keyword_thread = threading.Thread( + target=self._process_keyword_index, + args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), + ) + create_keyword_thread.start() - # create keyword index - create_keyword_thread = threading.Thread( - target=self._process_keyword_index, - args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore - ) - create_keyword_thread.start() if dataset.indexing_technique == "high_quality": with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [] @@ -680,8 +558,8 @@ def _load( for future in futures: tokens += future.result() - - create_keyword_thread.join() + if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: + create_keyword_thread.join() indexing_end_at = time.perf_counter() # update document status to completed @@ -793,28 +671,6 @@ def _update_segments_by_document(dataset_document_id: str, update_params: dict) DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() - @staticmethod - def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): - """ - Batch add segments index processing - """ - documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - documents.append(document) - # save vector index - index_type = dataset.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.load(dataset, documents) - def _transform( self, index_processor: BaseIndexProcessor, @@ -856,7 +712,7 @@ def _load_segments(self, dataset, dataset_document, documents): ) # add document segments - doc_store.add_documents(documents) + doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) # update document status to indexing cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 34343ad60ea4c1..568517c0ea6d36 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,11 +6,14 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.embedding.retrieval import RetrievalSegments +from core.rag.index_processor.constant.index_type import IndexType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db -from models.dataset import Dataset +from models.dataset import ChildChunk, Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { @@ -248,3 +251,88 @@ def full_text_index_search( @staticmethod def escape_query_for_search(query: str) -> str: return query.replace('"', '\\"') + + @staticmethod + def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]: + records = [] + include_segment_ids = [] + segment_child_map = {} + for document in documents: + document_id = document.metadata["document_id"] + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_index_node_id = document.metadata["doc_id"] + result = ( + db.session.query(ChildChunk, DocumentSegment) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .filter( + ChildChunk.index_node_id == child_index_node_id, + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + ) + .first() + ) + if result: + child_chunk, segment = result + if not segment: + continue + if segment.id not in include_segment_ids: + include_segment_ids.append(segment.id) + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + map_detail = { + "max_score": document.metadata.get("score", 0.0), + "child_chunks": [child_chunk_detail], + } + segment_child_map[segment.id] = map_detail + record = { + "segment": segment, + } + records.append(record) + else: + child_chunk_detail = { + "id": child_chunk.id, + "content": child_chunk.content, + "position": child_chunk.position, + "score": document.metadata.get("score", 0.0), + } + segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) + segment_child_map[segment.id]["max_score"] = max( + segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) + ) + else: + continue + else: + index_node_id = document.metadata["doc_id"] + + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset_document.dataset_id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() + ) + + if not segment: + continue + include_segment_ids.append(segment.id) + record = { + "segment": segment, + "score": document.metadata.get("score", None), + } + + records.append(record) + for record in records: + if record["segment"].id in segment_child_map: + record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) + record["score"] = segment_child_map[record["segment"].id]["max_score"] + + return [RetrievalSegments(**record) for record in records] diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 35becaa0c7bea7..8dfc60184c2fb6 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -7,7 +7,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.rag.models.document import Document from extensions.ext_database import db -from models.dataset import Dataset, DocumentSegment +from models.dataset import ChildChunk, Dataset, DocumentSegment class DatasetDocumentStore: @@ -60,7 +60,7 @@ def docs(self) -> dict[str, Document]: return output - def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: + def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: max_position = ( db.session.query(func.max(DocumentSegment.position)) .filter(DocumentSegment.document_id == self._document_id) @@ -120,6 +120,23 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> segment_document.answer = doc.metadata.pop("answer", "") db.session.add(segment_document) + db.session.flush() + if save_child: + for postion, child in enumerate(doc.children, start=1): + child_segment = ChildChunk( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_document.id, + position=postion, + index_node_id=child.metadata["doc_id"], + index_node_hash=child.metadata["doc_hash"], + content=child.page_content, + word_count=len(child.page_content), + type="automatic", + created_by=self._user_id, + ) + db.session.add(child_segment) else: segment_document.content = doc.page_content if doc.metadata.get("answer"): @@ -127,6 +144,30 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> segment_document.index_node_hash = doc.metadata["doc_hash"] segment_document.word_count = len(doc.page_content) segment_document.tokens = tokens + if save_child and doc.children: + # delete the existing child chunks + db.session.query(ChildChunk).filter( + ChildChunk.tenant_id == self._dataset.tenant_id, + ChildChunk.dataset_id == self._dataset.id, + ChildChunk.document_id == self._document_id, + ChildChunk.segment_id == segment_document.id, + ).delete() + # add new child chunks + for position, child in enumerate(doc.children, start=1): + child_segment = ChildChunk( + tenant_id=self._dataset.tenant_id, + dataset_id=self._dataset.id, + document_id=self._document_id, + segment_id=segment_document.id, + position=position, + index_node_id=child.metadata["doc_id"], + index_node_hash=child.metadata["doc_hash"], + content=child.page_content, + word_count=len(child.page_content), + type="automatic", + created_by=self._user_id, + ) + db.session.add(child_segment) db.session.commit() diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py new file mode 100644 index 00000000000000..800422d888e4fa --- /dev/null +++ b/api/core/rag/embedding/retrieval.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import BaseModel + +from models.dataset import DocumentSegment + + +class RetrievalChildChunk(BaseModel): + """Retrieval segments.""" + + id: str + content: str + score: float + position: int + + +class RetrievalSegments(BaseModel): + """Retrieval segments.""" + + model_config = {"arbitrary_types_allowed": True} + segment: DocumentSegment + child_chunks: Optional[list[RetrievalChildChunk]] = None + score: Optional[float] = None diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index a473b3dfa78a90..23ccab63b80732 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -24,7 +24,6 @@ from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor -from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor from core.rag.extractor.word_extractor import WordExtractor from core.rag.models.document import Document @@ -141,11 +140,7 @@ def extract( extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key) else: # txt - extractor = ( - UnstructuredTextExtractor(file_path, unstructured_api_url) - if is_automatic - else TextExtractor(file_path, autodetect_encoding=True) - ) + extractor = TextExtractor(file_path, autodetect_encoding=True) else: if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index c3161bc812cb73..d93de5fef948d4 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -267,8 +267,10 @@ def parse_paragraph(paragraph): if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph para = paragraphs.pop(0) parsed_paragraph = parse_paragraph(para) - if parsed_paragraph: + if parsed_paragraph.strip(): content.append(parsed_paragraph) + else: + content.append("\n") elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table table = tables.pop(0) content.append(self._table_to_markdown(table, image_map)) diff --git a/api/core/rag/index_processor/constant/index_type.py b/api/core/rag/index_processor/constant/index_type.py index e42cc44c6f8043..0845b58e25b558 100644 --- a/api/core/rag/index_processor/constant/index_type.py +++ b/api/core/rag/index_processor/constant/index_type.py @@ -1,8 +1,7 @@ from enum import Enum -class IndexType(Enum): +class IndexType(str, Enum): PARAGRAPH_INDEX = "text_model" QA_INDEX = "qa_model" - PARENT_CHILD_INDEX = "parent_child_index" - SUMMARY_INDEX = "summary_index" + PARENT_CHILD_INDEX = "hierarchical_model" diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index 7e5efdc66ed533..6d7aa0f7df172e 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -27,10 +27,10 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: raise NotImplementedError @abstractmethod - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): raise NotImplementedError - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): raise NotImplementedError @abstractmethod @@ -45,26 +45,29 @@ def retrieve( ) -> list[Document]: raise NotImplementedError - def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: + def _get_splitter( + self, + processing_rule_mode: str, + max_tokens: int, + chunk_overlap: int, + separator: str, + embedding_model_instance: Optional[ModelInstance], + ) -> TextSplitter: """ Get the NodeParser object according to the processing rule. """ - character_splitter: TextSplitter - if processing_rule["mode"] == "custom": + if processing_rule_mode in ["custom", "hierarchical"]: # The user-defined segmentation rule - rules = processing_rule["rules"] - segmentation = rules["segmentation"] max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH - if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: + if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") - separator = segmentation["separator"] if separator: separator = separator.replace("\\n", "\n") character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( - chunk_size=segmentation["max_tokens"], - chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, + chunk_size=max_tokens, + chunk_overlap=chunk_overlap, fixed_separator=separator, separators=["\n\n", "。", ". ", " ", ""], embedding_model_instance=embedding_model_instance, diff --git a/api/core/rag/index_processor/index_processor_factory.py b/api/core/rag/index_processor/index_processor_factory.py index c5ba6295f32f84..c987edf342ab8a 100644 --- a/api/core/rag/index_processor/index_processor_factory.py +++ b/api/core/rag/index_processor/index_processor_factory.py @@ -3,6 +3,7 @@ from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor @@ -18,9 +19,11 @@ def init_index_processor(self) -> BaseIndexProcessor: if not self._index_type: raise ValueError("Index type must be specified.") - if self._index_type == IndexType.PARAGRAPH_INDEX.value: + if self._index_type == IndexType.PARAGRAPH_INDEX: return ParagraphIndexProcessor() - elif self._index_type == IndexType.QA_INDEX.value: + elif self._index_type == IndexType.QA_INDEX: return QAIndexProcessor() + elif self._index_type == IndexType.PARENT_CHILD_INDEX: + return ParentChildIndexProcessor() else: raise ValueError(f"Index type {self._index_type} is not supported.") diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index c66fa54d503e9f..ec7126159021ea 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -13,21 +13,34 @@ from core.rag.models.document import Document from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper -from models.dataset import Dataset +from models.dataset import Dataset, DatasetProcessRule +from services.entities.knowledge_entities.knowledge_entities import Rule class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( - extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: + process_rule = kwargs.get("process_rule") + if process_rule.get("mode") == "automatic": + automatic_rule = DatasetProcessRule.AUTOMATIC_RULES + rules = Rule(**automatic_rule) + else: + rules = Rule(**process_rule.get("rules")) # Split the text documents into nodes. splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule", {}), + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens, + chunk_overlap=rules.segmentation.chunk_overlap, + separator=rules.segmentation.separator, embedding_model_instance=kwargs.get("embedding_model_instance"), ) all_documents = [] @@ -53,15 +66,19 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: all_documents.extend(split_documents) return all_documents - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) if with_keywords: + keywords_list = kwargs.get("keywords_list") keyword = Keyword(dataset) - keyword.create(documents) + if keywords_list and len(keywords_list) > 0: + keyword.add_texts(documents, keywords_list=keywords_list) + else: + keyword.add_texts(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) if node_ids: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py new file mode 100644 index 00000000000000..7ff15b9f4c86d5 --- /dev/null +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -0,0 +1,189 @@ +"""Paragraph index processor.""" + +import uuid +from typing import Optional + +from core.model_manager import ModelInstance +from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from libs import helper +from models.dataset import ChildChunk, Dataset, DocumentSegment +from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule + + +class ParentChildIndexProcessor(BaseIndexProcessor): + def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: + text_docs = ExtractProcessor.extract( + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), + ) + + return text_docs + + def transform(self, documents: list[Document], **kwargs) -> list[Document]: + process_rule = kwargs.get("process_rule") + rules = Rule(**process_rule.get("rules")) + all_documents = [] + if rules.parent_mode == ParentMode.PARAGRAPH: + # Split the text documents into nodes. + splitter = self._get_splitter( + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens, + chunk_overlap=rules.segmentation.chunk_overlap, + separator=rules.segmentation.separator, + embedding_model_instance=kwargs.get("embedding_model_instance"), + ) + for document in documents: + # document clean + document_text = CleanProcessor.clean(document.page_content, process_rule) + document.page_content = document_text + # parse document to nodes + document_nodes = splitter.split_documents([document]) + split_documents = [] + for document_node in document_nodes: + if document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document_node.page_content) + document_node.metadata["doc_id"] = doc_id + document_node.metadata["doc_hash"] = hash + # delete Splitter character + page_content = document_node.page_content + if page_content.startswith(".") or page_content.startswith("。"): + page_content = page_content[1:].strip() + else: + page_content = page_content + if len(page_content) > 0: + document_node.page_content = page_content + # parse document to child nodes + child_nodes = self._split_child_nodes( + document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") + ) + document_node.children = child_nodes + split_documents.append(document_node) + all_documents.extend(split_documents) + elif rules.parent_mode == ParentMode.FULL_DOC: + page_content = "\n".join([document.page_content for document in documents]) + document = Document(page_content=page_content, metadata=documents[0].metadata) + # parse document to child nodes + child_nodes = self._split_child_nodes( + document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") + ) + document.children = child_nodes + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(document.page_content) + document.metadata["doc_id"] = doc_id + document.metadata["doc_hash"] = hash + all_documents.append(document) + + return all_documents + + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): + if dataset.indexing_technique == "high_quality": + vector = Vector(dataset) + for document in documents: + child_documents = document.children + if child_documents: + formatted_child_documents = [ + Document(**child_document.model_dump()) for child_document in child_documents + ] + vector.create(formatted_child_documents) + + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): + # node_ids is segment's node_ids + if dataset.indexing_technique == "high_quality": + delete_child_chunks = kwargs.get("delete_child_chunks") or False + vector = Vector(dataset) + if node_ids: + child_node_ids = ( + db.session.query(ChildChunk.index_node_id) + .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.index_node_id.in_(node_ids), + ChildChunk.dataset_id == dataset.id, + ) + .all() + ) + child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] + vector.delete_by_ids(child_node_ids) + if delete_child_chunks: + db.session.query(ChildChunk).filter( + ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) + ).delete() + db.session.commit() + else: + vector.delete() + + if delete_child_chunks: + db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete() + db.session.commit() + + def retrieve( + self, + retrieval_method: str, + query: str, + dataset: Dataset, + top_k: int, + score_threshold: float, + reranking_model: dict, + ) -> list[Document]: + # Set search parameters. + results = RetrievalService.retrieve( + retrieval_method=retrieval_method, + dataset_id=dataset.id, + query=query, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) + # Organize results. + docs = [] + for result in results: + metadata = result.metadata + metadata["score"] = result.score + if result.score > score_threshold: + doc = Document(page_content=result.page_content, metadata=metadata) + docs.append(doc) + return docs + + def _split_child_nodes( + self, + document_node: Document, + rules: Rule, + process_rule_mode: str, + embedding_model_instance: Optional[ModelInstance], + ) -> list[ChildDocument]: + child_splitter = self._get_splitter( + processing_rule_mode=process_rule_mode, + max_tokens=rules.subchunk_segmentation.max_tokens, + chunk_overlap=rules.subchunk_segmentation.chunk_overlap, + separator=rules.subchunk_segmentation.separator, + embedding_model_instance=embedding_model_instance, + ) + # parse document to child nodes + child_nodes = [] + child_documents = child_splitter.split_documents([document_node]) + for child_document_node in child_documents: + if child_document_node.page_content.strip(): + doc_id = str(uuid.uuid4()) + hash = helper.generate_text_hash(child_document_node.page_content) + child_document = ChildDocument( + page_content=child_document_node.page_content, metadata=document_node.metadata + ) + child_document.metadata["doc_id"] = doc_id + child_document.metadata["doc_hash"] = hash + child_page_content = child_document.page_content + if child_page_content.startswith(".") or child_page_content.startswith("。"): + child_page_content = child_page_content[1:].strip() + if len(child_page_content) > 0: + child_document.page_content = child_page_content + child_nodes.append(child_document) + return child_nodes diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 20fd16e8f39b65..6535d4626117f8 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -21,18 +21,28 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols from libs import helper from models.dataset import Dataset +from services.entities.knowledge_entities.knowledge_entities import Rule class QAIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: text_docs = ExtractProcessor.extract( - extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" + extract_setting=extract_setting, + is_automatic=( + kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" + ), ) return text_docs def transform(self, documents: list[Document], **kwargs) -> list[Document]: + preview = kwargs.get("preview") + process_rule = kwargs.get("process_rule") + rules = Rule(**process_rule.get("rules")) splitter = self._get_splitter( - processing_rule=kwargs.get("process_rule") or {}, + processing_rule_mode=process_rule.get("mode"), + max_tokens=rules.segmentation.max_tokens, + chunk_overlap=rules.segmentation.chunk_overlap, + separator=rules.segmentation.separator, embedding_model_instance=kwargs.get("embedding_model_instance"), ) @@ -59,24 +69,33 @@ def transform(self, documents: list[Document], **kwargs) -> list[Document]: document_node.page_content = remove_leading_symbols(page_content) split_documents.append(document_node) all_documents.extend(split_documents) - for i in range(0, len(all_documents), 10): - threads = [] - sub_documents = all_documents[i : i + 10] - for doc in sub_documents: - document_format_thread = threading.Thread( - target=self._format_qa_document, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "tenant_id": kwargs.get("tenant_id"), - "document_node": doc, - "all_qa_documents": all_qa_documents, - "document_language": kwargs.get("doc_language", "English"), - }, - ) - threads.append(document_format_thread) - document_format_thread.start() - for thread in threads: - thread.join() + if preview: + self._format_qa_document( + current_app._get_current_object(), + kwargs.get("tenant_id"), + all_documents[0], + all_qa_documents, + kwargs.get("doc_language", "English"), + ) + else: + for i in range(0, len(all_documents), 10): + threads = [] + sub_documents = all_documents[i : i + 10] + for doc in sub_documents: + document_format_thread = threading.Thread( + target=self._format_qa_document, + kwargs={ + "flask_app": current_app._get_current_object(), + "tenant_id": kwargs.get("tenant_id"), + "document_node": doc, + "all_qa_documents": all_qa_documents, + "document_language": kwargs.get("doc_language", "English"), + }, + ) + threads.append(document_format_thread) + document_format_thread.start() + for thread in threads: + thread.join() return all_qa_documents def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: @@ -98,12 +117,12 @@ def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: raise ValueError(str(e)) return text_docs - def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): + def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): if dataset.indexing_technique == "high_quality": vector = Vector(dataset) vector.create(documents) - def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): + def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): vector = Vector(dataset) if node_ids: vector.delete_by_ids(node_ids) diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 1e9aaa24f04c98..a34afc7bd75fda 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -5,6 +5,19 @@ from pydantic import BaseModel, Field +class ChildDocument(BaseModel): + """Class for storing a piece of text and associated metadata.""" + + page_content: str + + vector: Optional[list[float]] = None + + """Arbitrary metadata about the page content (e.g., source, relationships to other + documents, etc.). + """ + metadata: Optional[dict] = Field(default_factory=dict) + + class Document(BaseModel): """Class for storing a piece of text and associated metadata.""" @@ -19,6 +32,8 @@ class Document(BaseModel): provider: Optional[str] = "dify" + children: Optional[list[ChildDocument]] = None + class BaseDocumentTransformer(ABC): """Abstract base class for document transformation systems. diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index a265f36671b04b..e1d36aad1fa5d7 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -166,43 +166,29 @@ def retrieve( "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list = {} # deal with dify documents if dify_documents: - for item in dify_documents: - if item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - - index_node_ids = [document.metadata["doc_id"] for document in dify_documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(dataset_ids), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - for segment in sorted_segments: + records = RetrievalService.format_retrieval_documents(dify_documents) + if records: + for record in records: + segment = record.segment if segment.answer: document_context_list.append( DocumentContext( content=f"question:{segment.get_sign_content()} answer:{segment.answer}", - score=document_score_list.get(segment.index_node_id, None), + score=record.score, ) ) else: document_context_list.append( DocumentContext( content=segment.get_sign_content(), - score=document_score_list.get(segment.index_node_id, None), + score=record.score, ) ) if show_retrieve_source: - for segment in sorted_segments: + for record in records: + segment = record.segment dataset = Dataset.query.filter_by(id=segment.dataset_id).first() document = DatasetDocument.query.filter( DatasetDocument.id == segment.document_id, @@ -218,7 +204,7 @@ def retrieve( "data_source_type": document.data_source_type, "segment_id": segment.id, "retriever_from": invoke_from.to_source(), - "score": document_score_list.get(segment.index_node_id, 0.0), + "score": record.score or 0.0, } if invoke_from.to_source() == "dev": diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index bfd93c074dd6d5..0f239af51ae79c 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -11,6 +11,7 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.rag.datasource.retrieval_service import RetrievalService from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment @@ -18,7 +19,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment +from models.dataset import Dataset, Document from models.workflow import WorkflowNodeExecutionStatus from .entities import KnowledgeRetrievalNodeData @@ -211,29 +212,12 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: "content": item.page_content, } retrieval_resource_list.append(source) - document_score_list: dict[str, float] = {} # deal with dify documents if dify_documents: - document_score_list = {} - for item in dify_documents: - if item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - - index_node_ids = [document.metadata["doc_id"] for document in dify_documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - - for segment in sorted_segments: + records = RetrievalService.format_retrieval_documents(dify_documents) + if records: + for record in records: + segment = record.segment dataset = Dataset.query.filter_by(id=segment.dataset_id).first() document = Document.query.filter( Document.id == segment.document_id, @@ -251,7 +235,7 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: "document_data_source_type": document.data_source_type, "segment_id": segment.id, "retriever_from": "workflow", - "score": document_score_list.get(segment.index_node_id, None), + "score": record.score or 0.0, "segment_hit_count": segment.hit_count, "segment_word_count": segment.word_count, "segment_position": segment.position, @@ -270,10 +254,8 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, reverse=True, ) - position = 1 - for item in retrieval_resource_list: + for position, item in enumerate(retrieval_resource_list, start=1): item["metadata"]["position"] = position - position += 1 return retrieval_resource_list @classmethod diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index a74e6f54fb3858..bedab5750f1d66 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -73,6 +73,7 @@ "embedding_available": fields.Boolean, "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), "tags": fields.List(fields.Nested(tag_fields)), + "doc_form": fields.String, "external_knowledge_info": fields.Nested(external_knowledge_info_fields), "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), } diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 2b2ac6243f4da5..f2250d964ac12d 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -34,6 +34,7 @@ "data_source_info": fields.Raw(attribute="data_source_info_dict"), "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), "dataset_process_rule_id": fields.String, + "process_rule_dict": fields.Raw(attribute="process_rule_dict"), "name": fields.String, "created_from": fields.String, "created_by": fields.String, diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index aaafcab8ab6ba0..b9f7e78c170529 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -34,8 +34,16 @@ "document": fields.Nested(document_fields), } +child_chunk_fields = { + "id": fields.String, + "content": fields.String, + "position": fields.Integer, + "score": fields.Float, +} + hit_testing_record_fields = { "segment": fields.Nested(segment_fields), + "child_chunks": fields.List(fields.Nested(child_chunk_fields)), "score": fields.Float, "tsne_position": fields.Raw, } diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py index 4413af31607897..52f89859c931b7 100644 --- a/api/fields/segment_fields.py +++ b/api/fields/segment_fields.py @@ -2,6 +2,17 @@ from libs.helper import TimestampField +child_chunk_fields = { + "id": fields.String, + "segment_id": fields.String, + "content": fields.String, + "position": fields.Integer, + "word_count": fields.Integer, + "type": fields.String, + "created_at": TimestampField, + "updated_at": TimestampField, +} + segment_fields = { "id": fields.String, "position": fields.Integer, @@ -20,10 +31,13 @@ "status": fields.String, "created_by": fields.String, "created_at": TimestampField, + "updated_at": TimestampField, + "updated_by": fields.String, "indexing_at": TimestampField, "completed_at": TimestampField, "error": fields.String, "stopped_at": TimestampField, + "child_chunks": fields.List(fields.Nested(child_chunk_fields)), } segment_list_response = { diff --git a/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py new file mode 100644 index 00000000000000..9238e5a0a81c5a --- /dev/null +++ b/api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py @@ -0,0 +1,55 @@ +"""parent-child-index + +Revision ID: e19037032219 +Revises: 01d6889832f7 +Create Date: 2024-11-22 07:01:17.550037 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e19037032219' +down_revision = 'd7999dfa4aae' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('child_chunks', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('segment_id', models.types.StringUUID(), nullable=False), + sa.Column('position', sa.Integer(), nullable=False), + sa.Column('content', sa.Text(), nullable=False), + sa.Column('word_count', sa.Integer(), nullable=False), + sa.Column('index_node_id', sa.String(length=255), nullable=True), + sa.Column('index_node_hash', sa.String(length=255), nullable=True), + sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False), + sa.Column('created_by', models.types.StringUUID(), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_by', models.types.StringUUID(), nullable=True), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('indexing_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.Column('error', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id', name='child_chunk_pkey') + ) + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('child_chunks', schema=None) as batch_op: + batch_op.drop_index('child_chunk_dataset_id_idx') + + op.drop_table('child_chunks') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py new file mode 100644 index 00000000000000..6dadd4e4a8afe5 --- /dev/null +++ b/api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py @@ -0,0 +1,47 @@ +"""add_auto_disabled_dataset_logs + +Revision ID: 923752d42eb6 +Revises: e19037032219 +Create Date: 2024-12-25 11:37:55.467101 + +""" +from alembic import op +import models as models +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '923752d42eb6' +down_revision = 'e19037032219' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dataset_auto_disable_logs', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('dataset_id', models.types.StringUUID(), nullable=False), + sa.Column('document_id', models.types.StringUUID(), nullable=False), + sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey') + ) + with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: + batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False) + batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False) + batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: + batch_op.drop_index('dataset_auto_disable_log_tenant_idx') + batch_op.drop_index('dataset_auto_disable_log_dataset_idx') + batch_op.drop_index('dataset_auto_disable_log_created_atx') + + op.drop_table('dataset_auto_disable_logs') + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index b9b41dcf475bb1..f6c8a4511bb774 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -17,6 +17,7 @@ from configs import dify_config from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_storage import storage +from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from .account import Account from .engine import db @@ -215,7 +216,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined] created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - MODES = ["automatic", "custom"] + MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] AUTOMATIC_RULES: dict[str, Any] = { "pre_processing_rules": [ @@ -231,8 +232,6 @@ def to_dict(self): "dataset_id": self.dataset_id, "mode": self.mode, "rules": self.rules_dict, - "created_by": self.created_by, - "created_at": self.created_at, } @property @@ -396,6 +395,12 @@ def hit_count(self): .scalar() ) + @property + def process_rule_dict(self): + if self.dataset_process_rule_id: + return self.dataset_process_rule.to_dict() + return None + def to_dict(self): return { "id": self.id, @@ -560,6 +565,24 @@ def next_segment(self): .first() ) + @property + def child_chunks(self): + process_rule = self.document.dataset_process_rule + if process_rule.mode == "hierarchical": + rules = Rule(**process_rule.rules_dict) + if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: + child_chunks = ( + db.session.query(ChildChunk) + .filter(ChildChunk.segment_id == self.id) + .order_by(ChildChunk.position.asc()) + .all() + ) + return child_chunks or [] + else: + return [] + else: + return [] + def get_sign_content(self): signed_urls = [] text = self.content @@ -605,6 +628,47 @@ def get_sign_content(self): return text +class ChildChunk(db.Model): + __tablename__ = "child_chunks" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), + db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), + ) + + # initial fields + id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + segment_id = db.Column(StringUUID, nullable=False) + position = db.Column(db.Integer, nullable=False) + content = db.Column(db.Text, nullable=False) + word_count = db.Column(db.Integer, nullable=False) + # indexing fields + index_node_id = db.Column(db.String(255), nullable=True) + index_node_hash = db.Column(db.String(255), nullable=True) + type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) + created_by = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_by = db.Column(StringUUID, nullable=True) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + indexing_at = db.Column(db.DateTime, nullable=True) + completed_at = db.Column(db.DateTime, nullable=True) + error = db.Column(db.Text, nullable=True) + + @property + def dataset(self): + return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() + + @property + def document(self): + return db.session.query(Document).filter(Document.id == self.document_id).first() + + @property + def segment(self): + return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() + + class AppDatasetJoin(db.Model): # type: ignore[name-defined] __tablename__ = "app_dataset_joins" __table_args__ = ( @@ -844,3 +908,20 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + + +class DatasetAutoDisableLog(db.Model): + __tablename__ = "dataset_auto_disable_logs" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), + db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), + db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), + db.Index("dataset_auto_disable_log_created_atx", "created_at"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index f66b3c47979435..eb73cc285d6d55 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -10,7 +10,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.dataset import Dataset, DatasetQuery, Document +from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document from services.feature_service import FeatureService @@ -75,6 +75,23 @@ def clean_unused_datasets_task(): ) if not dataset_query or len(dataset_query) == 0: try: + # add auto disable log + documents = ( + db.session.query(Document) + .filter( + Document.dataset_id == dataset.id, + Document.enabled == True, + Document.archived == False, + ) + .all() + ) + for document in documents: + dataset_auto_disable_log = DatasetAutoDisableLog( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + ) + db.session.add(dataset_auto_disable_log) # remove index index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor.clean(dataset, None) @@ -151,6 +168,23 @@ def clean_unused_datasets_task(): else: plan = plan_cache.decode() if plan == "sandbox": + # add auto disable log + documents = ( + db.session.query(Document) + .filter( + Document.dataset_id == dataset.id, + Document.enabled == True, + Document.archived == False, + ) + .all() + ) + for document in documents: + dataset_auto_disable_log = DatasetAutoDisableLog( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=document.id, + ) + db.session.add(dataset_auto_disable_log) # remove index index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor.clean(dataset, None) diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py new file mode 100644 index 00000000000000..080e78113150f7 --- /dev/null +++ b/api/schedule/mail_clean_document_notify_task.py @@ -0,0 +1,66 @@ +import logging +import time + +import click +from celery import shared_task +from flask import render_template + +from extensions.ext_mail import mail +from models.account import Account, Tenant, TenantAccountJoin +from models.dataset import Dataset, DatasetAutoDisableLog + + +@shared_task(queue="mail") +def send_document_clean_notify_task(): + """ + Async Send document clean notify mail + + Usage: send_document_clean_notify_task.delay() + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start send document clean notify mail", fg="green")) + start_at = time.perf_counter() + + # send document clean notify mail + try: + dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() + # group by tenant_id + dataset_auto_disable_logs_map = {} + for dataset_auto_disable_log in dataset_auto_disable_logs: + dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) + + for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): + knowledge_details = [] + tenant = Tenant.query.filter(Tenant.id == tenant_id).first() + if not tenant: + continue + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + account = Account.query.filter(Account.id == current_owner_join.account_id).first() + if not account: + continue + + dataset_auto_dataset_map = {} + for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: + dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( + dataset_auto_disable_log.document_id + ) + + for dataset_id, document_ids in dataset_auto_dataset_map.items(): + dataset = Dataset.query.filter(Dataset.id == dataset_id).first() + if dataset: + document_count = len(document_ids) + knowledge_details.append(f"
  • Knowledge base {dataset.name}: {document_count} documents
  • ") + + html_content = render_template( + "clean_document_job_mail_template-US.html", + ) + mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Send invite member mail to {} failed".format(to)) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index d2d8a718d55c8a..8de28085d45457 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -14,6 +14,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from core.rag.index_processor.constant.index_type import IndexType from core.rag.retrieval.retrieval_methods import RetrievalMethod from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted @@ -23,7 +24,9 @@ from models.account import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, + ChildChunk, Dataset, + DatasetAutoDisableLog, DatasetCollectionBinding, DatasetPermission, DatasetPermissionEnum, @@ -35,8 +38,14 @@ ) from models.model import UploadFile from models.source import DataSourceOauthBinding -from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateEntity -from services.errors.account import NoPermissionError +from services.entities.knowledge_entities.knowledge_entities import ( + ChildChunkUpdateArgs, + KnowledgeConfig, + RetrievalModel, + SegmentUpdateArgs, +) +from services.errors.account import InvalidActionError, NoPermissionError +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError from services.errors.dataset import DatasetNameDuplicateError from services.errors.document import DocumentIndexingError from services.errors.file import FileNotExistsError @@ -44,13 +53,16 @@ from services.feature_service import FeatureModel, FeatureService from services.tag_service import TagService from services.vector_service import VectorService +from tasks.batch_clean_document_task import batch_clean_document_task from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task from tasks.delete_segment_from_index_task import delete_segment_from_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task +from tasks.disable_segments_from_index_task import disable_segments_from_index_task from tasks.document_indexing_task import document_indexing_task from tasks.document_indexing_update_task import document_indexing_update_task from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task +from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task @@ -408,6 +420,24 @@ def get_related_apps(dataset_id: str): .all() ) + @staticmethod + def get_dataset_auto_disable_logs(dataset_id: str) -> dict: + # get recent 30 days auto disable logs + start_date = datetime.datetime.now() - datetime.timedelta(days=30) + dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( + DatasetAutoDisableLog.dataset_id == dataset_id, + DatasetAutoDisableLog.created_at >= start_date, + ).all() + if dataset_auto_disable_logs: + return { + "document_ids": [log.document_id for log in dataset_auto_disable_logs], + "count": len(dataset_auto_disable_logs), + } + return { + "document_ids": [], + "count": 0, + } + class DocumentService: DEFAULT_RULES = { @@ -588,6 +618,20 @@ def delete_document(document): db.session.delete(document) db.session.commit() + @staticmethod + def delete_documents(dataset: Dataset, document_ids: list[str]): + documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all() + file_ids = [ + document.data_source_info_dict["upload_file_id"] + for document in documents + if document.data_source_type == "upload_file" + ] + batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) + + for document in documents: + db.session.delete(document) + db.session.commit() + @staticmethod def rename_document(dataset_id: str, document_id: str, name: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) @@ -689,7 +733,7 @@ def get_documents_position(dataset_id): @staticmethod def save_document_with_dataset_id( dataset: Dataset, - document_data: dict, + knowledge_config: KnowledgeConfig, account: Account | Any, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", @@ -698,18 +742,18 @@ def save_document_with_dataset_id( features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if "original_document_id" not in document_data or not document_data["original_document_id"]: + if not knowledge_config.original_document_id: count = 0 - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids count = len(upload_file_list) - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list for notion_info in notion_info_list: - count = count + len(notion_info["pages"]) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - count = len(website_info["urls"]) + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + count = len(website_info.urls) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -718,17 +762,14 @@ def save_document_with_dataset_id( # if dataset is empty, update dataset data_source_type if not dataset.data_source_type: - dataset.data_source_type = document_data["data_source"]["type"] + dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type if not dataset.indexing_technique: - if ( - "indexing_technique" not in document_data - or document_data["indexing_technique"] not in Dataset.INDEXING_TECHNIQUE_LIST - ): - raise ValueError("Indexing technique is required") + if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: + raise ValueError("Indexing technique is invalid") - dataset.indexing_technique = document_data["indexing_technique"] - if document_data["indexing_technique"] == "high_quality": + dataset.indexing_technique = knowledge_config.indexing_technique + if knowledge_config.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING @@ -748,29 +789,29 @@ def save_document_with_dataset_id( "score_threshold_enabled": False, } - dataset.retrieval_model = document_data.get("retrieval_model") or default_retrieval_model + dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model documents = [] - if document_data.get("original_document_id"): - document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) + if knowledge_config.original_document_id: + document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) documents.append(document) batch = document.batch else: batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) # save process rule if not dataset_process_rule: - process_rule = document_data["process_rule"] - if process_rule["mode"] == "custom": + process_rule = knowledge_config.process_rule + if process_rule.mode in ("custom", "hierarchical"): dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule["mode"], - rules=json.dumps(process_rule["rules"]), + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json(), created_by=account.id, ) - elif process_rule["mode"] == "automatic": + elif process_rule.mode == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule["mode"], + mode=process_rule.mode, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) @@ -786,8 +827,8 @@ def save_document_with_dataset_id( position = DocumentService.get_documents_position(dataset.id) document_ids = [] duplicate_document_ids = [] - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -804,7 +845,7 @@ def save_document_with_dataset_id( "upload_file_id": file_id, } # check duplicate - if document_data.get("duplicate", False): + if knowledge_config.duplicate: document = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, @@ -814,10 +855,10 @@ def save_document_with_dataset_id( ).first() if document: document.dataset_process_rule_id = dataset_process_rule.id - document.updated_at = datetime.datetime.utcnow() + document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data["doc_form"] - document.doc_language = document_data["doc_language"] + document.doc_form = knowledge_config.doc_form + document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch document.indexing_status = "waiting" @@ -828,9 +869,9 @@ def save_document_with_dataset_id( document = DocumentService.build_document( dataset, dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, @@ -843,8 +884,8 @@ def save_document_with_dataset_id( document_ids.append(document.id) documents.append(document) position += 1 - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list exist_page_ids = [] exist_document = {} documents = Document.query.filter_by( @@ -859,7 +900,7 @@ def save_document_with_dataset_id( exist_page_ids.append(data_source_info["notion_page_id"]) exist_document[data_source_info["notion_page_id"]] = document.id for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] + workspace_id = notion_info.workspace_id data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, @@ -870,25 +911,25 @@ def save_document_with_dataset_id( ).first() if not data_source_binding: raise ValueError("Data source binding not found.") - for page in notion_info["pages"]: - if page["page_id"] not in exist_page_ids: + for page in notion_info.pages: + if page.page_id not in exist_page_ids: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page["page_id"], - "notion_page_icon": page["page_icon"], - "type": page["type"], + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, + "type": page.type, } document = DocumentService.build_document( dataset, dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, account, - page["page_name"], + page.page_name, batch, ) db.session.add(document) @@ -897,19 +938,19 @@ def save_document_with_dataset_id( documents.append(document) position += 1 else: - exist_document.pop(page["page_id"]) + exist_document.pop(page.page_id) # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - urls = website_info["urls"] + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + urls = website_info.urls for url in urls: data_source_info = { "url": url, - "provider": website_info["provider"], - "job_id": website_info["job_id"], - "only_main_content": website_info.get("only_main_content", False), + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, "mode": "crawl", } if len(url) > 255: @@ -919,9 +960,9 @@ def save_document_with_dataset_id( document = DocumentService.build_document( dataset, dataset_process_rule.id, - document_data["data_source"]["type"], - document_data["doc_form"], - document_data["doc_language"], + knowledge_config.data_source.info_list.data_source_type, + knowledge_config.doc_form, + knowledge_config.doc_language, data_source_info, created_from, position, @@ -995,31 +1036,31 @@ def get_tenant_documents_count(): @staticmethod def update_document_with_dataset_id( dataset: Dataset, - document_data: dict, + document_data: KnowledgeConfig, account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, created_from: str = "web", ): DatasetService.check_dataset_model_setting(dataset) - document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) + document = DocumentService.get_document(dataset.id, document_data.original_document_id) if document is None: raise NotFound("Document not found") if document.display_status != "available": raise ValueError("Document is not available") # save process rule - if document_data.get("process_rule"): - process_rule = document_data["process_rule"] - if process_rule["mode"] == "custom": + if document_data.process_rule: + process_rule = document_data.process_rule + if process_rule.mode in {"custom", "hierarchical"}: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule["mode"], - rules=json.dumps(process_rule["rules"]), + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json(), created_by=account.id, ) - elif process_rule["mode"] == "automatic": + elif process_rule.mode == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode=process_rule["mode"], + mode=process_rule.mode, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) @@ -1028,11 +1069,11 @@ def update_document_with_dataset_id( db.session.commit() document.dataset_process_rule_id = dataset_process_rule.id # update document data source - if document_data.get("data_source"): + if document_data.data_source: file_name = "" data_source_info = {} - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if document_data.data_source.info_list.data_source_type == "upload_file": + upload_file_list = document_data.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: file = ( db.session.query(UploadFile) @@ -1048,10 +1089,10 @@ def update_document_with_dataset_id( data_source_info = { "upload_file_id": file_id, } - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + elif document_data.data_source.info_list.data_source_type == "notion_import": + notion_info_list = document_data.data_source.info_list.notion_info_list for notion_info in notion_info_list: - workspace_id = notion_info["workspace_id"] + workspace_id = notion_info.workspace_id data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, @@ -1062,31 +1103,31 @@ def update_document_with_dataset_id( ).first() if not data_source_binding: raise ValueError("Data source binding not found.") - for page in notion_info["pages"]: + for page in notion_info.pages: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page["page_id"], - "notion_page_icon": page["page_icon"], - "type": page["type"], + "notion_page_id": page.page_id, + "notion_page_icon": page.page_icon, + "type": page.type, } - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - urls = website_info["urls"] + elif document_data.data_source.info_list.data_source_type == "website_crawl": + website_info = document_data.data_source.info_list.website_info_list + urls = website_info.urls for url in urls: data_source_info = { "url": url, - "provider": website_info["provider"], - "job_id": website_info["job_id"], - "only_main_content": website_info.get("only_main_content", False), + "provider": website_info.provider, + "job_id": website_info.job_id, + "only_main_content": website_info.only_main_content, "mode": "crawl", } - document.data_source_type = document_data["data_source"]["type"] + document.data_source_type = document_data.data_source.info_list.data_source_type document.data_source_info = json.dumps(data_source_info) document.name = file_name # update document name - if document_data.get("name"): - document.name = document_data["name"] + if document_data.name: + document.name = document_data.name # update document to be waiting document.indexing_status = "waiting" document.completed_at = None @@ -1096,7 +1137,7 @@ def update_document_with_dataset_id( document.splitting_completed_at = None document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data["doc_form"] + document.doc_form = document_data.doc_form db.session.add(document) db.session.commit() # update document segment @@ -1108,21 +1149,21 @@ def update_document_with_dataset_id( return document @staticmethod - def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account): + def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: count = 0 - if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids count = len(upload_file_list) - elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] + elif knowledge_config.data_source.info_list.data_source_type == "notion_import": + notion_info_list = knowledge_config.data_source.info_list.notion_info_list for notion_info in notion_info_list: - count = count + len(notion_info["pages"]) - elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]["info_list"]["website_info_list"] - count = len(website_info["urls"]) + count = count + len(notion_info.pages) + elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": + website_info = knowledge_config.data_source.info_list.website_info_list + count = len(website_info.urls) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -1131,13 +1172,13 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun dataset_collection_binding_id = None retrieval_model = None - if document_data["indexing_technique"] == "high_quality": + if knowledge_config.indexing_technique == "high_quality": dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - document_data["embedding_model_provider"], document_data["embedding_model"] + knowledge_config.embedding_model_provider, knowledge_config.embedding_model ) dataset_collection_binding_id = dataset_collection_binding.id - if document_data.get("retrieval_model"): - retrieval_model = document_data["retrieval_model"] + if knowledge_config.retrieval_model: + retrieval_model = knowledge_config.retrieval_model else: default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, @@ -1146,24 +1187,24 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun "top_k": 2, "score_threshold_enabled": False, } - retrieval_model = default_retrieval_model + retrieval_model = RetrievalModel(**default_retrieval_model) # save dataset dataset = Dataset( tenant_id=tenant_id, name="", - data_source_type=document_data["data_source"]["type"], - indexing_technique=document_data.get("indexing_technique", "high_quality"), + data_source_type=knowledge_config.data_source.info_list.data_source_type, + indexing_technique=knowledge_config.indexing_technique, created_by=account.id, - embedding_model=document_data.get("embedding_model"), - embedding_model_provider=document_data.get("embedding_model_provider"), + embedding_model=knowledge_config.embedding_model, + embedding_model_provider=knowledge_config.embedding_model_provider, collection_binding_id=dataset_collection_binding_id, - retrieval_model=retrieval_model, + retrieval_model=retrieval_model.model_dump() if retrieval_model else None, ) - db.session.add(dataset) + db.session.add(dataset) # type: ignore db.session.flush() - documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account) + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) cut_length = 18 cut_name = documents[0].name[:cut_length] @@ -1174,133 +1215,86 @@ def save_document_without_dataset_id(tenant_id: str, document_data: dict, accoun return dataset, documents, batch @classmethod - def document_create_args_validate(cls, args: dict): - if "original_document_id" not in args or not args["original_document_id"]: - DocumentService.data_source_args_validate(args) - DocumentService.process_rule_args_validate(args) + def document_create_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.data_source and not knowledge_config.process_rule: + raise ValueError("Data source or Process rule is required") else: - if ("data_source" not in args or not args["data_source"]) and ( - "process_rule" not in args or not args["process_rule"] - ): - raise ValueError("Data source or Process rule is required") - else: - if args.get("data_source"): - DocumentService.data_source_args_validate(args) - if args.get("process_rule"): - DocumentService.process_rule_args_validate(args) + if knowledge_config.data_source: + DocumentService.data_source_args_validate(knowledge_config) + if knowledge_config.process_rule: + DocumentService.process_rule_args_validate(knowledge_config) @classmethod - def data_source_args_validate(cls, args: dict): - if "data_source" not in args or not args["data_source"]: + def data_source_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.data_source: raise ValueError("Data source is required") - if not isinstance(args["data_source"], dict): - raise ValueError("Data source is invalid") - - if "type" not in args["data_source"] or not args["data_source"]["type"]: - raise ValueError("Data source type is required") - - if args["data_source"]["type"] not in Document.DATA_SOURCES: + if knowledge_config.data_source.info_list.data_source_type not in Document.DATA_SOURCES: raise ValueError("Data source type is invalid") - if "info_list" not in args["data_source"] or not args["data_source"]["info_list"]: + if not knowledge_config.data_source.info_list: raise ValueError("Data source info is required") - if args["data_source"]["type"] == "upload_file": - if ( - "file_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["file_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "upload_file": + if not knowledge_config.data_source.info_list.file_info_list: raise ValueError("File source info is required") - if args["data_source"]["type"] == "notion_import": - if ( - "notion_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["notion_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "notion_import": + if not knowledge_config.data_source.info_list.notion_info_list: raise ValueError("Notion source info is required") - if args["data_source"]["type"] == "website_crawl": - if ( - "website_info_list" not in args["data_source"]["info_list"] - or not args["data_source"]["info_list"]["website_info_list"] - ): + if knowledge_config.data_source.info_list.data_source_type == "website_crawl": + if not knowledge_config.data_source.info_list.website_info_list: raise ValueError("Website source info is required") @classmethod - def process_rule_args_validate(cls, args: dict): - if "process_rule" not in args or not args["process_rule"]: + def process_rule_args_validate(cls, knowledge_config: KnowledgeConfig): + if not knowledge_config.process_rule: raise ValueError("Process rule is required") - if not isinstance(args["process_rule"], dict): - raise ValueError("Process rule is invalid") - - if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: + if not knowledge_config.process_rule.mode: raise ValueError("Process rule mode is required") - if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: + if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args["process_rule"]["mode"] == "automatic": - args["process_rule"]["rules"] = {} + if knowledge_config.process_rule.mode == "automatic": + knowledge_config.process_rule.rules = None else: - if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: + if not knowledge_config.process_rule.rules: raise ValueError("Process rule rules is required") - if not isinstance(args["process_rule"]["rules"], dict): - raise ValueError("Process rule rules is invalid") - - if ( - "pre_processing_rules" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["pre_processing_rules"] is None - ): + if knowledge_config.process_rule.rules.pre_processing_rules is None: raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): - raise ValueError("Process rule pre_processing_rules is invalid") - unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: - if "id" not in pre_processing_rule or not pre_processing_rule["id"]: + for pre_processing_rule in knowledge_config.process_rule.rules.pre_processing_rules: + if not pre_processing_rule.id: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: - raise ValueError("Process rule pre_processing_rules id is invalid") - - if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: - raise ValueError("Process rule pre_processing_rules enabled is required") - - if not isinstance(pre_processing_rule["enabled"], bool): + if not isinstance(pre_processing_rule.enabled, bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule.id] = pre_processing_rule - args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) + knowledge_config.process_rule.rules.pre_processing_rules = list(unique_pre_processing_rule_dicts.values()) - if ( - "segmentation" not in args["process_rule"]["rules"] - or args["process_rule"]["rules"]["segmentation"] is None - ): + if not knowledge_config.process_rule.rules.segmentation: raise ValueError("Process rule segmentation is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): - raise ValueError("Process rule segmentation is invalid") - - if ( - "separator" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["separator"] - ): + if not knowledge_config.process_rule.rules.segmentation.separator: raise ValueError("Process rule segmentation separator is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): + if not isinstance(knowledge_config.process_rule.rules.segmentation.separator, str): raise ValueError("Process rule segmentation separator is invalid") - if ( - "max_tokens" not in args["process_rule"]["rules"]["segmentation"] - or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + if not ( + knowledge_config.process_rule.mode == "hierarchical" + and knowledge_config.process_rule.rules.parent_mode == "full-doc" ): - raise ValueError("Process rule segmentation max_tokens is required") + if not knowledge_config.process_rule.rules.segmentation.max_tokens: + raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): - raise ValueError("Process rule segmentation max_tokens is invalid") + if not isinstance(knowledge_config.process_rule.rules.segmentation.max_tokens, int): + raise ValueError("Process rule segmentation max_tokens is invalid") @classmethod def estimate_args_validate(cls, args: dict): @@ -1447,7 +1441,7 @@ def create_segment(cls, args: dict, document: Document, dataset: Dataset): # save vector index try: - VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset) + VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset, document.doc_form) except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False @@ -1525,7 +1519,7 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas db.session.add(document) try: # save vector index - VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset) + VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset, document.doc_form) except Exception as e: logging.exception("create segment index failed") for segment_document in segment_data_list: @@ -1537,14 +1531,13 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas return segment_data_list @classmethod - def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): - segment_update_entity = SegmentUpdateEntity(**args) + def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") - if segment_update_entity.enabled is not None: - action = segment_update_entity.enabled + if args.enabled is not None: + action = args.enabled if segment.enabled != action: if not action: segment.enabled = action @@ -1557,22 +1550,22 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document disable_segment_from_index_task.delay(segment.id) return segment if not segment.enabled: - if segment_update_entity.enabled is not None: - if not segment_update_entity.enabled: + if args.enabled is not None: + if not args.enabled: raise ValueError("Can't update disabled segment") else: raise ValueError("Can't update disabled segment") try: word_count_change = segment.word_count - content = segment_update_entity.content + content = args.content if segment.content == content: segment.word_count = len(content) if document.doc_form == "qa_model": - segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer or "") + segment.answer = args.answer + segment.word_count += len(args.answer) word_count_change = segment.word_count - word_count_change - if segment_update_entity.keywords: - segment.keywords = segment_update_entity.keywords + if args.keywords: + segment.keywords = args.keywords segment.enabled = True segment.disabled_at = None segment.disabled_by = None @@ -1583,9 +1576,38 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document document.word_count = max(0, document.word_count + word_count_change) db.session.add(document) # update segment index task - if segment_update_entity.enabled: - keywords = segment_update_entity.keywords or [] - VectorService.create_segments_vector([keywords], [segment], dataset) + if args.enabled: + VectorService.create_segments_vector([args.keywords], [segment], dataset) + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # regenerate child chunks + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() + + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) else: segment_hash = helper.generate_text_hash(content) tokens = 0 @@ -1616,8 +1638,8 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document segment.disabled_at = None segment.disabled_by = None if document.doc_form == "qa_model": - segment.answer = segment_update_entity.answer - segment.word_count += len(segment_update_entity.answer or "") + segment.answer = args.answer + segment.word_count += len(args.answer) word_count_change = segment.word_count - word_count_change # update document word count if word_count_change != 0: @@ -1625,8 +1647,38 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document db.session.add(document) db.session.add(segment) db.session.commit() - # update segment vector index - VectorService.update_segment_vector(segment_update_entity.keywords, segment, dataset) + if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() + + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + VectorService.generate_child_chunks( + segment, document, dataset, embedding_model_instance, processing_rule, True + ) + elif document.doc_form in (IndexType.PARAGRAPH_INDEX, IndexType.QA_INDEX): + # update segment vector index + VectorService.update_segment_vector(args.keywords, segment, dataset) except Exception as e: logging.exception("update segment index failed") @@ -1649,13 +1701,265 @@ def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: D if segment.enabled: # send delete segment index task redis_client.setex(indexing_cache_key, 600, 1) - delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id) + delete_segment_from_index_task.delay([segment.index_node_id], dataset.id, document.id) db.session.delete(segment) # update document word count document.word_count -= segment.word_count db.session.add(document) db.session.commit() + @classmethod + def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): + index_node_ids = ( + DocumentSegment.query.with_entities(DocumentSegment.index_node_id) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.tenant_id == current_user.current_tenant_id, + ) + .all() + ) + index_node_ids = [index_node_id[0] for index_node_id in index_node_ids] + + delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id) + db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete() + db.session.commit() + + @classmethod + def update_segments_status(cls, segment_ids: list, action: str, dataset: Dataset, document: Document): + if action == "enable": + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == False, + ) + .all() + ) + if not segments: + return + real_deal_segmment_ids = [] + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + real_deal_segmment_ids.append(segment.id) + db.session.commit() + + enable_segments_to_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + elif action == "disable": + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + ) + .all() + ) + if not segments: + return + real_deal_segmment_ids = [] + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = False + segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_by = current_user.id + db.session.add(segment) + real_deal_segmment_ids.append(segment.id) + db.session.commit() + + disable_segments_from_index_task.delay(real_deal_segmment_ids, dataset.id, document.id) + else: + raise InvalidActionError() + + @classmethod + def create_child_chunk( + cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset + ) -> ChildChunk: + lock_name = "add_child_lock_{}".format(segment.id) + with redis_client.lock(lock_name, timeout=20): + index_node_id = str(uuid.uuid4()) + index_node_hash = helper.generate_text_hash(content) + child_chunk_count = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .count() + ) + max_position = ( + db.session.query(func.max(ChildChunk.position)) + .filter( + ChildChunk.tenant_id == current_user.current_tenant_id, + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .scalar() + ) + child_chunk = ChildChunk( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset.id, + document_id=document.id, + segment_id=segment.id, + position=max_position + 1, + index_node_id=index_node_id, + index_node_hash=index_node_hash, + content=content, + word_count=len(content), + type="customized", + created_by=current_user.id, + ) + db.session.add(child_chunk) + # save vector index + try: + VectorService.create_child_chunk_vector(child_chunk, dataset) + except Exception as e: + logging.exception("create child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + db.session.commit() + + return child_chunk + + @classmethod + def update_child_chunks( + cls, + child_chunks_update_args: list[ChildChunkUpdateArgs], + segment: DocumentSegment, + document: Document, + dataset: Dataset, + ) -> list[ChildChunk]: + child_chunks = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.dataset_id == dataset.id, + ChildChunk.document_id == document.id, + ChildChunk.segment_id == segment.id, + ) + .all() + ) + child_chunks_map = {chunk.id: chunk for chunk in child_chunks} + + new_child_chunks, update_child_chunks, delete_child_chunks, new_child_chunks_args = [], [], [], [] + + for child_chunk_update_args in child_chunks_update_args: + if child_chunk_update_args.id: + child_chunk = child_chunks_map.pop(child_chunk_update_args.id, None) + if child_chunk: + if child_chunk.content != child_chunk_update_args.content: + child_chunk.content = child_chunk_update_args.content + child_chunk.word_count = len(child_chunk.content) + child_chunk.updated_by = current_user.id + child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + child_chunk.type = "customized" + update_child_chunks.append(child_chunk) + else: + new_child_chunks_args.append(child_chunk_update_args) + if child_chunks_map: + delete_child_chunks = list(child_chunks_map.values()) + try: + if update_child_chunks: + db.session.bulk_save_objects(update_child_chunks) + + if delete_child_chunks: + for child_chunk in delete_child_chunks: + db.session.delete(child_chunk) + if new_child_chunks_args: + child_chunk_count = len(child_chunks) + for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1): + index_node_id = str(uuid.uuid4()) + index_node_hash = helper.generate_text_hash(args.content) + child_chunk = ChildChunk( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset.id, + document_id=document.id, + segment_id=segment.id, + position=position, + index_node_id=index_node_id, + index_node_hash=index_node_hash, + content=args.content, + word_count=len(args.content), + type="customized", + created_by=current_user.id, + ) + + db.session.add(child_chunk) + db.session.flush() + new_child_chunks.append(child_chunk) + VectorService.update_child_chunk_vector(new_child_chunks, update_child_chunks, delete_child_chunks, dataset) + db.session.commit() + except Exception as e: + logging.exception("update child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position) + + @classmethod + def update_child_chunk( + cls, + content: str, + child_chunk: ChildChunk, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + ) -> ChildChunk: + try: + child_chunk.content = content + child_chunk.word_count = len(content) + child_chunk.updated_by = current_user.id + child_chunk.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + child_chunk.type = "customized" + db.session.add(child_chunk) + VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) + db.session.commit() + except Exception as e: + logging.exception("update child chunk index failed") + db.session.rollback() + raise ChildChunkIndexingError(str(e)) + return child_chunk + + @classmethod + def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset): + db.session.delete(child_chunk) + try: + VectorService.delete_child_chunk_vector(child_chunk, dataset) + except Exception as e: + logging.exception("delete child chunk index failed") + db.session.rollback() + raise ChildChunkDeleteIndexError(str(e)) + db.session.commit() + + @classmethod + def get_child_chunks( + cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None + ): + query = ChildChunk.query.filter_by( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset_id, + document_id=document_id, + segment_id=segment_id, + ).order_by(ChildChunk.position.asc()) + if keyword: + query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) + return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + class DatasetCollectionBindingService: @classmethod diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 449b79f339b9a9..8d6a246b6428d0 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,4 +1,5 @@ -from typing import Optional +from enum import Enum +from typing import Literal, Optional from pydantic import BaseModel @@ -8,3 +9,112 @@ class SegmentUpdateEntity(BaseModel): answer: Optional[str] = None keywords: Optional[list[str]] = None enabled: Optional[bool] = None + + +class ParentMode(str, Enum): + FULL_DOC = "full-doc" + PARAGRAPH = "paragraph" + + +class NotionIcon(BaseModel): + type: str + url: Optional[str] = None + emoji: Optional[str] = None + + +class NotionPage(BaseModel): + page_id: str + page_name: str + page_icon: Optional[NotionIcon] = None + type: str + + +class NotionInfo(BaseModel): + workspace_id: str + pages: list[NotionPage] + + +class WebsiteInfo(BaseModel): + provider: str + job_id: str + urls: list[str] + only_main_content: bool = True + + +class FileInfo(BaseModel): + file_ids: list[str] + + +class InfoList(BaseModel): + data_source_type: Literal["upload_file", "notion_import", "website_crawl"] + notion_info_list: Optional[list[NotionInfo]] = None + file_info_list: Optional[FileInfo] = None + website_info_list: Optional[WebsiteInfo] = None + + +class DataSource(BaseModel): + info_list: InfoList + + +class PreProcessingRule(BaseModel): + id: str + enabled: bool + + +class Segmentation(BaseModel): + separator: str = "\n" + max_tokens: int + chunk_overlap: int = 0 + + +class Rule(BaseModel): + pre_processing_rules: Optional[list[PreProcessingRule]] = None + segmentation: Optional[Segmentation] = None + parent_mode: Optional[Literal["full-doc", "paragraph"]] = None + subchunk_segmentation: Optional[Segmentation] = None + + +class ProcessRule(BaseModel): + mode: Literal["automatic", "custom", "hierarchical"] + rules: Optional[Rule] = None + + +class RerankingModel(BaseModel): + reranking_provider_name: Optional[str] = None + reranking_model_name: Optional[str] = None + + +class RetrievalModel(BaseModel): + search_method: Literal["hybrid_search", "semantic_search", "full_text_search"] + reranking_enable: bool + reranking_model: Optional[RerankingModel] = None + top_k: int + score_threshold_enabled: bool + score_threshold: Optional[float] = None + + +class KnowledgeConfig(BaseModel): + original_document_id: Optional[str] = None + duplicate: bool = True + indexing_technique: Literal["high_quality", "economy"] + data_source: Optional[DataSource] = None + process_rule: Optional[ProcessRule] = None + retrieval_model: Optional[RetrievalModel] = None + doc_form: str = "text_model" + doc_language: str = "English" + embedding_model: Optional[str] = None + embedding_model_provider: Optional[str] = None + name: Optional[str] = None + + +class SegmentUpdateArgs(BaseModel): + content: Optional[str] = None + answer: Optional[str] = None + keywords: Optional[list[str]] = None + regenerate_child_chunks: bool = False + enabled: Optional[bool] = None + + +class ChildChunkUpdateArgs(BaseModel): + id: Optional[str] = None + content: str diff --git a/api/services/errors/chunk.py b/api/services/errors/chunk.py new file mode 100644 index 00000000000000..75bf4d5d5f8122 --- /dev/null +++ b/api/services/errors/chunk.py @@ -0,0 +1,9 @@ +from services.errors.base import BaseServiceError + + +class ChildChunkIndexingError(BaseServiceError): + description = "{message}" + + +class ChildChunkDeleteIndexError(BaseServiceError): + description = "{message}" diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 41b4e1ec46374a..0e61beaa90ef20 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -7,7 +7,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.account import Account -from models.dataset import Dataset, DatasetQuery, DocumentSegment +from models.dataset import Dataset, DatasetQuery default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, @@ -69,7 +69,7 @@ def retrieve( db.session.add(dataset_query) db.session.commit() - return dict(cls.compact_retrieve_response(dataset, query, all_documents)) + return cls.compact_retrieve_response(query, all_documents) @classmethod def external_retrieve( @@ -106,41 +106,14 @@ def external_retrieve( return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod - def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): - records = [] - - for document in documents: - if document.metadata is None: - continue - - index_node_id = document.metadata["doc_id"] - - segment = ( - db.session.query(DocumentSegment) - .filter( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.enabled == True, - DocumentSegment.status == "completed", - DocumentSegment.index_node_id == index_node_id, - ) - .first() - ) - - if not segment: - continue - - record = { - "segment": segment, - "score": document.metadata.get("score", None), - } - - records.append(record) + def compact_retrieve_response(cls, query: str, documents: list[Document]): + records = RetrievalService.format_retrieval_documents(documents) return { "query": { "content": query, }, - "records": records, + "records": [record.model_dump() for record in records], } @classmethod diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 3c67351335359d..6698e6e7188223 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,40 +1,68 @@ from typing import Optional +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import Document -from models.dataset import Dataset, DocumentSegment +from extensions.ext_database import db +from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import Document as DatasetDocument +from services.entities.knowledge_entities.knowledge_entities import ParentMode class VectorService: @classmethod def create_segments_vector( - cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset + cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): documents = [] - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - documents.append(document) - if dataset.indexing_technique == "high_quality": - # save vector index - vector = Vector(dataset=dataset) - vector.add_texts(documents, duplicate_check=True) - # save keyword index - keyword = Keyword(dataset) + for segment in segments: + if doc_form == IndexType.PARENT_CHILD_INDEX: + document = DatasetDocument.query.filter_by(id=segment.document_id).first() + # get the process rule + processing_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .first() + ) + # get embedding model instance + if dataset.indexing_technique == "high_quality": + # check embedding model setting + model_manager = ModelManager() - if keywords_list and len(keywords_list) > 0: - keyword.add_texts(documents, keywords_list=keywords_list) - else: - keyword.add_texts(documents) + if dataset.embedding_model_provider: + embedding_model_instance = model_manager.get_model_instance( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + else: + embedding_model_instance = model_manager.get_default_model_instance( + tenant_id=dataset.tenant_id, + model_type=ModelType.TEXT_EMBEDDING, + ) + else: + raise ValueError("The knowledge base index technique is not high quality!") + cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) + else: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + documents.append(document) + if len(documents) > 0: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) @classmethod def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): @@ -65,3 +93,123 @@ def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentS keyword.add_texts([document], keywords_list=[keywords]) else: keyword.add_texts([document]) + + @classmethod + def generate_child_chunks( + cls, + segment: DocumentSegment, + dataset_document: Document, + dataset: Dataset, + embedding_model_instance: ModelInstance, + processing_rule: DatasetProcessRule, + regenerate: bool = False, + ): + index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() + if regenerate: + # delete child chunks + index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True) + + # generate child chunks + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + # use full doc mode to generate segment's child chunk + processing_rule_dict = processing_rule.to_dict() + processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value + documents = index_processor.transform( + [document], + embedding_model_instance=embedding_model_instance, + process_rule=processing_rule_dict, + tenant_id=dataset.tenant_id, + doc_language=dataset_document.doc_language, + ) + # save child chunks + if len(documents) > 0 and len(documents[0].children) > 0: + index_processor.load(dataset, documents) + + for position, child_chunk in enumerate(documents[0].children, start=1): + child_segment = ChildChunk( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + document_id=dataset_document.id, + segment_id=segment.id, + position=position, + index_node_id=child_chunk.metadata["doc_id"], + index_node_hash=child_chunk.metadata["doc_hash"], + content=child_chunk.page_content, + word_count=len(child_chunk.page_content), + type="automatic", + created_by=dataset_document.created_by, + ) + db.session.add(child_segment) + db.session.commit() + + @classmethod + def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset): + child_document = Document( + page_content=child_segment.content, + metadata={ + "doc_id": child_segment.index_node_id, + "doc_hash": child_segment.index_node_hash, + "document_id": child_segment.document_id, + "dataset_id": child_segment.dataset_id, + }, + ) + if dataset.indexing_technique == "high_quality": + # save vector index + vector = Vector(dataset=dataset) + vector.add_texts([child_document], duplicate_check=True) + + @classmethod + def update_child_chunk_vector( + cls, + new_child_chunks: list[ChildChunk], + update_child_chunks: list[ChildChunk], + delete_child_chunks: list[ChildChunk], + dataset: Dataset, + ): + documents = [] + delete_node_ids = [] + for new_child_chunk in new_child_chunks: + new_child_document = Document( + page_content=new_child_chunk.content, + metadata={ + "doc_id": new_child_chunk.index_node_id, + "doc_hash": new_child_chunk.index_node_hash, + "document_id": new_child_chunk.document_id, + "dataset_id": new_child_chunk.dataset_id, + }, + ) + documents.append(new_child_document) + for update_child_chunk in update_child_chunks: + child_document = Document( + page_content=update_child_chunk.content, + metadata={ + "doc_id": update_child_chunk.index_node_id, + "doc_hash": update_child_chunk.index_node_hash, + "document_id": update_child_chunk.document_id, + "dataset_id": update_child_chunk.dataset_id, + }, + ) + documents.append(child_document) + delete_node_ids.append(update_child_chunk.index_node_id) + for delete_child_chunk in delete_child_chunks: + delete_node_ids.append(delete_child_chunk.index_node_id) + if dataset.indexing_technique == "high_quality": + # update vector index + vector = Vector(dataset=dataset) + if delete_node_ids: + vector.delete_by_ids(delete_node_ids) + if documents: + vector.add_texts(documents, duplicate_check=True) + + @classmethod + def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): + vector = Vector(dataset=dataset) + vector.delete_by_ids([child_chunk.index_node_id]) diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 50bb2b6e634fba..9a172b2d9d8157 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -6,12 +6,13 @@ from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client +from models.dataset import DatasetAutoDisableLog, DocumentSegment from models.dataset import Document as DatasetDocument -from models.dataset import DocumentSegment @shared_task(queue="dataset") @@ -53,7 +54,22 @@ def add_document_to_index_task(dataset_document_id: str): "dataset_id": segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) dataset = dataset_document.dataset @@ -65,6 +81,12 @@ def add_document_to_index_task(dataset_document_id: str): index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor.load(dataset, documents) + # delete auto disable log + db.session.query(DatasetAutoDisableLog).filter( + DatasetAutoDisableLog.document_id == dataset_document.id + ).delete() + db.session.commit() + end_at = time.perf_counter() logging.info( click.style( diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py new file mode 100644 index 00000000000000..9e81fefaa75d0e --- /dev/null +++ b/api/tasks/batch_clean_document_task.py @@ -0,0 +1,75 @@ +import logging +import time + +import click +from celery import shared_task + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.tools.utils.web_reader_tool import get_image_upload_file_ids +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.dataset import Dataset, DocumentSegment +from models.model import UploadFile + + +@shared_task(queue="dataset") +def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]): + """ + Clean document when document deleted. + :param document_ids: document ids + :param dataset_id: dataset id + :param doc_form: doc_form + :param file_ids: file ids + + Usage: clean_document_task.delay(document_id, dataset_id) + """ + logging.info(click.style("Start batch clean documents when documents deleted", fg="green")) + start_at = time.perf_counter() + + try: + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + + if not dataset: + raise Exception("Document has no dataset") + + segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all() + # check segment is exist + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + for segment in segments: + image_upload_file_ids = get_image_upload_file_ids(segment.content) + for upload_file_id in image_upload_file_ids: + image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + try: + storage.delete(image_file.key) + except Exception: + logging.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: {}".format(upload_file_id) + ) + db.session.delete(image_file) + db.session.delete(segment) + + db.session.commit() + if file_ids: + files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all() + for file in files: + try: + storage.delete(file.key) + except Exception: + logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id)) + db.session.delete(file) + db.session.commit() + + end_at = time.perf_counter() + logging.info( + click.style( + "Cleaned documents when documents deleted latency: {}".format(end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("Cleaned documents when documents deleted failed") diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 26ae9f8736d79a..05a0f0a407f5f3 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -7,13 +7,13 @@ from celery import shared_task # type: ignore from sqlalchemy import func -from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper from models.dataset import Dataset, Document, DocumentSegment +from services.vector_service import VectorService @shared_task(queue="dataset") @@ -96,8 +96,7 @@ def batch_create_segment_to_index_task( dataset_document.word_count += word_count_change db.session.add(dataset_document) # add index to db - indexing_runner = IndexingRunner() - indexing_runner.batch_add_segments(document_segments, dataset) + VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) db.session.commit() redis_client.setex(indexing_cache_key, 600, "completed") end_at = time.perf_counter() diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index d9278c03793877..dfc7a896fc05aa 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -62,7 +62,7 @@ def clean_dataset_task( if doc_form is None: raise ValueError("Index type must be specified.") index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, None) + index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) for document in documents: db.session.delete(document) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 3e80dd13771802..7a536f74265757 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -38,7 +38,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i if segments: index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index f5d6406d9cc04f..5a6eb00a6259d5 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -37,7 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() index_node_ids = [segment.index_node_id for segment in segments] - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index b025509aebe674..0efc924a77aae4 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -4,8 +4,9 @@ import click from celery import shared_task # type: ignore +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment from models.dataset import Document as DatasetDocument @@ -105,7 +106,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): db.session.commit() # clean index - index_processor.clean(dataset, None, with_keywords=False) + index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) for dataset_document in dataset_documents: # update from vector index @@ -128,7 +129,22 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): "dataset_id": segment.dataset_id, }, ) - + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 45a612c74550cd..3b04143dd9a075 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -6,48 +6,38 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from extensions.ext_database import db -from extensions.ext_redis import redis_client from models.dataset import Dataset, Document @shared_task(queue="dataset") -def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): +def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str): """ Async Remove segment from index - :param segment_id: - :param index_node_id: + :param index_node_ids: :param dataset_id: :param document_id: - Usage: delete_segment_from_index_task.delay(segment_id) + Usage: delete_segment_from_index_task.delay(segment_ids) """ - logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) + logging.info(click.style("Start delete segment from index", fg="green")) start_at = time.perf_counter() - indexing_cache_key = "segment_{}_delete_indexing".format(segment_id) try: dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: - logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan")) return dataset_document = db.session.query(Document).filter(Document.id == document_id).first() if not dataset_document: - logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan")) return if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": - logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan")) return index_type = dataset_document.doc_form index_processor = IndexProcessorFactory(index_type).init_index_processor() - index_processor.clean(dataset, [index_node_id]) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) end_at = time.perf_counter() - logging.info( - click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green") - ) + logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green")) except Exception: logging.exception("delete segment from index failed") - finally: - redis_client.delete(indexing_cache_key) diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py new file mode 100644 index 00000000000000..97eb9a40e71834 --- /dev/null +++ b/api/tasks/disable_segments_from_index_task.py @@ -0,0 +1,76 @@ +import logging +import time + +import click +from celery import shared_task + +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str): + """ + Async disable segments from index + :param segment_ids: + + Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id) + """ + start_at = time.perf_counter() + + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + return + + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + + if not dataset_document: + logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .all() + ) + + if not segments: + return + + try: + index_node_ids = [segment.index_node_id for segment in segments] + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) + + end_at = time.perf_counter() + logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green")) + except Exception: + # update segment error msg + db.session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "disabled_at": None, + "disabled_by": None, + "enabled": True, + } + ) + db.session.commit() + finally: + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index ac4e81f95d127e..d686698b9a5338 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -82,7 +82,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 5f1e9a892f54e3..d8f14830c979ad 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -47,7 +47,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 6db2620eb6eef0..8e1d2b6b5d147e 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -51,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() return @@ -73,14 +73,14 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) for segment in segments: db.session.delete(segment) db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 2f6eb7b82a0633..76522f4720cf95 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -6,8 +6,9 @@ from celery import shared_task # type: ignore from werkzeug.exceptions import NotFound +from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.models.document import Document +from core.rag.models.document import ChildDocument, Document from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DocumentSegment @@ -61,6 +62,22 @@ def enable_segment_to_index_task(segment_id: str): return index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents # save vector index index_processor.load(dataset, [document]) diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py new file mode 100644 index 00000000000000..ecafc99a94b433 --- /dev/null +++ b/api/tasks/enable_segments_to_index_task.py @@ -0,0 +1,108 @@ +import datetime +import logging +import time + +import click +from celery import shared_task + +from core.rag.index_processor.constant.index_type import IndexType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.dataset import Dataset, DocumentSegment +from models.dataset import Document as DatasetDocument + + +@shared_task(queue="dataset") +def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str): + """ + Async enable segments to index + :param segment_ids: + + Usage: enable_segments_to_index_task.delay(segment_ids) + """ + start_at = time.perf_counter() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) + return + + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + + if not dataset_document: + logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) + return + if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": + logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) + return + # sync index processor + index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() + + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ) + .all() + ) + if not segments: + return + + try: + documents = [] + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + + if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: + child_chunks = segment.child_chunks + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": document_id, + "dataset_id": dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + documents.append(document) + # save vector index + index_processor.load(dataset, documents) + + end_at = time.perf_counter() + logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green")) + except Exception as e: + logging.exception("enable segments to index failed") + # update segment error msg + db.session.query(DocumentSegment).filter( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset_id, + DocumentSegment.document_id == document_id, + ).update( + { + "error": str(e), + "status": "error", + "disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + "enabled": False, + } + ) + db.session.commit() + finally: + for segment in segments: + indexing_cache_key = "segment_{}_indexing".format(segment.id) + redis_client.delete(indexing_cache_key) diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py index 4ba6d1a83e32ae..1d580b38028f37 100644 --- a/api/tasks/remove_document_from_index_task.py +++ b/api/tasks/remove_document_from_index_task.py @@ -43,7 +43,7 @@ def remove_document_from_index_task(document_id: str): index_node_ids = [segment.index_node_id for segment in segments] if index_node_ids: try: - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) except Exception: logging.exception(f"clean dataset {dataset.id} from index failed") diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 485caa5152ea78..74fd542f6c4a80 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -48,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() redis_client.delete(retry_indexing_cache_key) @@ -69,14 +69,14 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - db.session.delete(segment) - db.session.commit() + for segment in segments: + db.session.delete(segment) + db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() @@ -86,7 +86,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): except Exception as ex: document.indexing_status = "error" document.error = str(ex) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index 5d6b069cf44919..8da050d0d1e2d3 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() redis_client.delete(sync_indexing_cache_key) @@ -65,14 +65,14 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): if segments: index_node_ids = [segment.index_node_id for segment in segments] # delete from vector index - index_processor.clean(dataset, index_node_ids) + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - for segment in segments: - db.session.delete(segment) - db.session.commit() + for segment in segments: + db.session.delete(segment) + db.session.commit() document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.utcnow() + document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() @@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): except Exception as ex: document.indexing_status = "error" document.error = str(ex) - document.stopped_at = datetime.datetime.utcnow() + document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(document) db.session.commit() logging.info(click.style(str(ex), fg="yellow")) diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html new file mode 100644 index 00000000000000..b7c9538f9f8bee --- /dev/null +++ b/api/templates/clean_document_job_mail_template-US.html @@ -0,0 +1,98 @@ + + + + + + Documents Disabled Notification + + + + + + \ No newline at end of file