Skip to content

Commit

Permalink
Added sorting option for sentences during encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
David-OC17 committed Nov 20, 2024
1 parent f17dffa commit 1170739
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions sonar/inference_pipelines/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def predict(
max_seq_len: Optional[int] = None,
progress_bar: bool = False,
target_device: Optional[Device] = None,
sort_sent: bool = True,
) -> torch.Tensor:
"""
Transform the input texts (from a list of strings or from a text file) into a matrix of their embeddings.
Expand All @@ -165,13 +166,19 @@ def truncate(x: torch.Tensor) -> torch.Tensor:
nonlocal n_truncated
n_truncated += 1
return x[:max_seq_len]

def sort_input(input_sent: Iterable[str]) -> Iterable[str]:
return sorted(input_sent, key=len) if sort_sent else input_sent

if isinstance(input, (str, Path)):
input_sent = read_text(input)
else:
input_sent = read_sequence(input)

sorted_input_sent = sort_input(input_sent=input_sent)

pipeline: Iterable = (
(
read_text(input)
if isinstance(input, (str, Path))
else read_sequence(input)
)
sorted_input_sent
.map(tokenizer_encoder)
.map(truncate)
.bucket(batch_size)
Expand Down

0 comments on commit 1170739

Please sign in to comment.