Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function to tokenize text stories and split into batches #55

Merged
merged 8 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions scripts/tokenize_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3

import argparse

from datasets import Dataset
from transformers import AutoTokenizer

from delphi.dataset.tokenization import tokenize_dataset
from delphi.eval.utils import load_validation_dataset

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")

parser.add_argument(
"--input-dataset-name",
type=str,
help="Text dataset from huggingface to tokenize",
)
parser.add_argument(
"--output-dataset-name",
type=str,
help="Name of the tokenized dataset to upload to huggingface",
)
parser.add_argument(
"--tokenizer-name",
type=str,
help="Name of the tokenizer from huggingface",
)
parser.add_argument(
"--token",
type=str,
help="Hugging Face API token",
)
parser.add_argument(
"--context-size",
type=int,
default=512,
help="Context size of the tokenized dataset as input of the model",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="Batch size of text inputs into the tokenizer",
)
parser.add_argument(
"--column-name",
type=str,
help="Name of the column containing text documents in the input dataset",
)
args = parser.parse_args()

input_dataset = load_validation_dataset(f"delphi-suite/{args.input_dataset_name}")
tokenizer = AutoTokenizer.from_pretrained(f"delphi-suite/{args.tokenizer_name}")

if args.column_name:
text_docs = input_dataset[args.column_name]
else:
if len(input_dataset.column_names) > 1:
raise ValueError("There are more than one column in the specified dataset")
text_docs = input_dataset[input_dataset.column_names[0]]

output_dataset = Dataset.from_dict(
{
"tokens": tokenize_dataset(
text_docs,
tokenizer,
context_size=args.context_size,
batch_size=args.batch_size,
)
}
)

output_dataset.push_to_hub(
repo_id=f"delphi-suite/{args.output_dataset_name}",
private=False,
token=args.token,
)
107 changes: 107 additions & 0 deletions src/delphi/dataset/tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from collections import deque
from typing import Optional

from transformers import PreTrainedTokenizerBase


def extend_deque(
dq: deque[int],
context_size: int,
text_documents: list[str],
doc_idx: int,
tokenizer: PreTrainedTokenizerBase,
batch_size: int,
) -> int:
"""
Extends the deque with tokenized text documents until the deque grows large
enough to reach the context size, or until all text documents are processed.

The usage of a deque here aims to save the memory as opposed to
load all the documents and tokenize them at once.

Args:
dq: Deque to extend with tokenized tokens.
context_size: Size of the context(input sequences).
text_documents: List of (untokenized) text documents to be tokenized.
doc_idx: Index of the current text story.
tokenizer: Tokenizer to encode the text strings.
Returns:
int: Updated index in the text documents dataset.
"""
while len(dq) < context_size and doc_idx < len(text_documents):
text_doc = text_documents[doc_idx : doc_idx + batch_size]
batch_input_ids = tokenizer(
text_doc, return_attention_mask=False, add_special_tokens=False
)["input_ids"]
for input_ids in batch_input_ids:
dq.extend(input_ids + [tokenizer.eos_token_id])
doc_idx += batch_size
return doc_idx


def make_new_samples(
dq: deque[int], context_size: int, bos_token_id: int
) -> list[list[int]]:
"""
Generates new samples for training by creating sequences of tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the explanation a bit confusing (or I am confused).
This function does not generate entirely new content, correct?
It reduces the length pre-existing samples by clipping it to context_size, correct?
A different wording would make it easier to understand. You could also consider renaming the function make_new_samples to something such as clip_samples or similar. But this is up to you.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, my previous assumption was wrong. This function does not clip the samples, but rather splits them, prepending a BOS token as well as adding the final token from the previous split as well.

from the deque until the deque does not hold enough tokens to generate
another sample.

Note: the model is unable to use the last token in an input sequence,
so we repeat this token in the next input sequence.

Args:
dq: Deque containing tokenized tokens.
context_size: Size of the context (input sequences).
bos_token_id: bos_token_id of the tokenizer used.

Returns:
list[list[int]]: List of token sequences of the same length(context_size).
"""

samples = []
while len(dq) >= context_size:
sample = [bos_token_id]

# For the first (n-1) elements, pop from the left of the deque
# and add to the new sample, the n-th element will be retained
# in the deque for making the next sample.
for _ in range(context_size - 1):
sample.append(dq.popleft())
sample.append(dq[0])

samples.append(sample)
return samples


def tokenize_dataset(
text_documents: list[str],
tokenizer: PreTrainedTokenizerBase,
context_size: int,
batch_size: int,
) -> list[list[int]]:
"""
Tokenizes the input text documents using the provided tokenizer and
generates token sequences of the specified length.

Args:
text_documents: List[str],
tokenizer,
context_size,

Returns:
list[list[int]]: List of token sequences of length equal to context_size.
"""

dq = deque()
doc_idx = 0
samples = []

while doc_idx < len(text_documents):
doc_idx = extend_deque(
dq, context_size, text_documents, doc_idx, tokenizer, batch_size
)
samples.extend(make_new_samples(dq, context_size, tokenizer.bos_token_id))

# We discard the last chunk, so no processing on the remainder of the deque here
return samples
88 changes: 88 additions & 0 deletions tests/dataset/test_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import collections
import random

import pytest
from transformers import AutoTokenizer

from delphi.dataset.tokenization import extend_deque, make_new_samples, tokenize_dataset


@pytest.fixture
def tokenizer():
return AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer")


def test_extend_deque(tokenizer):
CTX_SIZE = 10
BATCH_SIZE = 2
# generate 100 random stories
text_stories = [
" ".join(
[
tokenizer.decode(random.randint(3, tokenizer.vocab_size))
for _ in range(random.randint(100, 800))
]
)
for _ in range(100)
]
prompt_idx = 0
dq = collections.deque()

while prompt_idx < len(text_stories):
prompt_idx = extend_deque(
dq, CTX_SIZE, text_stories, prompt_idx, tokenizer, BATCH_SIZE
)
if prompt_idx < len(text_stories) - 1:
# assert that the deque has grown large enough in each round
assert len(dq) >= CTX_SIZE
while len(dq) >= CTX_SIZE:
for _ in range(CTX_SIZE - 1):
dq.popleft()


def test_make_new_sample(tokenizer):
for _ in range(100):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this done one hundred times?

total_tokens = random.randint(100, 1000)
context_size = random.randint(5, total_tokens // 2)
dq = collections.deque(random.choices(range(3, 1000), k=total_tokens))
samples = make_new_samples(dq, context_size, tokenizer.bos_token_id)
tokens_cnt = 0
for i, sample in enumerate(samples):
assert sample[0] == tokenizer.bos_token_id
if i > 0:
# assert that there is an overlap of the last token in the previous sample
# and the first token in its following sample
assert sample[1] == samples[i - 1][-1]
tokens_cnt += len(sample)

# We discard the last chunk so the following lines are only for testing
tokens_cnt += 1 + len(dq) # the last batch with BOS in the beginning
assert tokens_cnt == total_tokens + (
2 * len(samples) + 1
) # BOS for each batch + overlapping of the last tokens in the batches
assert len(dq) > 0 # always leaving at least one element in the deque


def test_tokenize_dataset(tokenizer):
CTX_SIZE = 10
BATCH_SIZE = 2

text_stories = [
"Once upon a",
"Mother woke up alert. She put on her coat",
"Once upon a time, in a small town, there was a weird",
"Once upon a time, there was a",
"Sara and Tom are friends. They like to play in the park.",
]
correct_batches = [
[1, 432, 440, 261, 2, 367, 501, 1917, 372, 3398, 4037],
[1, 4037, 341, 577, 359, 342, 1854, 2, 432, 440, 261],
[1, 261, 403, 4045, 317, 261, 560, 1000, 4045, 406, 286],
[1, 286, 261, 2567, 2, 432, 440, 261, 403, 4045, 406],
[1, 406, 286, 261, 2, 787, 269, 396, 484, 415, 4037],
[1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2],
]
assert (
tokenize_dataset(text_stories, tokenizer, CTX_SIZE, BATCH_SIZE)
== correct_batches
)
Loading