Skip to content

Commit

Permalink
fix: keywords (#7357)
Browse files Browse the repository at this point in the history
  • Loading branch information
crazywoola authored Aug 16, 2024
1 parent 3a33062 commit 4d4af00
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
19 changes: 11 additions & 8 deletions api/controllers/service_api/dataset/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,22 @@ def post(self, tenant_id, dataset_id, document_id):
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# validate args
parser = reqparse.RequestParser()
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
args = parser.parse_args()
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
if args['segments'] is not None:
for args_item in args['segments']:
SegmentService.segment_create_args_validate(args_item, document)
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form
}, 200
else:
return {"error": "Segemtns is required"}, 400

def get(self, tenant_id, dataset_id, document_id):
"""Create single segment."""
Expand Down
7 changes: 5 additions & 2 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,10 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas
segment_data_list.append(segment_document)

pre_segment_data_list.append(segment_document)
keywords_list.append(segment_item['keywords'])
if 'keywords' in segment_item:
keywords_list.append(segment_item['keywords'])
else:
keywords_list.append(None)

try:
# save vector index
Expand Down Expand Up @@ -1482,7 +1485,7 @@ def update_segment(cls, args: dict, segment: DocumentSegment, document: Document
db.session.add(segment)
db.session.commit()
# update segment index task
if args['keywords']:
if 'keywords' in args:
keyword = Keyword(dataset)
keyword.delete_by_ids([segment.index_node_id])
document = RAGDocument(
Expand Down

0 comments on commit 4d4af00

Please sign in to comment.