-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add function to tokenize text stories and split into batches
- Loading branch information
Siwei Li
committed
Mar 10, 2024
1 parent
6177d73
commit e715276
Showing
2 changed files
with
86 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from collections import deque | ||
from typing import Union | ||
|
||
from transformers import PreTrainedTokenizerBase | ||
|
||
|
||
def get_tokenized_batches( | ||
text_stories: Union[list[str], list[list[int]]], | ||
tokenizer: PreTrainedTokenizerBase, | ||
context_size: int, | ||
input_tokenized=False, | ||
) -> list[list[int]]: | ||
dq = deque() | ||
samples = [] | ||
|
||
prompt_idx = 0 | ||
while prompt_idx < len(text_stories): | ||
while len(dq) < context_size: | ||
text_story = text_stories[prompt_idx] | ||
if not input_tokenized: | ||
dq.extend( | ||
tokenizer.encode(text_story, add_special_tokens=False) | ||
+ [tokenizer.eos_token_id] | ||
) | ||
else: | ||
dq.extend(text_story) | ||
dq.append(tokenizer.eos_token_id) | ||
prompt_idx += 1 | ||
|
||
sample = [tokenizer.bos_token_id] | ||
for i in range(context_size - 1): # peek at and not pop the last element | ||
sample.append(dq.popleft()) | ||
sample.append(dq[0]) | ||
|
||
samples.append(sample) | ||
|
||
if dq: | ||
samples.append([tokenizer.bos_token_id] + list(dq)) | ||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from transformers import AutoTokenizer | ||
|
||
from delphi.train.dataset_tokenization import get_tokenized_batches | ||
|
||
|
||
def test_get_tokenized_batches(): | ||
CTX_SIZE = 10 | ||
tokenizer = AutoTokenizer.from_pretrained("delphi-suite/v0-llama2-tokenizer") | ||
|
||
text_stories = [ | ||
"Once upon a", | ||
"Mother woke up alert. She put on her coat", | ||
"Once upon a time, in a small town, there was a weird", | ||
"Once upon a time, there was a", | ||
"Sara and Tom are friends. They like to play in the park.", | ||
] | ||
correct_batches = [ | ||
[1, 432, 440, 261, 2, 367, 501, 1917, 372, 3398, 4037], | ||
[1, 4037, 341, 577, 359, 342, 1854, 2, 432, 440, 261], | ||
[1, 261, 403, 4045, 317, 261, 560, 1000, 4045, 406, 286], | ||
[1, 286, 261, 2567, 2, 432, 440, 261, 403, 4045, 406], | ||
[1, 406, 286, 261, 2, 787, 269, 396, 484, 415, 4037], | ||
[1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2], | ||
] | ||
assert get_tokenized_batches(text_stories, tokenizer, CTX_SIZE) == correct_batches | ||
|
||
tokenized_stories = [ | ||
[1618, 3520, 2223, 3961, 853, 3376, 1820, 1442, 1573], | ||
[46, 3515, 2941, 1637, 1377], | ||
[1439, 3378, 3897, 3807, 343, 1140, 3843, 3848, 1343, 3812, 947, 2871, 1973], | ||
[1163, 1358, 1930, 3590, 2216, 3659, 278], | ||
[604, 2920, 1330, 2240, 786, 4088, 1416, 2122, 1556, 3501, 3159, 3427], | ||
] | ||
correct_batches = [ | ||
[1, 1618, 3520, 2223, 3961, 853, 3376, 1820, 1442, 1573, 2], | ||
[1, 2, 46, 3515, 2941, 1637, 1377, 2, 1439, 3378, 3897], | ||
[1, 3897, 3807, 343, 1140, 3843, 3848, 1343, 3812, 947, 2871], | ||
[1, 2871, 1973, 2, 1163, 1358, 1930, 3590, 2216, 3659, 278], | ||
[1, 278, 2, 604, 2920, 1330, 2240, 786, 4088, 1416, 2122], | ||
[1, 2122, 1556, 3501, 3159, 3427, 2], | ||
] | ||
assert ( | ||
get_tokenized_batches( | ||
tokenized_stories, tokenizer, CTX_SIZE, input_tokenized=True | ||
) | ||
== correct_batches | ||
) |