From 1de746778be864fcac32f552b8e937945e80ec3f Mon Sep 17 00:00:00 2001 From: Siwei Li Date: Sun, 17 Mar 2024 18:35:53 -0700 Subject: [PATCH] Add docstrings to the functions --- src/delphi/train/dataset_tokenization.py | 46 ++++++++++++++++++++++++ tests/train/test_tokenizer.py | 13 +++++-- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/src/delphi/train/dataset_tokenization.py b/src/delphi/train/dataset_tokenization.py index ffc0a985..23309e4d 100644 --- a/src/delphi/train/dataset_tokenization.py +++ b/src/delphi/train/dataset_tokenization.py @@ -10,6 +10,23 @@ def extend_deque( prompt_idx: int, tokenizer: PreTrainedTokenizerBase, ) -> int: + """ + Extends the deque with tokenized text stories until the deque grows large + enough to reach the context size, or until all text stories are processed. + + The usage of a deque here aims to save the memory as opposed to + load all the stories and tokenize them at once. + + Args: + dq (deque[int]): Deque to extend with tokenized tokens. + context_size (int): Size of the context(input sequences). + text_stories (list[str]): List of (untokenized) text stories to be tokenized. + prompt_idx (int): Index of the current text story. + tokenizer (PreTrainedTokenizerBase): Tokenizer to encode the text strings. + + Returns: + int: Updated index in the text stories dataset. + """ while len(dq) < context_size and prompt_idx < len(text_stories): text_story = text_stories[prompt_idx] dq.extend( @@ -23,6 +40,22 @@ def extend_deque( def make_new_samples( dq: deque[int], context_size: int, tokenizer: PreTrainedTokenizerBase ) -> list[list[int]]: + """ + Generates new samples for training by creating sequences of tokens + from the deque until the deque is empty. + + Note: the model is unable to use the last token in an input sequence, + so we repeat this token in the next input sequence. + + Args: + dq (deque[int]): Deque containing tokenized tokens. + context_size (int): Size of the context (input sequences). + tokenizer (PreTrainedTokenizerBase): Tokenizer to encode the text strings. + + Returns: + list[list[int]]: List of token sequences of the same length(context_size). + """ + samples = [] while len(dq) >= context_size: sample = [tokenizer.bos_token_id] @@ -38,6 +71,19 @@ def get_tokenized_batches( tokenizer: PreTrainedTokenizerBase, context_size: int, ) -> list[list[int]]: + """ + Tokenizes the input text stories using the provided tokenizer and + generates token sequences of the specified length. + + Args: + text_stories (list[str]): List of text stories to be tokenized. + tokenizer (PreTrainedTokenizerBase): Tokenizer to encode the text strings. + context_size (int): Size of the context (input sequences). + + Returns: + list[list[int]]: List of token sequences of length equal to context_size. + """ + dq = deque() prompt_idx = 0 samples = [] diff --git a/tests/train/test_tokenizer.py b/tests/train/test_tokenizer.py index 9b3552de..a10f72a4 100644 --- a/tests/train/test_tokenizer.py +++ b/tests/train/test_tokenizer.py @@ -15,14 +15,21 @@ def test_extend_deque(): CTX_SIZE = 10 - dataset = load_validation_dataset("delphi-suite/tinystories-v2-clean") - text_stories = dataset["story"][:100] + # generate 100 random stories + text_stories = [ + [ + random.randint(3, tokenizer.vocab_size) + for _ in range(random.randint(100, 800)) + ] + for _ in range(100) + ] prompt_idx = 0 dq = collections.deque() while prompt_idx < len(text_stories): prompt_idx = extend_deque(dq, CTX_SIZE, text_stories, prompt_idx, tokenizer) if prompt_idx < len(text_stories) - 1: + # assert that the deque has grown large enough in each round assert len(dq) >= CTX_SIZE while len(dq) >= CTX_SIZE: for _ in range(CTX_SIZE - 1): @@ -39,6 +46,8 @@ def test_make_new_sample(): for i, sample in enumerate(samples): assert sample[0] == tokenizer.bos_token_id if i > 0: + # assert that there is an overlap of the last token in the previous sample + # and the first token in its following sample assert sample[1] == samples[i - 1][-1] tokens_cnt += len(sample)