Skip to content

Commit

Permalink
Add function to tokenize text stories and split into batches
Browse files Browse the repository at this point in the history
  • Loading branch information
Siwei Li committed Mar 10, 2024
1 parent 6177d73 commit e715276
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/delphi/train/dataset_tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from collections import deque
from typing import Union

from transformers import PreTrainedTokenizerBase


def get_tokenized_batches(
text_stories: Union[list[str], list[list[int]]],
tokenizer: PreTrainedTokenizerBase,
context_size: int,
input_tokenized=False,
) -> list[list[int]]:
dq = deque()
samples = []

prompt_idx = 0
while prompt_idx < len(text_stories):
while len(dq) < context_size:
text_story = text_stories[prompt_idx]
if not input_tokenized:
dq.extend(
tokenizer.encode(text_story, add_special_tokens=False)
+ [tokenizer.eos_token_id]
)
else:
dq.extend(text_story)
dq.append(tokenizer.eos_token_id)
prompt_idx += 1

sample = [tokenizer.bos_token_id]
for i in range(context_size - 1): # peek at and not pop the last element
sample.append(dq.popleft())
sample.append(dq[0])

samples.append(sample)

if dq:
samples.append([tokenizer.bos_token_id] + list(dq))
return samples
47 changes: 47 additions & 0 deletions tests/train/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from transformers import AutoTokenizer

from delphi.train.dataset_tokenization import get_tokenized_batches


def test_get_tokenized_batches():
CTX_SIZE = 10
tokenizer = AutoTokenizer.from_pretrained("delphi-suite/v0-llama2-tokenizer")

text_stories = [
"Once upon a",
"Mother woke up alert. She put on her coat",
"Once upon a time, in a small town, there was a weird",
"Once upon a time, there was a",
"Sara and Tom are friends. They like to play in the park.",
]
correct_batches = [
[1, 432, 440, 261, 2, 367, 501, 1917, 372, 3398, 4037],
[1, 4037, 341, 577, 359, 342, 1854, 2, 432, 440, 261],
[1, 261, 403, 4045, 317, 261, 560, 1000, 4045, 406, 286],
[1, 286, 261, 2567, 2, 432, 440, 261, 403, 4045, 406],
[1, 406, 286, 261, 2, 787, 269, 396, 484, 415, 4037],
[1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2],
]
assert get_tokenized_batches(text_stories, tokenizer, CTX_SIZE) == correct_batches

tokenized_stories = [
[1618, 3520, 2223, 3961, 853, 3376, 1820, 1442, 1573],
[46, 3515, 2941, 1637, 1377],
[1439, 3378, 3897, 3807, 343, 1140, 3843, 3848, 1343, 3812, 947, 2871, 1973],
[1163, 1358, 1930, 3590, 2216, 3659, 278],
[604, 2920, 1330, 2240, 786, 4088, 1416, 2122, 1556, 3501, 3159, 3427],
]
correct_batches = [
[1, 1618, 3520, 2223, 3961, 853, 3376, 1820, 1442, 1573, 2],
[1, 2, 46, 3515, 2941, 1637, 1377, 2, 1439, 3378, 3897],
[1, 3897, 3807, 343, 1140, 3843, 3848, 1343, 3812, 947, 2871],
[1, 2871, 1973, 2, 1163, 1358, 1930, 3590, 2216, 3659, 278],
[1, 278, 2, 604, 2920, 1330, 2240, 786, 4088, 1416, 2122],
[1, 2122, 1556, 3501, 3159, 3427, 2],
]
assert (
get_tokenized_batches(
tokenized_stories, tokenizer, CTX_SIZE, input_tokenized=True
)
== correct_batches
)

0 comments on commit e715276

Please sign in to comment.