diff --git a/src/delphi/train/dataset_tokenization.py b/src/delphi/train/dataset_tokenization.py index 5ca3371f..ffc0a985 100644 --- a/src/delphi/train/dataset_tokenization.py +++ b/src/delphi/train/dataset_tokenization.py @@ -1,39 +1,50 @@ from collections import deque -from typing import Union from transformers import PreTrainedTokenizerBase +def extend_deque( + dq: deque[int], + context_size: int, + text_stories: list[str], + prompt_idx: int, + tokenizer: PreTrainedTokenizerBase, +) -> int: + while len(dq) < context_size and prompt_idx < len(text_stories): + text_story = text_stories[prompt_idx] + dq.extend( + tokenizer.encode(text_story, add_special_tokens=False) + + [tokenizer.eos_token_id] + ) + prompt_idx += 1 + return prompt_idx + + +def make_new_samples( + dq: deque[int], context_size: int, tokenizer: PreTrainedTokenizerBase +) -> list[list[int]]: + samples = [] + while len(dq) >= context_size: + sample = [tokenizer.bos_token_id] + for _ in range(context_size - 1): # peek at and not pop the last element + sample.append(dq.popleft()) + sample.append(dq[0]) + samples.append(sample) + return samples + + def get_tokenized_batches( - text_stories: Union[list[str], list[list[int]]], + text_stories: list[str], tokenizer: PreTrainedTokenizerBase, context_size: int, - input_tokenized=False, ) -> list[list[int]]: dq = deque() + prompt_idx = 0 samples = [] - prompt_idx = 0 while prompt_idx < len(text_stories): - while len(dq) < context_size: - text_story = text_stories[prompt_idx] - if not input_tokenized: - dq.extend( - tokenizer.encode(text_story, add_special_tokens=False) - + [tokenizer.eos_token_id] - ) - else: - dq.extend(text_story) - dq.append(tokenizer.eos_token_id) - prompt_idx += 1 - - sample = [tokenizer.bos_token_id] - for i in range(context_size - 1): # peek at and not pop the last element - sample.append(dq.popleft()) - sample.append(dq[0]) - - samples.append(sample) + prompt_idx = extend_deque(dq, context_size, text_stories, prompt_idx, tokenizer) + samples.extend(make_new_samples(dq, context_size, tokenizer)) - if dq: - samples.append([tokenizer.bos_token_id] + list(dq)) + # We discard the last chunk, so no processing on the remainder of the deque here return samples diff --git a/tests/train/test_tokenizer.py b/tests/train/test_tokenizer.py index ca0415f8..9b3552de 100644 --- a/tests/train/test_tokenizer.py +++ b/tests/train/test_tokenizer.py @@ -1,11 +1,58 @@ +import collections +import random + from transformers import AutoTokenizer -from delphi.train.dataset_tokenization import get_tokenized_batches +from delphi.eval.utils import load_validation_dataset +from delphi.train.dataset_tokenization import ( + extend_deque, + get_tokenized_batches, + make_new_samples, +) + +tokenizer = AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer") + + +def test_extend_deque(): + CTX_SIZE = 10 + dataset = load_validation_dataset("delphi-suite/tinystories-v2-clean") + text_stories = dataset["story"][:100] + prompt_idx = 0 + dq = collections.deque() + + while prompt_idx < len(text_stories): + prompt_idx = extend_deque(dq, CTX_SIZE, text_stories, prompt_idx, tokenizer) + if prompt_idx < len(text_stories) - 1: + assert len(dq) >= CTX_SIZE + while len(dq) >= CTX_SIZE: + for _ in range(CTX_SIZE - 1): + dq.popleft() + + +def test_make_new_sample(): + for _ in range(100): + total_tokens = random.randint(100, 1000) + context_size = random.randint(5, total_tokens // 2) + dq = collections.deque([random.randint(3, 1000) for _ in range(total_tokens)]) + samples = make_new_samples(dq, context_size, tokenizer) + tokens_cnt = 0 + for i, sample in enumerate(samples): + assert sample[0] == tokenizer.bos_token_id + if i > 0: + assert sample[1] == samples[i - 1][-1] + tokens_cnt += len(sample) + + # We discard the last chunk so the following lines are only for testing + tokens_cnt += 1 + len(dq) # the last batch with BOS in the beginning + assert tokens_cnt == total_tokens + ( + 2 * len(samples) + 1 + ) # BOS for each batch + overlapping of the last tokens in the batches + assert len(dq) > 0 # always leaving at least one element in the deque def test_get_tokenized_batches(): CTX_SIZE = 10 - tokenizer = AutoTokenizer.from_pretrained("delphi-suite/v0-llama2-tokenizer") + tokenizer = AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer") text_stories = [ "Once upon a", @@ -23,25 +70,3 @@ def test_get_tokenized_batches(): [1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2], ] assert get_tokenized_batches(text_stories, tokenizer, CTX_SIZE) == correct_batches - - tokenized_stories = [ - [1618, 3520, 2223, 3961, 853, 3376, 1820, 1442, 1573], - [46, 3515, 2941, 1637, 1377], - [1439, 3378, 3897, 3807, 343, 1140, 3843, 3848, 1343, 3812, 947, 2871, 1973], - [1163, 1358, 1930, 3590, 2216, 3659, 278], - [604, 2920, 1330, 2240, 786, 4088, 1416, 2122, 1556, 3501, 3159, 3427], - ] - correct_batches = [ - [1, 1618, 3520, 2223, 3961, 853, 3376, 1820, 1442, 1573, 2], - [1, 2, 46, 3515, 2941, 1637, 1377, 2, 1439, 3378, 3897], - [1, 3897, 3807, 343, 1140, 3843, 3848, 1343, 3812, 947, 2871], - [1, 2871, 1973, 2, 1163, 1358, 1930, 3590, 2216, 3659, 278], - [1, 278, 2, 604, 2920, 1330, 2240, 786, 4088, 1416, 2122], - [1, 2122, 1556, 3501, 3159, 3427, 2], - ] - assert ( - get_tokenized_batches( - tokenized_stories, tokenizer, CTX_SIZE, input_tokenized=True - ) - == correct_batches - )