Skip to content

Commit

Permalink
Fix DocumentSplitter not splitting by function (#8549)
Browse files Browse the repository at this point in the history
* Fix DocumentSplitter not splitting by function

* Make the split_by mapping a constant
  • Loading branch information
silvanocerza authored Nov 18, 2024
1 parent cea1e3f commit bd77120
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 51 deletions.
74 changes: 37 additions & 37 deletions haystack/components/preprocessors/document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

logger = logging.getLogger(__name__)

# Maps the 'split_by' argument to the actual char used to split the Documents.
# 'function' is not in the mapping cause it doesn't split on chars.
_SPLIT_BY_MAPPING = {"page": "\f", "passage": "\n\n", "sentence": ".", "word": " ", "line": "\n"}


@component
class DocumentSplitter:
Expand Down Expand Up @@ -73,7 +77,7 @@ def __init__( # pylint: disable=too-many-positional-arguments

self.split_by = split_by
if split_by not in ["function", "page", "passage", "sentence", "word", "line"]:
raise ValueError("split_by must be one of 'word', 'sentence', 'page', 'passage' or 'line'.")
raise ValueError("split_by must be one of 'function', 'word', 'sentence', 'page', 'passage' or 'line'.")
if split_by == "function" and splitting_function is None:
raise ValueError("When 'split_by' is set to 'function', a valid 'splitting_function' must be provided.")
if split_length <= 0:
Expand Down Expand Up @@ -108,7 +112,7 @@ def run(self, documents: List[Document]):
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
raise TypeError("DocumentSplitter expects a List of Documents as input.")

split_docs = []
split_docs: List[Document] = []
for doc in documents:
if doc.content is None:
raise ValueError(
Expand All @@ -117,42 +121,38 @@ def run(self, documents: List[Document]):
if doc.content == "":
logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id)
continue
units = self._split_into_units(doc.content, self.split_by)
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
units, self.split_length, self.split_overlap, self.split_threshold
)
metadata = deepcopy(doc.meta)
metadata["source_id"] = doc.id
split_docs += self._create_docs_from_splits(
text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata
)
split_docs += self._split(doc)
return {"documents": split_docs}

def _split_into_units(
self, text: str, split_by: Literal["function", "page", "passage", "sentence", "word", "line"]
) -> List[str]:
if split_by == "page":
self.split_at = "\f"
elif split_by == "passage":
self.split_at = "\n\n"
elif split_by == "sentence":
self.split_at = "."
elif split_by == "word":
self.split_at = " "
elif split_by == "line":
self.split_at = "\n"
elif split_by == "function" and self.splitting_function is not None:
return self.splitting_function(text)
else:
raise NotImplementedError(
"""DocumentSplitter only supports 'function', 'line', 'page',
'passage', 'sentence' or 'word' split_by options."""
)
units = text.split(self.split_at)
def _split(self, to_split: Document) -> List[Document]:
# We already check this before calling _split but
# we need to make linters happy
if to_split.content is None:
return []

if self.split_by == "function" and self.splitting_function is not None:
splits = self.splitting_function(to_split.content)
docs: List[Document] = []
for s in splits:
meta = deepcopy(to_split.meta)
meta["source_id"] = to_split.id
docs.append(Document(content=s, meta=meta))
return docs

split_at = _SPLIT_BY_MAPPING[self.split_by]
units = to_split.content.split(split_at)
# Add the delimiter back to all units except the last one
for i in range(len(units) - 1):
units[i] += self.split_at
return units
units[i] += split_at

text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
units, self.split_length, self.split_overlap, self.split_threshold
)
metadata = deepcopy(to_split.meta)
metadata["source_id"] = to_split.id
return self._create_docs_from_splits(
text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata
)

def _concatenate_units(
self, elements: List[str], split_length: int, split_overlap: int, split_threshold: int
Expand All @@ -166,8 +166,8 @@ def _concatenate_units(
"""

text_splits: List[str] = []
splits_pages = []
splits_start_idxs = []
splits_pages: List[int] = []
splits_start_idxs: List[int] = []
cur_start_idx = 0
cur_page = 1
segments = windowed(elements, n=split_length, step=split_length - split_overlap)
Expand Down Expand Up @@ -200,7 +200,7 @@ def _concatenate_units(
return text_splits, splits_pages, splits_start_idxs

def _create_docs_from_splits(
self, text_splits: List[str], splits_pages: List[int], splits_start_idxs: List[int], meta: Dict
self, text_splits: List[str], splits_pages: List[int], splits_start_idxs: List[int], meta: Dict[str, Any]
) -> List[Document]:
"""
Creates Document objects from splits enriching them with page number and the metadata of the original document.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.language = language

def _split_into_units(
self, text: str, split_by: Literal["word", "sentence", "passage", "page", "function"]
self, text: str, split_by: Literal["function", "page", "passage", "sentence", "word", "line"]
) -> List[str]:
"""
Splits the text into units based on the specified split_by parameter.
Expand Down
5 changes: 5 additions & 0 deletions releasenotes/notes/split-by-function-62ce32fac70d8f8c.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Fix `DocumentSplitter` to handle custom `splitting_function` without requiring `split_length`.
Previously the `splitting_function` provided would not override other settings.
37 changes: 24 additions & 13 deletions test/components/preprocessors/test_document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_empty_list(self):

def test_unsupported_split_by(self):
with pytest.raises(
ValueError, match="split_by must be one of 'word', 'sentence', 'page', 'passage' or 'line'."
ValueError, match="split_by must be one of 'function', 'word', 'sentence', 'page', 'passage' or 'line'."
):
DocumentSplitter(split_by="unsupported")

Expand Down Expand Up @@ -177,25 +177,36 @@ def test_split_by_page(self):
assert docs[2].meta["page_number"] == 3

def test_split_by_function(self):
splitting_function = lambda input_str: input_str.split(".")
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function, split_length=1)
splitting_function = lambda s: s.split(".")
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function)
text = "This.Is.A.Test"
result = splitter.run(documents=[Document(content=text)])
result = splitter.run(documents=[Document(id="1", content=text, meta={"key": "value"})])
docs = result["documents"]

word_list = ["This", "Is", "A", "Test"]
assert len(docs) == 4
for w_target, w_split in zip(word_list, docs):
assert w_split.content == w_target

splitting_function = lambda input_str: re.split("[\s]{2,}", input_str)
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function, split_length=1)
assert docs[0].content == "This"
assert docs[0].meta == {"key": "value", "source_id": "1"}
assert docs[1].content == "Is"
assert docs[1].meta == {"key": "value", "source_id": "1"}
assert docs[2].content == "A"
assert docs[2].meta == {"key": "value", "source_id": "1"}
assert docs[3].content == "Test"
assert docs[3].meta == {"key": "value", "source_id": "1"}

splitting_function = lambda s: re.split(r"[\s]{2,}", s)
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function)
text = "This Is\n A Test"
result = splitter.run(documents=[Document(content=text)])
result = splitter.run(documents=[Document(id="1", content=text, meta={"key": "value"})])
docs = result["documents"]
assert len(docs) == 4
for w_target, w_split in zip(word_list, docs):
assert w_split.content == w_target
assert docs[0].content == "This"
assert docs[0].meta == {"key": "value", "source_id": "1"}
assert docs[1].content == "Is"
assert docs[1].meta == {"key": "value", "source_id": "1"}
assert docs[2].content == "A"
assert docs[2].meta == {"key": "value", "source_id": "1"}
assert docs[3].content == "Test"
assert docs[3].meta == {"key": "value", "source_id": "1"}

def test_split_by_word_with_overlap(self):
splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2)
Expand Down

0 comments on commit bd77120

Please sign in to comment.