diff --git a/laser_encoders/laser_tokenizer.py b/laser_encoders/laser_tokenizer.py index 0488cb2c..728fdde5 100644 --- a/laser_encoders/laser_tokenizer.py +++ b/laser_encoders/laser_tokenizer.py @@ -103,8 +103,8 @@ def tokenize_file(self, inp_fname: Path, out_fname: Path) -> None: tokens = self.tokenize(line.strip()) file_out.write(tokens + "\n") - def __call__(self, text_or_batch, batch=False): - if not batch: + def __call__(self, text_or_batch): + if isinstance(text_or_batch, str): return self.tokenize(text_or_batch) else: return self.tokenize_batch(text_or_batch) diff --git a/laser_encoders/models.py b/laser_encoders/models.py index 037a4f9f..6d2a567f 100644 --- a/laser_encoders/models.py +++ b/laser_encoders/models.py @@ -94,10 +94,12 @@ def __init__( self.encoder.eval() self.sort_kind = sort_kind - def __call__(self, sentences): + def __call__(self, text_or_batch): if self.spm_model: - sentences = self.tokenizer(sentences) - return self.encode_sentences(sentences) + text_or_batch = self.tokenizer(text_or_batch) + if isinstance(text_or_batch, str): + text_or_batch = [text_or_batch] + return self.encode_sentences(text_or_batch) else: raise ValueError( "Either initialize the encoder with an spm_model or pre-tokenize and use the encode_sentences method." diff --git a/laser_encoders/test_laser_tokenizer.py b/laser_encoders/test_laser_tokenizer.py index 1155f8d2..1350c108 100644 --- a/laser_encoders/test_laser_tokenizer.py +++ b/laser_encoders/test_laser_tokenizer.py @@ -65,6 +65,19 @@ def test_tokenize(tokenizer, input_text: str): assert tokenizer.tokenize(input_text) == expected_output +def test_tokenizer_call_method(tokenizer, input_text: str): + single_string = "This is a test sentence." + expected_output = "▁this ▁is ▁a ▁test ▁sent ence ." + assert tokenizer(single_string) == expected_output + + list_of_strings = ["This is a test sentence.", "This is another test sentence."] + expected_output = [ + "▁this ▁is ▁a ▁test ▁sent ence .", + "▁this ▁is ▁another ▁test ▁sent ence .", + ] + assert tokenizer(list_of_strings) == expected_output + + def test_normalization(tokenizer): test_data = "Hello!!! How are you??? I'm doing great." expected_output = "▁hel lo !!! ▁how ▁are ▁you ??? ▁i ' m ▁do ing ▁great ."