-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Split the tokenization function into two parts, fixing the while-loop…
… issues
- Loading branch information
Showing
2 changed files
with
84 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,50 @@ | ||
from collections import deque | ||
from typing import Union | ||
|
||
from transformers import PreTrainedTokenizerBase | ||
|
||
|
||
def extend_deque( | ||
dq: deque[int], | ||
context_size: int, | ||
text_stories: list[str], | ||
prompt_idx: int, | ||
tokenizer: PreTrainedTokenizerBase, | ||
) -> int: | ||
while len(dq) < context_size and prompt_idx < len(text_stories): | ||
text_story = text_stories[prompt_idx] | ||
dq.extend( | ||
tokenizer.encode(text_story, add_special_tokens=False) | ||
+ [tokenizer.eos_token_id] | ||
) | ||
prompt_idx += 1 | ||
return prompt_idx | ||
|
||
|
||
def make_new_samples( | ||
dq: deque[int], context_size: int, tokenizer: PreTrainedTokenizerBase | ||
) -> list[list[int]]: | ||
samples = [] | ||
while len(dq) >= context_size: | ||
sample = [tokenizer.bos_token_id] | ||
for _ in range(context_size - 1): # peek at and not pop the last element | ||
sample.append(dq.popleft()) | ||
sample.append(dq[0]) | ||
samples.append(sample) | ||
return samples | ||
|
||
|
||
def get_tokenized_batches( | ||
text_stories: Union[list[str], list[list[int]]], | ||
text_stories: list[str], | ||
tokenizer: PreTrainedTokenizerBase, | ||
context_size: int, | ||
input_tokenized=False, | ||
) -> list[list[int]]: | ||
dq = deque() | ||
prompt_idx = 0 | ||
samples = [] | ||
|
||
prompt_idx = 0 | ||
while prompt_idx < len(text_stories): | ||
while len(dq) < context_size: | ||
text_story = text_stories[prompt_idx] | ||
if not input_tokenized: | ||
dq.extend( | ||
tokenizer.encode(text_story, add_special_tokens=False) | ||
+ [tokenizer.eos_token_id] | ||
) | ||
else: | ||
dq.extend(text_story) | ||
dq.append(tokenizer.eos_token_id) | ||
prompt_idx += 1 | ||
|
||
sample = [tokenizer.bos_token_id] | ||
for i in range(context_size - 1): # peek at and not pop the last element | ||
sample.append(dq.popleft()) | ||
sample.append(dq[0]) | ||
|
||
samples.append(sample) | ||
prompt_idx = extend_deque(dq, context_size, text_stories, prompt_idx, tokenizer) | ||
samples.extend(make_new_samples(dq, context_size, tokenizer)) | ||
|
||
if dq: | ||
samples.append([tokenizer.bos_token_id] + list(dq)) | ||
# We discard the last chunk, so no processing on the remainder of the deque here | ||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters