From cbb0f9e76239c20e67e9a1e9f63f8313a2bded1f Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Wed, 4 Sep 2024 21:47:12 +0800 Subject: [PATCH] fix spliter length missed (#7987) --- api/core/rag/splitter/fixed_text_splitter.py | 10 +++++++--- api/core/rag/splitter/text_splitter.py | 5 ++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 6a0804f890db39..0c1cb57c7f4e03 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -93,17 +93,21 @@ def recursive_split_text(self, text: str) -> list[str]: splits = list(text) # Now go merging things, recursively splitting longer texts. _good_splits = [] + _good_splits_lengths = [] # cache the lengths of the splits for s in splits: - if self._length_function(s) < self._chunk_size: + s_len = self._length_function(s) + if s_len < self._chunk_size: _good_splits.append(s) + _good_splits_lengths.append(s_len) else: if _good_splits: - merged_text = self._merge_splits(_good_splits, separator) + merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths) final_chunks.extend(merged_text) _good_splits = [] + _good_splits_lengths = [] other_info = self.recursive_split_text(s) final_chunks.extend(other_info) if _good_splits: - merged_text = self._merge_splits(_good_splits, separator) + merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths) final_chunks.extend(merged_text) return final_chunks diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 943f9918a79c7e..f06f22a00e1855 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -243,7 +243,10 @@ def split_text(self, text: str) -> list[str]: # First we naively split the large input into a bunch of smaller ones. splits = _split_text_with_regex(text, self._separator, self._keep_separator) _separator = "" if self._keep_separator else self._separator - return self._merge_splits(splits, _separator) + _good_splits_lengths = [] # cache the lengths of the splits + for split in splits: + _good_splits_lengths.append(self._length_function(split)) + return self._merge_splits(splits, _separator, _good_splits_lengths) class LineType(TypedDict):