diff --git a/sonar/inference_pipelines/text.py b/sonar/inference_pipelines/text.py index 0d8dd00..0679b36 100644 --- a/sonar/inference_pipelines/text.py +++ b/sonar/inference_pipelines/text.py @@ -143,6 +143,7 @@ def predict( max_seq_len: Optional[int] = None, progress_bar: bool = False, target_device: Optional[Device] = None, + sort_sent: bool = True, ) -> torch.Tensor: """ Transform the input texts (from a list of strings or from a text file) into a matrix of their embeddings. @@ -165,13 +166,19 @@ def truncate(x: torch.Tensor) -> torch.Tensor: nonlocal n_truncated n_truncated += 1 return x[:max_seq_len] + + def sort_input(input_sent: Iterable[str]) -> Iterable[str]: + return sorted(input_sent, key=len) if sort_sent else input_sent + + if isinstance(input, (str, Path)): + input_sent = read_text(input) + else: + input_sent = read_sequence(input) + + sorted_input_sent = sort_input(input_sent=input_sent) pipeline: Iterable = ( - ( - read_text(input) - if isinstance(input, (str, Path)) - else read_sequence(input) - ) + sorted_input_sent .map(tokenizer_encoder) .map(truncate) .bucket(batch_size)