diff --git a/scripts/tokenize_dataset.py b/scripts/tokenize_dataset.py new file mode 100755 index 00000000..5a3d01d5 --- /dev/null +++ b/scripts/tokenize_dataset.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +import argparse + +from datasets import Dataset +from transformers import AutoTokenizer + +from delphi.dataset.tokenization import tokenize_dataset +from delphi.eval.utils import load_validation_dataset + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + + parser.add_argument( + "--input-dataset-name", + type=str, + help="Text dataset from huggingface to tokenize", + ) + parser.add_argument( + "--output-dataset-name", + type=str, + help="Name of the tokenized dataset to upload to huggingface", + ) + parser.add_argument( + "--tokenizer-name", + type=str, + help="Name of the tokenizer from huggingface", + ) + parser.add_argument( + "--token", + type=str, + help="Hugging Face API token", + ) + parser.add_argument( + "--context-size", + type=int, + default=512, + help="Context size of the tokenized dataset as input of the model", + ) + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Batch size of text inputs into the tokenizer", + ) + parser.add_argument( + "--column-name", + type=str, + help="Name of the column containing text documents in the input dataset", + ) + args = parser.parse_args() + + input_dataset = load_validation_dataset(f"delphi-suite/{args.input_dataset_name}") + tokenizer = AutoTokenizer.from_pretrained(f"delphi-suite/{args.tokenizer_name}") + + if args.column_name: + text_docs = input_dataset[args.column_name] + else: + if len(input_dataset.column_names) > 1: + raise ValueError("There are more than one column in the specified dataset") + text_docs = input_dataset[input_dataset.column_names[0]] + + output_dataset = Dataset.from_dict( + { + "tokens": tokenize_dataset( + text_docs, + tokenizer, + context_size=args.context_size, + batch_size=args.batch_size, + ) + } + ) + + output_dataset.push_to_hub( + repo_id=f"delphi-suite/{args.output_dataset_name}", + private=False, + token=args.token, + ) diff --git a/src/delphi/dataset/tokenization.py b/src/delphi/dataset/tokenization.py new file mode 100644 index 00000000..b800b64b --- /dev/null +++ b/src/delphi/dataset/tokenization.py @@ -0,0 +1,107 @@ +from collections import deque +from typing import Optional + +from transformers import PreTrainedTokenizerBase + + +def extend_deque( + dq: deque[int], + context_size: int, + text_documents: list[str], + doc_idx: int, + tokenizer: PreTrainedTokenizerBase, + batch_size: int, +) -> int: + """ + Extends the deque with tokenized text documents until the deque grows large + enough to reach the context size, or until all text documents are processed. + + The usage of a deque here aims to save the memory as opposed to + load all the documents and tokenize them at once. + + Args: + dq: Deque to extend with tokenized tokens. + context_size: Size of the context(input sequences). + text_documents: List of (untokenized) text documents to be tokenized. + doc_idx: Index of the current text story. + tokenizer: Tokenizer to encode the text strings. + Returns: + int: Updated index in the text documents dataset. + """ + while len(dq) < context_size and doc_idx < len(text_documents): + text_doc = text_documents[doc_idx : doc_idx + batch_size] + batch_input_ids = tokenizer( + text_doc, return_attention_mask=False, add_special_tokens=False + )["input_ids"] + for input_ids in batch_input_ids: + dq.extend(input_ids + [tokenizer.eos_token_id]) + doc_idx += batch_size + return doc_idx + + +def make_new_samples( + dq: deque[int], context_size: int, bos_token_id: int +) -> list[list[int]]: + """ + Generates new samples for training by creating sequences of tokens + from the deque until the deque does not hold enough tokens to generate + another sample. + + Note: the model is unable to use the last token in an input sequence, + so we repeat this token in the next input sequence. + + Args: + dq: Deque containing tokenized tokens. + context_size: Size of the context (input sequences). + bos_token_id: bos_token_id of the tokenizer used. + + Returns: + list[list[int]]: List of token sequences of the same length(context_size). + """ + + samples = [] + while len(dq) >= context_size: + sample = [bos_token_id] + + # For the first (n-1) elements, pop from the left of the deque + # and add to the new sample, the n-th element will be retained + # in the deque for making the next sample. + for _ in range(context_size - 1): + sample.append(dq.popleft()) + sample.append(dq[0]) + + samples.append(sample) + return samples + + +def tokenize_dataset( + text_documents: list[str], + tokenizer: PreTrainedTokenizerBase, + context_size: int, + batch_size: int, +) -> list[list[int]]: + """ + Tokenizes the input text documents using the provided tokenizer and + generates token sequences of the specified length. + + Args: + text_documents: List[str], + tokenizer, + context_size, + + Returns: + list[list[int]]: List of token sequences of length equal to context_size. + """ + + dq = deque() + doc_idx = 0 + samples = [] + + while doc_idx < len(text_documents): + doc_idx = extend_deque( + dq, context_size, text_documents, doc_idx, tokenizer, batch_size + ) + samples.extend(make_new_samples(dq, context_size, tokenizer.bos_token_id)) + + # We discard the last chunk, so no processing on the remainder of the deque here + return samples diff --git a/tests/dataset/test_tokenizer.py b/tests/dataset/test_tokenizer.py new file mode 100644 index 00000000..99b2dcb3 --- /dev/null +++ b/tests/dataset/test_tokenizer.py @@ -0,0 +1,88 @@ +import collections +import random + +import pytest +from transformers import AutoTokenizer + +from delphi.dataset.tokenization import extend_deque, make_new_samples, tokenize_dataset + + +@pytest.fixture +def tokenizer(): + return AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer") + + +def test_extend_deque(tokenizer): + CTX_SIZE = 10 + BATCH_SIZE = 2 + # generate 100 random stories + text_stories = [ + " ".join( + [ + tokenizer.decode(random.randint(3, tokenizer.vocab_size)) + for _ in range(random.randint(100, 800)) + ] + ) + for _ in range(100) + ] + prompt_idx = 0 + dq = collections.deque() + + while prompt_idx < len(text_stories): + prompt_idx = extend_deque( + dq, CTX_SIZE, text_stories, prompt_idx, tokenizer, BATCH_SIZE + ) + if prompt_idx < len(text_stories) - 1: + # assert that the deque has grown large enough in each round + assert len(dq) >= CTX_SIZE + while len(dq) >= CTX_SIZE: + for _ in range(CTX_SIZE - 1): + dq.popleft() + + +def test_make_new_sample(tokenizer): + for _ in range(100): + total_tokens = random.randint(100, 1000) + context_size = random.randint(5, total_tokens // 2) + dq = collections.deque(random.choices(range(3, 1000), k=total_tokens)) + samples = make_new_samples(dq, context_size, tokenizer.bos_token_id) + tokens_cnt = 0 + for i, sample in enumerate(samples): + assert sample[0] == tokenizer.bos_token_id + if i > 0: + # assert that there is an overlap of the last token in the previous sample + # and the first token in its following sample + assert sample[1] == samples[i - 1][-1] + tokens_cnt += len(sample) + + # We discard the last chunk so the following lines are only for testing + tokens_cnt += 1 + len(dq) # the last batch with BOS in the beginning + assert tokens_cnt == total_tokens + ( + 2 * len(samples) + 1 + ) # BOS for each batch + overlapping of the last tokens in the batches + assert len(dq) > 0 # always leaving at least one element in the deque + + +def test_tokenize_dataset(tokenizer): + CTX_SIZE = 10 + BATCH_SIZE = 2 + + text_stories = [ + "Once upon a", + "Mother woke up alert. She put on her coat", + "Once upon a time, in a small town, there was a weird", + "Once upon a time, there was a", + "Sara and Tom are friends. They like to play in the park.", + ] + correct_batches = [ + [1, 432, 440, 261, 2, 367, 501, 1917, 372, 3398, 4037], + [1, 4037, 341, 577, 359, 342, 1854, 2, 432, 440, 261], + [1, 261, 403, 4045, 317, 261, 560, 1000, 4045, 406, 286], + [1, 286, 261, 2567, 2, 432, 440, 261, 403, 4045, 406], + [1, 406, 286, 261, 2, 787, 269, 396, 484, 415, 4037], + [1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2], + ] + assert ( + tokenize_dataset(text_stories, tokenizer, CTX_SIZE, BATCH_SIZE) + == correct_batches + )