Skip to content

Commit

Permalink
fix: Bring in fix from custom nodes (#8539)
Browse files Browse the repository at this point in the history
* Bring in fix from custom nodes

* Add to_dict function and test

* reno

* Fix pylint
  • Loading branch information
sjrl authored Nov 14, 2024
1 parent f5683bc commit 0c11c7b
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 16 deletions.
47 changes: 42 additions & 5 deletions haystack/components/preprocessors/nltk_document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Literal, Tuple
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple

from haystack import Document, component, logging
from haystack.components.preprocessors.document_splitter import DocumentSplitter
from haystack.core.serialization import default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils import serialize_callable

with LazyImport("Run 'pip install nltk'") as nltk_imports:
import nltk
Expand All @@ -23,7 +25,7 @@

@component
class NLTKDocumentSplitter(DocumentSplitter):
def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
split_by: Literal["word", "sentence", "page", "passage", "function"] = "word",
split_length: int = 200,
Expand All @@ -33,6 +35,7 @@ def __init__(
language: Language = "en",
use_split_rules: bool = True,
extend_abbreviations: bool = True,
splitting_function: Optional[Callable[[str], List[str]]] = None,
):
"""
Splits your documents using NLTK to respect sentence boundaries.
Expand All @@ -53,10 +56,17 @@ def __init__(
:param extend_abbreviations: Choose whether to extend NLTK's PunktTokenizer abbreviations with a list
of curated abbreviations, if available.
This is currently supported for English ("en") and German ("de").
:param splitting_function: Necessary when `split_by` is set to "function".
This is a function which must accept a single `str` as input and return a `list` of `str` as output,
representing the chunks after splitting.
"""

super(NLTKDocumentSplitter, self).__init__(
split_by=split_by, split_length=split_length, split_overlap=split_overlap, split_threshold=split_threshold
split_by=split_by,
split_length=split_length,
split_overlap=split_overlap,
split_threshold=split_threshold,
splitting_function=splitting_function,
)
nltk_imports.check()
if respect_sentence_boundary and split_by != "word":
Expand All @@ -66,6 +76,8 @@ def __init__(
)
respect_sentence_boundary = False
self.respect_sentence_boundary = respect_sentence_boundary
self.use_split_rules = use_split_rules
self.extend_abbreviations = extend_abbreviations
self.sentence_splitter = SentenceSplitter(
language=language,
use_split_rules=use_split_rules,
Expand Down Expand Up @@ -100,9 +112,11 @@ def _split_into_units(
elif split_by == "word":
self.split_at = " "
units = text.split(self.split_at)
elif split_by == "function" and self.splitting_function is not None:
return self.splitting_function(text)
else:
raise NotImplementedError(
"DocumentSplitter only supports 'word', 'sentence', 'page' or 'passage' split_by options."
"DocumentSplitter only supports 'function', 'page', 'passage', 'sentence' or 'word' split_by options."
)

# Add the delimiter back to all units except the last one
Expand Down Expand Up @@ -138,6 +152,9 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
raise ValueError(
f"DocumentSplitter only works with text documents but content for document ID {doc.id} is None."
)
if doc.content == "":
logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id)
continue

if self.respect_sentence_boundary:
units = self._split_into_units(doc.content, "sentence")
Expand All @@ -159,6 +176,25 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
)
return {"documents": split_docs}

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
"""
serialized = default_to_dict(
self,
split_by=self.split_by,
split_length=self.split_length,
split_overlap=self.split_overlap,
split_threshold=self.split_threshold,
respect_sentence_boundary=self.respect_sentence_boundary,
language=self.language,
use_split_rules=self.use_split_rules,
extend_abbreviations=self.extend_abbreviations,
)
if self.splitting_function:
serialized["init_parameters"]["splitting_function"] = serialize_callable(self.splitting_function)
return serialized

@staticmethod
def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_overlap: int) -> int:
"""
Expand All @@ -175,7 +211,8 @@ def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_

num_sentences_to_keep = 0
num_words = 0
for sent in reversed(sentences):
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
for sent in reversed(sentences[1:]):
num_words += len(sent.split())
# If the number of words is larger than the split_length then don't add any more sentences
if num_words > split_length:
Expand Down
10 changes: 10 additions & 0 deletions releasenotes/notes/fix-nltk-doc-splitter-d0864dda906c45b0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
fixes:
- |
For the NLTKDocumentSplitter we are updating how chunks are made when splitting by word and sentence boundary is respected.
Namely, to avoid fully subsuming the previous chunk into the next one, we ignore the first sentence from that chunk when calculating sentence overlap.
i.e. we want to avoid cases of Doc1 = [s1, s2], Doc2 = [s1, s2, s3].
Finished adding function support for this component by updating the _split_into_units function and added the splitting_function init parameter.
Add specific to_dict method to overwrite the underlying one from DocumentSplitter. This is needed to properly save the settings of the component to yaml.
63 changes: 52 additions & 11 deletions test/components/preprocessors/test_nltk_document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
from pytest import LogCaptureFixture

from haystack.components.preprocessors.nltk_document_splitter import NLTKDocumentSplitter, SentenceSplitter
from haystack.utils import deserialize_callable


def test_init_warning_message(caplog: LogCaptureFixture) -> None:
_ = NLTKDocumentSplitter(split_by="page", respect_sentence_boundary=True)
assert "The 'respect_sentence_boundary' option is only supported for" in caplog.text


def custom_split(text):
return text.split(".")


class TestNLTKDocumentSplitterSplitIntoUnits:
def test_document_splitter_split_into_units_word(self) -> None:
document_splitter = NLTKDocumentSplitter(
Expand Down Expand Up @@ -87,9 +92,11 @@ class TestNLTKDocumentSplitterNumberOfSentencesToKeep:
@pytest.mark.parametrize(
"sentences, expected_num_sentences",
[
(["Moonlight shimmered softly, wolves howled nearby, night enveloped everything."], 0),
([" It was a dark night ..."], 0),
([" The moon was full."], 1),
(["The sun set.", "Moonlight shimmered softly, wolves howled nearby, night enveloped everything."], 0),
(["The sun set.", "It was a dark night ..."], 0),
(["The sun set.", " The moon was full."], 1),
(["The sun.", " The moon."], 1), # Ignores the first sentence
(["Sun", "Moon"], 1), # Ignores the first sentence even if its inclusion would be < split_overlap
],
)
def test_number_of_sentences_to_keep(self, sentences: List[str], expected_num_sentences: int) -> None:
Expand Down Expand Up @@ -304,7 +311,7 @@ def test_run_split_by_word_respect_sentence_boundary_no_repeats(self) -> None:
def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page_breaks(self) -> None:
document_splitter = NLTKDocumentSplitter(
split_by="word",
split_length=5,
split_length=8,
split_overlap=1,
split_threshold=0,
language="en",
Expand All @@ -313,26 +320,60 @@ def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page
respect_sentence_boundary=True,
)

text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5."
text = (
"Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f"
"Sentence on page 3. Another on page 3.\f\f Sentence on page 5."
)
documents = document_splitter.run(documents=[Document(content=text)])["documents"]

assert len(documents) == 4
assert documents[0].content == "Sentence on page 1.\f"
assert len(documents) == 6
assert documents[0].content == "Sentence on page 1. Another on page 1.\f"
assert documents[0].meta["page_number"] == 1
assert documents[0].meta["split_id"] == 0
assert documents[0].meta["split_idx_start"] == text.index(documents[0].content)
assert documents[1].content == "Sentence on page 1.\fSentence on page 2. \f"
assert documents[1].content == "Another on page 1.\fSentence on page 2. "
assert documents[1].meta["page_number"] == 1
assert documents[1].meta["split_id"] == 1
assert documents[1].meta["split_idx_start"] == text.index(documents[1].content)
assert documents[2].content == "Sentence on page 2. \fSentence on page 3. \f\f "
assert documents[2].content == "Sentence on page 2. Another on page 2.\f"
assert documents[2].meta["page_number"] == 2
assert documents[2].meta["split_id"] == 2
assert documents[2].meta["split_idx_start"] == text.index(documents[2].content)
assert documents[3].content == "Sentence on page 3. \f\f Sentence on page 5."
assert documents[3].meta["page_number"] == 3
assert documents[3].content == "Another on page 2.\fSentence on page 3. "
assert documents[3].meta["page_number"] == 2
assert documents[3].meta["split_id"] == 3
assert documents[3].meta["split_idx_start"] == text.index(documents[3].content)
assert documents[4].content == "Sentence on page 3. Another on page 3.\f\f "
assert documents[4].meta["page_number"] == 3
assert documents[4].meta["split_id"] == 4
assert documents[4].meta["split_idx_start"] == text.index(documents[4].content)
assert documents[5].content == "Another on page 3.\f\f Sentence on page 5."
assert documents[5].meta["page_number"] == 3
assert documents[5].meta["split_id"] == 5
assert documents[5].meta["split_idx_start"] == text.index(documents[5].content)

def test_to_dict(self):
splitter = NLTKDocumentSplitter(split_by="word", split_length=10, split_overlap=2, split_threshold=5)
serialized = splitter.to_dict()

assert serialized["type"] == "haystack.components.preprocessors.nltk_document_splitter.NLTKDocumentSplitter"
assert serialized["init_parameters"]["split_by"] == "word"
assert serialized["init_parameters"]["split_length"] == 10
assert serialized["init_parameters"]["split_overlap"] == 2
assert serialized["init_parameters"]["split_threshold"] == 5
assert serialized["init_parameters"]["language"] == "en"
assert serialized["init_parameters"]["use_split_rules"] is True
assert serialized["init_parameters"]["extend_abbreviations"] is True
assert "splitting_function" not in serialized["init_parameters"]

def test_to_dict_with_splitting_function(self):
splitter = NLTKDocumentSplitter(split_by="function", splitting_function=custom_split)
serialized = splitter.to_dict()

assert serialized["type"] == "haystack.components.preprocessors.nltk_document_splitter.NLTKDocumentSplitter"
assert serialized["init_parameters"]["split_by"] == "function"
assert "splitting_function" in serialized["init_parameters"]
assert callable(deserialize_callable(serialized["init_parameters"]["splitting_function"]))


class TestSentenceSplitter:
Expand Down

0 comments on commit 0c11c7b

Please sign in to comment.