From 3ebeada7d158e74583740011befc4ac961829848 Mon Sep 17 00:00:00 2001 From: Wei Lu Date: Sat, 25 May 2024 21:46:07 -0700 Subject: [PATCH] implement an efficient function using stack --- .DS_Store | Bin 0 -> 6148 bytes src/encoder.py | 37 +++++++------------------------------ 2 files changed, 7 insertions(+), 30 deletions(-) create mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8a808272e4f2bbd94e4046f2201cde0fb5995d73 GIT binary patch literal 6148 zcmeHK%}T>S5Z-O8O({YSiXIod7Hn)0!Apqs1&ruHr6#6mFlI}VnnNk%sxRc5_&m<+ zZotx>MeGdhe)GGV{UH0p7~|e7au~B2V-_?-j>;NA_u5dyBqMShBcFze48i&grzZB- z0l&S;a+a~ju>AS`Nt_q`{uggFTU*<0t8I0yd;eLMei0P2+zYZ>w63H~!def)>v&R* z?cH;k7D1d&GF1>q6G*wciPK1yo}8sou4;W9usT*}Z12zKPER<)<3&%*o#Dx#Cr$^> zV$reo4v)?*$1lliD&IAm9LQF(W3Yw~P|RZ9{47mm`UIXjyUrpc28aP-fEd_t2F$@= zbvK*`T0SvA46HJM`-6an=o+jvs;vV$ygp;Rg@^(=z9kTaLDyiV5h5U5mjdciZk`xi zmxEuJJl9~QQI|8WW`=Rh%=P1itJ%RXR666XM(T+HVqlqprZ#On|1aQ|S^LOeE+LB; zAO`*!1H3izM;8Toy?} literal 0 HcmV?d00001 diff --git a/src/encoder.py b/src/encoder.py index 5f52e723c..828ba9b40 100644 --- a/src/encoder.py +++ b/src/encoder.py @@ -56,39 +56,16 @@ def bpe(self, token): if token in self.cache: return self.cache[token] word = tuple(token) - pairs = get_pairs(word) - if not pairs: - return token + new_word = [] - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break + for x in word: + curr = x + while new_word and (new_word[-1], curr) in self.bpe_ranks: + curr = self.bpe_ranks[(new_word.pop(), curr)] + new_word.append(curr) - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) + word = new_word word = ' '.join(word) self.cache[token] = word return word