-
Notifications
You must be signed in to change notification settings - Fork 520
/
tokenizer.py
112 lines (89 loc) · 3.54 KB
/
tokenizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import sentencepiece as spm
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from pathlib import Path
from typing import Dict
class TokenizerInterface:
def __init__(self, model_path):
self.model_path = model_path
def encode(self, text):
raise NotImplementedError("This method should be overridden by subclasses.")
def decode(self, tokens):
raise NotImplementedError("This method should be overridden by subclasses.")
def bos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")
def eos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")
class SentencePieceWrapper(TokenizerInterface):
def __init__(self, model_path):
super().__init__(model_path)
self.processor = spm.SentencePieceProcessor(str(model_path))
def encode(self, text):
return self.processor.EncodeAsIds(text)
def decode(self, tokens):
return self.processor.DecodeIds(tokens)
def bos_id(self):
return self.processor.bos_id()
def eos_id(self):
return self.processor.eos_id()
class TiktokenWrapper(TokenizerInterface):
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
special_tokens: Dict[str, int]
num_reserved_special_tokens = 256
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
def __init__(self, model_path):
super().__init__(model_path)
assert os.path.isfile(model_path), str(model_path)
mergeable_ranks = load_tiktoken_bpe(str(model_path))
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [
f"<|reserved_special_token_{i}|>"
for i in range(5, self.num_reserved_special_tokens - 5)
]
self.special_tokens = {
token: num_base_tokens + i for i, token in enumerate(special_tokens)
}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
# BOS / EOS token IDs
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
self._eos_id: int = self.special_tokens["<|end_of_text|>"]
def encode(self, text):
return self.model.encode(text)
def decode(self, tokens):
return self.model.decode(tokens)
def bos_id(self):
return self._bos_id
def eos_id(self):
return self._eos_id
def get_tokenizer(tokenizer_model_path, model_name):
"""
Factory function to get the appropriate tokenizer based on the model name.
Args:
- tokenizer_model_path (str): The file path to the tokenizer model.
- model_name (str): The name of the model, used to determine the tokenizer type.
Returns:
- TokenizerInterface: An instance of a tokenizer.
"""
if "llama-3" in str(model_name).lower():
return TiktokenWrapper(tokenizer_model_path)
else:
return SentencePieceWrapper(tokenizer_model_path)