From 11707398b50b2fd3f47d3582285b6bcd7533a41b Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Wed, 20 Nov 2024 05:56:44 -0600 Subject: [PATCH] Added sorting option for sentences during encoding --- sonar/inference_pipelines/text.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/sonar/inference_pipelines/text.py b/sonar/inference_pipelines/text.py index 0d8dd00..0679b36 100644 --- a/sonar/inference_pipelines/text.py +++ b/sonar/inference_pipelines/text.py @@ -143,6 +143,7 @@ def predict( max_seq_len: Optional[int] = None, progress_bar: bool = False, target_device: Optional[Device] = None, + sort_sent: bool = True, ) -> torch.Tensor: """ Transform the input texts (from a list of strings or from a text file) into a matrix of their embeddings. @@ -165,13 +166,19 @@ def truncate(x: torch.Tensor) -> torch.Tensor: nonlocal n_truncated n_truncated += 1 return x[:max_seq_len] + + def sort_input(input_sent: Iterable[str]) -> Iterable[str]: + return sorted(input_sent, key=len) if sort_sent else input_sent + + if isinstance(input, (str, Path)): + input_sent = read_text(input) + else: + input_sent = read_sequence(input) + + sorted_input_sent = sort_input(input_sent=input_sent) pipeline: Iterable = ( - ( - read_text(input) - if isinstance(input, (str, Path)) - else read_sequence(input) - ) + sorted_input_sent .map(tokenizer_encoder) .map(truncate) .bucket(batch_size)