forked from superlinear-ai/raglite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_split_chunks.py
102 lines (97 loc) · 4.66 KB
/
_split_chunks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Split a document into semantic chunks."""
import re
import numpy as np
from scipy.optimize import linprog
from scipy.sparse import coo_matrix
from raglite._typing import FloatMatrix
def split_chunks( # noqa: C901, PLR0915
sentences: list[str],
sentence_embeddings: FloatMatrix,
sentence_window_size: int = 3,
max_size: int = 1440,
) -> tuple[list[str], list[FloatMatrix]]:
"""Split sentences into optimal semantic chunks with corresponding sentence embeddings."""
# Validate the input.
sentence_length = np.asarray([len(sentence) for sentence in sentences])
if not np.all(sentence_length <= max_size):
error_message = "Sentence with length larger than chunk max_size detected."
raise ValueError(error_message)
if not np.all(np.linalg.norm(sentence_embeddings, axis=1) > 0.0):
error_message = "Sentence embeddings with zero norm detected."
raise ValueError(error_message)
# Exit early if there is only one chunk to return.
if len(sentences) <= 1 or sum(sentence_length) <= max_size:
return ["".join(sentences)] if sentences else sentences, [sentence_embeddings]
# Normalise the sentence embeddings to unit norm.
X = sentence_embeddings.astype(np.float32) # noqa: N806
X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806
# Select nonoutlying sentences and remove the discourse vector.
q15, q85 = np.quantile(sentence_length, [0.15, 0.85])
nonoutlying_sentences = (q15 <= sentence_length) & (sentence_length <= q85)
discourse = np.mean(X[nonoutlying_sentences, :], axis=0)
discourse = discourse / np.linalg.norm(discourse)
if not np.any(np.linalg.norm(X - discourse[np.newaxis, :], axis=1) <= np.finfo(X.dtype).eps):
X = X - np.outer(X @ discourse, discourse) # noqa: N806
X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806
# For each partition point in the list of sentences, compute the similarity of the windows
# before and after the partition point. Sentence embeddings are assumed to be of the sentence
# itself and at most the (sentence_window_size - 1) sentences that preceed it.
sentence_window_size = min(len(sentences) - 1, sentence_window_size)
windows_before = X[:-sentence_window_size]
windows_after = X[sentence_window_size:]
partition_similarity = np.ones(len(sentences) - 1, dtype=X.dtype)
partition_similarity[: len(windows_before)] = np.sum(windows_before * windows_after, axis=1)
# Make partition similarity nonnegative before modification and optimisation.
partition_similarity = np.maximum(
(partition_similarity + 1) / 2, np.sqrt(np.finfo(X.dtype).eps)
)
# Modify the partition similarity to encourage splitting on Markdown headings.
prev_sentence_is_heading = True
for i, sentence in enumerate(sentences[:-1]):
is_heading = bool(re.match(r"^#+\s", sentence.replace("\n", "").strip()))
if is_heading:
# Encourage splitting before a heading.
if not prev_sentence_is_heading:
partition_similarity[i - 1] = partition_similarity[i - 1] / 4
# Don't split immediately after a heading.
partition_similarity[i] = 1.0
prev_sentence_is_heading = is_heading
# Solve an optimisation problem to find the best partition points.
sentence_length_cumsum = np.cumsum(sentence_length)
row_indices = []
col_indices = []
data = []
for i in range(len(sentences) - 1):
r = sentence_length_cumsum[i - 1] if i > 0 else 0
idx = np.searchsorted(sentence_length_cumsum - r, max_size)
assert idx > i
if idx == len(sentence_length_cumsum):
break
cols = list(range(i, idx))
col_indices.extend(cols)
row_indices.extend([i] * len(cols))
data.extend([1] * len(cols))
A = coo_matrix( # noqa: N806
(data, (row_indices, col_indices)),
shape=(max(row_indices) + 1, len(sentences) - 1),
dtype=np.float32,
)
b_ub = np.ones(A.shape[0], dtype=np.float32)
res = linprog(
partition_similarity,
A_ub=-A,
b_ub=-b_ub,
bounds=(0, 1),
integrality=[1] * A.shape[1],
)
if not res.success:
error_message = "Optimization of chunk partitions failed."
raise ValueError(error_message)
# Split the sentences and their window embeddings into optimal chunks.
partition_indices = (np.where(res.x)[0] + 1).tolist()
chunks = [
"".join(sentences[i:j])
for i, j in zip([0, *partition_indices], [*partition_indices, len(sentences)], strict=True)
]
chunk_embeddings = np.split(sentence_embeddings, partition_indices)
return chunks, chunk_embeddings