forked from latitudegames/GPT-3-Encoder
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Encoder.js
337 lines (285 loc) · 12.1 KB
/
Encoder.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
const encoder = require("./encoder");
// This file includes code which was modified from https://github.com/openai/gpt-2
const bpe_ranks = require("./bpe_ranks");
//The old version used to include this but i prebuild it into a js file to be loaded by browserify
//todo delete old comments when not needed
// const fs = require('fs')
// const path = require('path');
// const json-loder
// const loader = require("json-loader");
// const encoder = loader('./encoder.json');
// const encoder = JSON.parse(fs.readFileSync(path.join(__dirname, './encoder.json')));
// const bpe_file = fs.readFileSync(path.join(__dirname, './vocab.bpe'), 'utf-8');
const range = (x, y) => {
const res = Array.from(Array(y).keys()).slice(x)
return res
}
const ord = x => {
return x.charCodeAt(0)
}
const chr = x => {
return String.fromCharCode(x)
}
const encodeStr = str => {
return Array.from(Buffer.from(str, 'utf-8')).map(x => x.toString());
}
const decodeStr = arr => {
return Buffer.from(arr).toString('utf-8')
}
function bytes_to_unicode() {
const bs = range(ord('!'), ord('~') + 1).concat(range(ord('¡'), ord('¬') + 1), range(ord('®'), ord('ÿ') + 1))
let cs = bs.slice()
let n = 0
for (let b = 0; b < 2 ** 8; b++) {
if (!bs.includes(b)) {
bs.push(b)
cs.push(2 ** 8 + n)
n = n + 1
}
}
cs = cs.map(x => chr(x))
const result = {}
bs.map((_, i) => {
result[bs[i]] = cs[i]
})
return result
}
function get_pairs(word) {
const pairs = new Set()
let prev_char = word[0]
for (let i = 1; i < word.length; i++) {
const char = word[i]
pairs.add([prev_char, char])
prev_char = char
}
return pairs
}
const pat = /'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/gu
// The regular expression pat is used to split a string into an array of tokens.
//
// The regular expression consists of several parts:
// 's|'t|'re|'ve|'m|'ll|'d: These are all short forms of common English words (e.g. "is", "not", "have"). The | symbol means "or", so this part of the expression matches any of these short forms.
//
// ?\p{L}+: This matches one or more consecutive letters (i.e. "words"). The ? means that the preceding space character is optional, so this part of the expression will match both words with spaces before and after them, as well as words without spaces.
//
// ?\p{N}+: This matches one or more consecutive numbers. Like the previous part of the expression, the ? means that the preceding space character is optional.
//
// ?[^\s\p{L}\p{N}]+: This matches one or more consecutive non-letter, non-number characters (e.g. punctuation, symbols). The [^...] notation means "any character except the ones listed inside the brackets", and \s represents whitespace characters. The \p{L} and \p{N} shorthand character classes represent letters and numbers, respectively. The + symbol means "one or more occurrences", and the ? means that the preceding space character is optional.
//
// \s+(?!\S): This matches one or more consecutive whitespace characters that are followed by a non-whitespace character. The \S shorthand character class represents non-whitespace characters, and the (?!...) notation is a negative lookahead assertion, which means "do not match if the pattern inside the parentheses is present". This part of the expression is used to match leading and trailing whitespace characters.
//
// \s+: This matches one or more consecutive whitespace characters. This part of the expression is used to match sequences of multiple whitespace characters within the string.
//
// The g flag at the end of the regular expression means "global", which means that the regular expression will continue to search for matches after the first one is found. The u flag means "Unicode", which enables the use of Unicode character classes like \p{L} and \p{N}.
//
// Overall, this regular expression is used to split a string into an array of tokens by matching words, numbers, and non-letter, non-number characters, as well as leading and trailing whitespace and sequences of multiple whitespace characters within the string.
const decoder = {}
Object.keys(encoder).map(x => {
decoder[encoder[x]] = x
})
const byte_encoder = bytes_to_unicode()
const byte_decoder = {}
Object.keys(byte_encoder).map(x => {
byte_decoder[byte_encoder[x]] = x
})
const cache = new Map;
/**
* Implements the Byte Pair Encoding (BPE) algorithm for subword tokenization.
*
* The BPE algorithm operates on a vocabulary of subwords, and works by iteratively replacing the most frequent pair of
* subwords in the vocabulary with a new subword, until a specified vocabulary size is reached. This results in a
* of subwords that can be used to represent words in a language, while still maintaining some of the structure and
* meaning of the original words.
*
* Here's a breakdown of the function:
* 1 The function first checks if the input token is in the cache, and if it is, it returns the cached value. This is likely to improve performance by avoiding unnecessary processing for tokens that have already been processed.
* 2 The input token is then split into individual characters, and a list of pairs of adjacent characters (bigrams) is generated using the get_pairs function. If there are no pairs, the input token is returned as is.
* 3 The function then enters a loop that continues until a termination condition is met. In each iteration, the pair of subwords with the lowest rank (as determined by the bpe_ranks object) is identified and stored in the bigram variable. If the bigram is not in bpe_ranks, the loop terminates.
* 4 The bigram is then replaced with a new subword in the word list. The word list is iterated over and any instances of the bigram are replaced with the new subword.
* 5 The word list is then joined back into a string and stored in the cache. The cached string is returned as the result of the function.
* @param {string} token - The input token to be tokenized.
* @return {string} word - The tokenized subwords as a string.
*/
function bpe(token) {
if (cache.has(token)) {
return cache.get(token)
}
let word = token.split('')
let pairs = get_pairs(word)
if (!pairs) {
return token
}
while (true) {
const minPairs = {}
Array.from(pairs).map(pair => {
const rank = bpe_ranks[pair]
minPairs[(isNaN(rank) ? 10e10 : rank)] = pair
})
const bigram = minPairs[Math.min(...Object.keys(minPairs).map(x => {
return parseInt(x)
}
))]
if (!(bigram in bpe_ranks)) {
break
}
const first = bigram[0]
const second = bigram[1]
let new_word = []
let i = 0
while (i < word.length) {
const j = word.indexOf(first, i)
if (j === -1) {
new_word = new_word.concat(word.slice(i))
break
}
new_word = new_word.concat(word.slice(i, j))
i = j
if (word[i] === first && i < word.length - 1 && word[i + 1] === second) {
new_word.push(first + second)
i = i + 2
} else {
new_word.push(word[i])
i = i + 1
}
}
word = new_word
if (word.length === 1) {
break
} else {
pairs = get_pairs(word)
}
}
word = word.join(' ')
cache.set(token, word)
return word
}
/**
* Encodes a given text string into a list of BPE tokens.
*
* @param {string} text - The text to be encoded.
* @return {Array} bpe_tokens - The encoded BPE tokens.
*/
function encode(text) {
if (typeof text != "string") {
if (typeof text == "undefined") {
console.warn("undefined text returning empty []");
return [];
}
console.warn("casting to string hope thats what you want!");
text = "" + text;
}
let bpe_tokens = []
const matches = Array.from(text.matchAll(pat)).map(x => x[0])
for (let token of matches) {
token = encodeStr(token).map(x => {
return byte_encoder[x]
}).join('')
const new_tokens = bpe(token).split(' ').map(x => encoder[x])
bpe_tokens = bpe_tokens.concat(new_tokens)
}
return bpe_tokens
}
/**
* Computes count, unique, and frequency statistics for a string or an array of tokens.
* This function can be used to get insights into the characteristics of a text dataset,
* or to analyze the distribution of tokens in a body of text.
*
* @param {(string|Array<number>)} input - The input string or array of tokens.
* @return {Object} stats - An object with count, unique, frequency, positions, and tokens properties.
*
* @property {number} stats.count - The total number of tokens.
* @property {number} stats.unique - The number of unique tokens.
* @property {Object} stats.frequency - An object with token-frequency pairs, sorted by frequency in descending order.
* @property {Object} stats.positions - An object with token-position pairs, where positions is an array of the indices of the token in the input string or array.
* @property {Array<number>} stats.tokens - The array of tokens passed to the function.
*/
function tokenStats(input) {
let tokens
if (typeof input === 'string') {
// Encode the string into tokens
tokens = encode(input)
} else {
tokens = input
}
const stats = {
count: tokens.length,
unique: new Set(tokens).size,
frequency: {},
positions: {},
tokens,
}
// Compute the frequency of each token
for (let i = 0; i < tokens.length; i++) {
const token = tokens[i];
if (stats.frequency[token]) {
stats.frequency[token]++;
stats.positions[token].push(i);
} else {
stats.frequency[token] = 1;
stats.positions[token] = [i];
}
}
// let word = word.join(' ')
// cache.set(token, word)
//todo count words and determin some string stats as well
// Sort the frequency object by frequency in descending order
stats.frequency = Object.fromEntries(
Object.entries(stats.frequency)
.sort((a, b) => b[1] - a[1])
)
return stats
}
/**
* This function works by iterating through the matches of the pat pattern in the input text,
* encoding each match using the encodeStr function and the byte_encoder mapping,
* and then applying the bpe function to the encoded token. The number of tokens produced by the bpe function is then added to the count variable.
* Finally, the count variable is returned as the result.
* @param text
* @return {number}
*/
function countTokens(text) {
let count = 0
const matches = Array.from(text.matchAll(pat)).map(x => x[0])
// Timings for 20* chars(200000): counting took average: 572.8,
// count = matches.reduce((acc, token) => {
// token = encodeStr(token).map(x => {
// return byte_encoder[x]
// }).join('');
//
// return acc + bpe(token).split(' ').length;
// }, 0);
//Timings for 20* chars(200000): counting took average: 570.8,
// for (let token of matches) {
// Timings for 20* chars(200000): counting took average: 559.85,
// not much difrence. but i dont mind the for loopl
let i, token;
for (i = 0; i < matches.length; i++) {
token = matches[i];
token = encodeStr(matches[i]).map(x => {
return byte_encoder[x]
}).join('')
count += bpe(token).split(' ').length
}
return count
}
/**
* Decodes a list of BPE tokens into a text string.
*
* @param {Array} tokens - The list of BPE tokens to be decoded.
* @return {string} text - The decoded text string.
*/
function decode(tokens) {
if (!tokens) {
console.warn("No tokens to decode, returning empty string")
return "";
}
let text = tokens.map(x => decoder[x]).join('')
text = decodeStr(text.split('').map(x => byte_decoder[x]))
return text
}
module.exports = {
encode,
decode,
countTokens,
tokenStats
};