Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

py lint #12102

Merged
merged 6 commits into from
Dec 25, 2024
Merged

py lint #12102

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def upgrade_db():
click.echo(click.style("Starting database migration.", fg="green"))

# run db migration
import flask_migrate
import flask_migrate # type: ignore

flask_migrate.upgrade()

Expand Down
7 changes: 4 additions & 3 deletions api/controllers/console/datasets/datasets_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,14 +413,15 @@ def get(self, dataset_id, document_id):
indexing_runner = IndexingRunner()

try:
response = indexing_runner.indexing_estimate(
estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id,
[extract_setting],
data_process_rule_dict,
document.doc_form,
"English",
dataset_id,
)
return estimate_response.model_dump(), 200
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
Expand All @@ -431,7 +432,7 @@ def get(self, dataset_id, document_id):
except Exception as e:
raise IndexingEstimateError(str(e))

return response.model_dump(), 200
return response, 200


class DocumentBatchIndexingEstimateApi(DocumentResource):
Expand Down Expand Up @@ -521,6 +522,7 @@ def get(self, dataset_id, batch):
"English",
dataset_id,
)
return response.model_dump(), 200
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
Expand All @@ -530,7 +532,6 @@ def get(self, dataset_id, batch):
raise ProviderNotInitializeError(ex.description)
except Exception as e:
raise IndexingEstimateError(str(e))
return response.model_dump(), 200


class DocumentBatchIndexingStatusApi(DocumentResource):
Expand Down
22 changes: 14 additions & 8 deletions api/controllers/service_api/dataset/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService


Expand Down Expand Up @@ -67,13 +68,14 @@ def post(self, tenant_id, dataset_id):
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args)
# validate args
DocumentService.document_create_args_validate(args)
DocumentService.document_create_args_validate(knowledge_config)

try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
Expand Down Expand Up @@ -122,12 +124,13 @@ def post(self, tenant_id, dataset_id, document_id):
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)

try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
Expand Down Expand Up @@ -186,12 +189,13 @@ def post(self, tenant_id, dataset_id):
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
args["data_source"] = data_source
# validate args
DocumentService.document_create_args_validate(args)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)

try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
Expand Down Expand Up @@ -245,12 +249,14 @@ def post(self, tenant_id, dataset_id, document_id):
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args)

knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)

try:
documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset,
document_data=args,
knowledge_config=knowledge_config,
account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api",
Expand Down
14 changes: 7 additions & 7 deletions api/core/indexing_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def indexing_estimate(
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
preview_texts = []
preview_texts = [] # type: ignore

total_segments = 0
index_type = doc_form
Expand All @@ -300,13 +300,13 @@ def indexing_estimate(
if len(preview_texts) < 10:
if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail(
question=document.page_content, answer=document.metadata.get("answer")
question=document.page_content, answer=document.metadata.get("answer") or ""
)
preview_texts.append(preview_detail)
else:
preview_detail = PreviewDetail(content=document.page_content)
preview_detail = PreviewDetail(content=document.page_content) # type: ignore
if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children]
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
preview_texts.append(preview_detail)

# delete image files and related db records
Expand All @@ -325,7 +325,7 @@ def indexing_estimate(

if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore

def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
Expand Down Expand Up @@ -454,7 +454,7 @@ def _get_splitter(
embedding_model_instance=embedding_model_instance,
)

return character_splitter
return character_splitter # type: ignore

def _split_to_documents_for_estimate(
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
Expand Down Expand Up @@ -535,7 +535,7 @@ def _load(
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
)
create_keyword_thread.start()

Expand Down
129 changes: 65 additions & 64 deletions api/core/rag/datasource/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,78 +258,79 @@ def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegme
include_segment_ids = []
segment_child_map = {}
for document in documents:
document_id = document.metadata["document_id"]
document_id = document.metadata.get("document_id")
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata["doc_id"]
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata.get("doc_id")
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
)
.first()
)
.first()
)
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
continue
else:
continue
else:
index_node_id = document.metadata["doc_id"]
index_node_id = document.metadata["doc_id"]

segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
.first()
)

if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}

records.append(record)
records.append(record)
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
Expand Down
37 changes: 19 additions & 18 deletions api/core/rag/docstore/dataset_docstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,26 +122,27 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True, sav
db.session.add(segment_document)
db.session.flush()
if save_child:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
if doc.children:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
else:
segment_document.content = doc.page_content
if doc.metadata.get("answer"):
segment_document.answer = doc.metadata.pop("answer", "")
segment_document.index_node_hash = doc.metadata["doc_hash"]
segment_document.index_node_hash = doc.metadata.get("doc_hash")
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
if save_child and doc.children:
Expand All @@ -160,8 +161,8 @@ def add_documents(self, docs: Sequence[Document], allow_update: bool = True, sav
document_id=self._document_id,
segment_id=segment_document.id,
position=position,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/extractor/excel_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, cast

import pandas as pd
from openpyxl import load_workbook
from openpyxl import load_workbook # type: ignore

from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/index_processor/index_processor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ def _get_splitter(
embedding_model_instance=embedding_model_instance,
)

return character_splitter
return character_splitter # type: ignore
Loading
Loading