Skip to content

Commit

Permalink
Split the tokenization function into two parts, fixing the while-loop…
Browse files Browse the repository at this point in the history
… issues
  • Loading branch information
Siwei Li authored and siwei-li committed Mar 18, 2024
1 parent 194e538 commit 5fe3c9f
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 48 deletions.
59 changes: 35 additions & 24 deletions src/delphi/train/dataset_tokenization.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,50 @@
from collections import deque
from typing import Union

from transformers import PreTrainedTokenizerBase


def extend_deque(
dq: deque[int],
context_size: int,
text_stories: list[str],
prompt_idx: int,
tokenizer: PreTrainedTokenizerBase,
) -> int:
while len(dq) < context_size and prompt_idx < len(text_stories):
text_story = text_stories[prompt_idx]
dq.extend(
tokenizer.encode(text_story, add_special_tokens=False)
+ [tokenizer.eos_token_id]
)
prompt_idx += 1
return prompt_idx


def make_new_samples(
dq: deque[int], context_size: int, tokenizer: PreTrainedTokenizerBase
) -> list[list[int]]:
samples = []
while len(dq) >= context_size:
sample = [tokenizer.bos_token_id]
for _ in range(context_size - 1): # peek at and not pop the last element
sample.append(dq.popleft())
sample.append(dq[0])
samples.append(sample)
return samples


def get_tokenized_batches(
text_stories: Union[list[str], list[list[int]]],
text_stories: list[str],
tokenizer: PreTrainedTokenizerBase,
context_size: int,
input_tokenized=False,
) -> list[list[int]]:
dq = deque()
prompt_idx = 0
samples = []

prompt_idx = 0
while prompt_idx < len(text_stories):
while len(dq) < context_size:
text_story = text_stories[prompt_idx]
if not input_tokenized:
dq.extend(
tokenizer.encode(text_story, add_special_tokens=False)
+ [tokenizer.eos_token_id]
)
else:
dq.extend(text_story)
dq.append(tokenizer.eos_token_id)
prompt_idx += 1

sample = [tokenizer.bos_token_id]
for i in range(context_size - 1): # peek at and not pop the last element
sample.append(dq.popleft())
sample.append(dq[0])

samples.append(sample)
prompt_idx = extend_deque(dq, context_size, text_stories, prompt_idx, tokenizer)
samples.extend(make_new_samples(dq, context_size, tokenizer))

if dq:
samples.append([tokenizer.bos_token_id] + list(dq))
# We discard the last chunk, so no processing on the remainder of the deque here
return samples
73 changes: 49 additions & 24 deletions tests/train/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,58 @@
import collections
import random

from transformers import AutoTokenizer

from delphi.train.dataset_tokenization import get_tokenized_batches
from delphi.eval.utils import load_validation_dataset
from delphi.train.dataset_tokenization import (
extend_deque,
get_tokenized_batches,
make_new_samples,
)

tokenizer = AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer")


def test_extend_deque():
CTX_SIZE = 10
dataset = load_validation_dataset("delphi-suite/tinystories-v2-clean")
text_stories = dataset["story"][:100]
prompt_idx = 0
dq = collections.deque()

while prompt_idx < len(text_stories):
prompt_idx = extend_deque(dq, CTX_SIZE, text_stories, prompt_idx, tokenizer)
if prompt_idx < len(text_stories) - 1:
assert len(dq) >= CTX_SIZE
while len(dq) >= CTX_SIZE:
for _ in range(CTX_SIZE - 1):
dq.popleft()


def test_make_new_sample():
for _ in range(100):
total_tokens = random.randint(100, 1000)
context_size = random.randint(5, total_tokens // 2)
dq = collections.deque([random.randint(3, 1000) for _ in range(total_tokens)])
samples = make_new_samples(dq, context_size, tokenizer)
tokens_cnt = 0
for i, sample in enumerate(samples):
assert sample[0] == tokenizer.bos_token_id
if i > 0:
assert sample[1] == samples[i - 1][-1]
tokens_cnt += len(sample)

# We discard the last chunk so the following lines are only for testing
tokens_cnt += 1 + len(dq) # the last batch with BOS in the beginning
assert tokens_cnt == total_tokens + (
2 * len(samples) + 1
) # BOS for each batch + overlapping of the last tokens in the batches
assert len(dq) > 0 # always leaving at least one element in the deque


def test_get_tokenized_batches():
CTX_SIZE = 10
tokenizer = AutoTokenizer.from_pretrained("delphi-suite/v0-llama2-tokenizer")
tokenizer = AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer")

text_stories = [
"Once upon a",
Expand All @@ -23,25 +70,3 @@ def test_get_tokenized_batches():
[1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2],
]
assert get_tokenized_batches(text_stories, tokenizer, CTX_SIZE) == correct_batches

tokenized_stories = [
[1618, 3520, 2223, 3961, 853, 3376, 1820, 1442, 1573],
[46, 3515, 2941, 1637, 1377],
[1439, 3378, 3897, 3807, 343, 1140, 3843, 3848, 1343, 3812, 947, 2871, 1973],
[1163, 1358, 1930, 3590, 2216, 3659, 278],
[604, 2920, 1330, 2240, 786, 4088, 1416, 2122, 1556, 3501, 3159, 3427],
]
correct_batches = [
[1, 1618, 3520, 2223, 3961, 853, 3376, 1820, 1442, 1573, 2],
[1, 2, 46, 3515, 2941, 1637, 1377, 2, 1439, 3378, 3897],
[1, 3897, 3807, 343, 1140, 3843, 3848, 1343, 3812, 947, 2871],
[1, 2871, 1973, 2, 1163, 1358, 1930, 3590, 2216, 3659, 278],
[1, 278, 2, 604, 2920, 1330, 2240, 786, 4088, 1416, 2122],
[1, 2122, 1556, 3501, 3159, 3427, 2],
]
assert (
get_tokenized_batches(
tokenized_stories, tokenizer, CTX_SIZE, input_tokenized=True
)
== correct_batches
)

0 comments on commit 5fe3c9f

Please sign in to comment.