From e7152762fd647322724e8125908df02c5c3bd096 Mon Sep 17 00:00:00 2001 From: Siwei Li Date: Sat, 9 Mar 2024 21:03:15 -0800 Subject: [PATCH] Add function to tokenize text stories and split into batches --- src/delphi/train/dataset_tokenization.py | 39 ++++++++++++++++++++ tests/train/test_tokenizer.py | 47 ++++++++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 src/delphi/train/dataset_tokenization.py create mode 100644 tests/train/test_tokenizer.py diff --git a/src/delphi/train/dataset_tokenization.py b/src/delphi/train/dataset_tokenization.py new file mode 100644 index 00000000..5ca3371f --- /dev/null +++ b/src/delphi/train/dataset_tokenization.py @@ -0,0 +1,39 @@ +from collections import deque +from typing import Union + +from transformers import PreTrainedTokenizerBase + + +def get_tokenized_batches( + text_stories: Union[list[str], list[list[int]]], + tokenizer: PreTrainedTokenizerBase, + context_size: int, + input_tokenized=False, +) -> list[list[int]]: + dq = deque() + samples = [] + + prompt_idx = 0 + while prompt_idx < len(text_stories): + while len(dq) < context_size: + text_story = text_stories[prompt_idx] + if not input_tokenized: + dq.extend( + tokenizer.encode(text_story, add_special_tokens=False) + + [tokenizer.eos_token_id] + ) + else: + dq.extend(text_story) + dq.append(tokenizer.eos_token_id) + prompt_idx += 1 + + sample = [tokenizer.bos_token_id] + for i in range(context_size - 1): # peek at and not pop the last element + sample.append(dq.popleft()) + sample.append(dq[0]) + + samples.append(sample) + + if dq: + samples.append([tokenizer.bos_token_id] + list(dq)) + return samples diff --git a/tests/train/test_tokenizer.py b/tests/train/test_tokenizer.py new file mode 100644 index 00000000..ca0415f8 --- /dev/null +++ b/tests/train/test_tokenizer.py @@ -0,0 +1,47 @@ +from transformers import AutoTokenizer + +from delphi.train.dataset_tokenization import get_tokenized_batches + + +def test_get_tokenized_batches(): + CTX_SIZE = 10 + tokenizer = AutoTokenizer.from_pretrained("delphi-suite/v0-llama2-tokenizer") + + text_stories = [ + "Once upon a", + "Mother woke up alert. She put on her coat", + "Once upon a time, in a small town, there was a weird", + "Once upon a time, there was a", + "Sara and Tom are friends. They like to play in the park.", + ] + correct_batches = [ + [1, 432, 440, 261, 2, 367, 501, 1917, 372, 3398, 4037], + [1, 4037, 341, 577, 359, 342, 1854, 2, 432, 440, 261], + [1, 261, 403, 4045, 317, 261, 560, 1000, 4045, 406, 286], + [1, 286, 261, 2567, 2, 432, 440, 261, 403, 4045, 406], + [1, 406, 286, 261, 2, 787, 269, 396, 484, 415, 4037], + [1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2], + ] + assert get_tokenized_batches(text_stories, tokenizer, CTX_SIZE) == correct_batches + + tokenized_stories = [ + [1618, 3520, 2223, 3961, 853, 3376, 1820, 1442, 1573], + [46, 3515, 2941, 1637, 1377], + [1439, 3378, 3897, 3807, 343, 1140, 3843, 3848, 1343, 3812, 947, 2871, 1973], + [1163, 1358, 1930, 3590, 2216, 3659, 278], + [604, 2920, 1330, 2240, 786, 4088, 1416, 2122, 1556, 3501, 3159, 3427], + ] + correct_batches = [ + [1, 1618, 3520, 2223, 3961, 853, 3376, 1820, 1442, 1573, 2], + [1, 2, 46, 3515, 2941, 1637, 1377, 2, 1439, 3378, 3897], + [1, 3897, 3807, 343, 1140, 3843, 3848, 1343, 3812, 947, 2871], + [1, 2871, 1973, 2, 1163, 1358, 1930, 3590, 2216, 3659, 278], + [1, 278, 2, 604, 2920, 1330, 2240, 786, 4088, 1416, 2122], + [1, 2122, 1556, 3501, 3159, 3427, 2], + ] + assert ( + get_tokenized_batches( + tokenized_stories, tokenizer, CTX_SIZE, input_tokenized=True + ) + == correct_batches + )