-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add function to tokenize text stories and split into batches #55
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
e7ab2e8
Add function to tokenize text stories and split into batches
2e69942
Split the tokenization function into two parts, fixing the while-loop…
7609e4f
Add docstrings to the functions
b98f81d
Minor edits in the code, fix the test
9869df7
Uses batch_encode() method to save time
7bd9ef9
Add script to upload to delphi-suite/batched-tokenized-stories
c5c0e09
Remove the test file in tests/train to pass pytest
ba1b109
Update function name
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
|
||
from datasets import Dataset | ||
from transformers import AutoTokenizer | ||
|
||
from delphi.dataset.tokenization import tokenize_dataset | ||
from delphi.eval.utils import load_validation_dataset | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="") | ||
|
||
parser.add_argument( | ||
"--input-dataset-name", | ||
type=str, | ||
help="Text dataset from huggingface to tokenize", | ||
) | ||
parser.add_argument( | ||
"--output-dataset-name", | ||
type=str, | ||
help="Name of the tokenized dataset to upload to huggingface", | ||
) | ||
parser.add_argument( | ||
"--tokenizer-name", | ||
type=str, | ||
help="Name of the tokenizer from huggingface", | ||
) | ||
parser.add_argument( | ||
"--token", | ||
type=str, | ||
help="Hugging Face API token", | ||
) | ||
parser.add_argument( | ||
"--context-size", | ||
type=int, | ||
default=512, | ||
help="Context size of the tokenized dataset as input of the model", | ||
) | ||
parser.add_argument( | ||
"--batch-size", | ||
type=int, | ||
default=50, | ||
help="Batch size of text inputs into the tokenizer", | ||
) | ||
parser.add_argument( | ||
"--column-name", | ||
type=str, | ||
help="Name of the column containing text documents in the input dataset", | ||
) | ||
args = parser.parse_args() | ||
|
||
input_dataset = load_validation_dataset(f"delphi-suite/{args.input_dataset_name}") | ||
tokenizer = AutoTokenizer.from_pretrained(f"delphi-suite/{args.tokenizer_name}") | ||
|
||
if args.column_name: | ||
text_docs = input_dataset[args.column_name] | ||
else: | ||
if len(input_dataset.column_names) > 1: | ||
raise ValueError("There are more than one column in the specified dataset") | ||
text_docs = input_dataset[input_dataset.column_names[0]] | ||
|
||
output_dataset = Dataset.from_dict( | ||
{ | ||
"tokens": tokenize_dataset( | ||
text_docs, | ||
tokenizer, | ||
context_size=args.context_size, | ||
batch_size=args.batch_size, | ||
) | ||
} | ||
) | ||
|
||
output_dataset.push_to_hub( | ||
repo_id=f"delphi-suite/{args.output_dataset_name}", | ||
private=False, | ||
token=args.token, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from collections import deque | ||
from typing import Optional | ||
|
||
from transformers import PreTrainedTokenizerBase | ||
|
||
|
||
def extend_deque( | ||
dq: deque[int], | ||
context_size: int, | ||
text_documents: list[str], | ||
doc_idx: int, | ||
tokenizer: PreTrainedTokenizerBase, | ||
batch_size: int, | ||
) -> int: | ||
""" | ||
Extends the deque with tokenized text documents until the deque grows large | ||
enough to reach the context size, or until all text documents are processed. | ||
|
||
The usage of a deque here aims to save the memory as opposed to | ||
load all the documents and tokenize them at once. | ||
|
||
Args: | ||
dq: Deque to extend with tokenized tokens. | ||
context_size: Size of the context(input sequences). | ||
text_documents: List of (untokenized) text documents to be tokenized. | ||
doc_idx: Index of the current text story. | ||
tokenizer: Tokenizer to encode the text strings. | ||
Returns: | ||
int: Updated index in the text documents dataset. | ||
""" | ||
while len(dq) < context_size and doc_idx < len(text_documents): | ||
text_doc = text_documents[doc_idx : doc_idx + batch_size] | ||
batch_input_ids = tokenizer( | ||
text_doc, return_attention_mask=False, add_special_tokens=False | ||
)["input_ids"] | ||
for input_ids in batch_input_ids: | ||
dq.extend(input_ids + [tokenizer.eos_token_id]) | ||
doc_idx += batch_size | ||
return doc_idx | ||
|
||
|
||
def make_new_samples( | ||
dq: deque[int], context_size: int, bos_token_id: int | ||
) -> list[list[int]]: | ||
""" | ||
Generates new samples for training by creating sequences of tokens | ||
from the deque until the deque does not hold enough tokens to generate | ||
another sample. | ||
|
||
Note: the model is unable to use the last token in an input sequence, | ||
so we repeat this token in the next input sequence. | ||
|
||
Args: | ||
dq: Deque containing tokenized tokens. | ||
context_size: Size of the context (input sequences). | ||
bos_token_id: bos_token_id of the tokenizer used. | ||
|
||
Returns: | ||
list[list[int]]: List of token sequences of the same length(context_size). | ||
""" | ||
|
||
samples = [] | ||
while len(dq) >= context_size: | ||
sample = [bos_token_id] | ||
|
||
# For the first (n-1) elements, pop from the left of the deque | ||
# and add to the new sample, the n-th element will be retained | ||
# in the deque for making the next sample. | ||
for _ in range(context_size - 1): | ||
sample.append(dq.popleft()) | ||
sample.append(dq[0]) | ||
|
||
samples.append(sample) | ||
return samples | ||
|
||
|
||
def tokenize_dataset( | ||
text_documents: list[str], | ||
tokenizer: PreTrainedTokenizerBase, | ||
context_size: int, | ||
batch_size: int, | ||
) -> list[list[int]]: | ||
""" | ||
Tokenizes the input text documents using the provided tokenizer and | ||
generates token sequences of the specified length. | ||
|
||
Args: | ||
text_documents: List[str], | ||
tokenizer, | ||
context_size, | ||
|
||
Returns: | ||
list[list[int]]: List of token sequences of length equal to context_size. | ||
""" | ||
|
||
dq = deque() | ||
doc_idx = 0 | ||
samples = [] | ||
|
||
while doc_idx < len(text_documents): | ||
doc_idx = extend_deque( | ||
dq, context_size, text_documents, doc_idx, tokenizer, batch_size | ||
) | ||
samples.extend(make_new_samples(dq, context_size, tokenizer.bos_token_id)) | ||
|
||
# We discard the last chunk, so no processing on the remainder of the deque here | ||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import collections | ||
import random | ||
|
||
import pytest | ||
from transformers import AutoTokenizer | ||
|
||
from delphi.dataset.tokenization import extend_deque, make_new_samples, tokenize_dataset | ||
|
||
|
||
@pytest.fixture | ||
def tokenizer(): | ||
return AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer") | ||
|
||
|
||
def test_extend_deque(tokenizer): | ||
CTX_SIZE = 10 | ||
BATCH_SIZE = 2 | ||
# generate 100 random stories | ||
text_stories = [ | ||
" ".join( | ||
[ | ||
tokenizer.decode(random.randint(3, tokenizer.vocab_size)) | ||
for _ in range(random.randint(100, 800)) | ||
] | ||
) | ||
for _ in range(100) | ||
] | ||
prompt_idx = 0 | ||
dq = collections.deque() | ||
|
||
while prompt_idx < len(text_stories): | ||
prompt_idx = extend_deque( | ||
dq, CTX_SIZE, text_stories, prompt_idx, tokenizer, BATCH_SIZE | ||
) | ||
if prompt_idx < len(text_stories) - 1: | ||
# assert that the deque has grown large enough in each round | ||
assert len(dq) >= CTX_SIZE | ||
while len(dq) >= CTX_SIZE: | ||
for _ in range(CTX_SIZE - 1): | ||
dq.popleft() | ||
|
||
|
||
def test_make_new_sample(tokenizer): | ||
for _ in range(100): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this done one hundred times? |
||
total_tokens = random.randint(100, 1000) | ||
context_size = random.randint(5, total_tokens // 2) | ||
dq = collections.deque(random.choices(range(3, 1000), k=total_tokens)) | ||
samples = make_new_samples(dq, context_size, tokenizer.bos_token_id) | ||
tokens_cnt = 0 | ||
for i, sample in enumerate(samples): | ||
assert sample[0] == tokenizer.bos_token_id | ||
if i > 0: | ||
# assert that there is an overlap of the last token in the previous sample | ||
# and the first token in its following sample | ||
assert sample[1] == samples[i - 1][-1] | ||
tokens_cnt += len(sample) | ||
|
||
# We discard the last chunk so the following lines are only for testing | ||
tokens_cnt += 1 + len(dq) # the last batch with BOS in the beginning | ||
assert tokens_cnt == total_tokens + ( | ||
2 * len(samples) + 1 | ||
) # BOS for each batch + overlapping of the last tokens in the batches | ||
assert len(dq) > 0 # always leaving at least one element in the deque | ||
|
||
|
||
def test_tokenize_dataset(tokenizer): | ||
CTX_SIZE = 10 | ||
BATCH_SIZE = 2 | ||
|
||
text_stories = [ | ||
"Once upon a", | ||
"Mother woke up alert. She put on her coat", | ||
"Once upon a time, in a small town, there was a weird", | ||
"Once upon a time, there was a", | ||
"Sara and Tom are friends. They like to play in the park.", | ||
] | ||
correct_batches = [ | ||
[1, 432, 440, 261, 2, 367, 501, 1917, 372, 3398, 4037], | ||
[1, 4037, 341, 577, 359, 342, 1854, 2, 432, 440, 261], | ||
[1, 261, 403, 4045, 317, 261, 560, 1000, 4045, 406, 286], | ||
[1, 286, 261, 2567, 2, 432, 440, 261, 403, 4045, 406], | ||
[1, 406, 286, 261, 2, 787, 269, 396, 484, 415, 4037], | ||
[1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2], | ||
] | ||
assert ( | ||
tokenize_dataset(text_stories, tokenizer, CTX_SIZE, BATCH_SIZE) | ||
== correct_batches | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the explanation a bit confusing (or I am confused).
This function does not generate entirely new content, correct?
It reduces the length pre-existing samples by clipping it to
context_size
, correct?A different wording would make it easier to understand. You could also consider renaming the function
make_new_samples
to something such asclip_samples
or similar. But this is up to you.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, my previous assumption was wrong. This function does not clip the samples, but rather splits them, prepending a BOS token as well as adding the final token from the previous split as well.