Skip to content

Commit

Permalink
Fix bug huggingface#35447 Tokenizer does not split text according to …
Browse files Browse the repository at this point in the history
…newly added input tokens

The root reason is Trie.split method didn't ignore partial match that should be removed

Add test case to token split
  • Loading branch information
jiongjiongli committed Dec 29, 2024
1 parent 5c75087 commit 1f3fdc1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ def split(self, text: str) -> List[str]:
# matches
# "[CLS]", "L", we need to match CLS even if L is special
for lookstart, looktrie_pointer in states.items():
if lookstart > start:
if lookstart in to_remove:
# This partial match should be removed
continue
elif lookstart > start:
# This partial match is later, we can stop looking
break
elif lookstart < start:
Expand Down
12 changes: 12 additions & 0 deletions tests/tokenization/test_tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,18 @@ def test_decoding_skip_special_tokens(self):
decoded_sent = tokenizer.decode(pad_id, skip_special_tokens=False)
self.assertEqual(decoded_sent, "[PAD]")

@require_tokenizers
def test_split_tokens(self):
for tokenizer_class in [BertTokenizer, BertTokenizerFast]:
with self.subTest(f"{tokenizer_class}"):
tokenizer = tokenizer_class.from_pretrained("google-bert/bert-base-cased")
tokenizer.add_tokens(["red", "e"])

# test split tokens
sentence = "read"
output_tokens = tokenizer.tokenize(sentence)
self.assertEqual(output_tokens, ["r", "e", "ad"])

@require_torch
def test_padding_accepts_tensors_pt(self):
import torch
Expand Down

0 comments on commit 1f3fdc1

Please sign in to comment.