diff --git a/scripts/tokenize_dataset.py b/scripts/tokenize_dataset.py index 7bafff84..5a3d01d5 100755 --- a/scripts/tokenize_dataset.py +++ b/scripts/tokenize_dataset.py @@ -5,7 +5,7 @@ from datasets import Dataset from transformers import AutoTokenizer -from delphi.dataset.tokenization import get_tokenized_batches +from delphi.dataset.tokenization import tokenize_dataset from delphi.eval.utils import load_validation_dataset if __name__ == "__main__": @@ -62,7 +62,7 @@ output_dataset = Dataset.from_dict( { - "tokens": get_tokenized_batches( + "tokens": tokenize_dataset( text_docs, tokenizer, context_size=args.context_size, diff --git a/src/delphi/dataset/tokenization.py b/src/delphi/dataset/tokenization.py index f340b4fa..b800b64b 100644 --- a/src/delphi/dataset/tokenization.py +++ b/src/delphi/dataset/tokenization.py @@ -74,7 +74,7 @@ def make_new_samples( return samples -def get_tokenized_batches( +def tokenize_dataset( text_documents: list[str], tokenizer: PreTrainedTokenizerBase, context_size: int, diff --git a/tests/dataset/test_tokenizer.py b/tests/dataset/test_tokenizer.py index 55f0fcbf..99b2dcb3 100644 --- a/tests/dataset/test_tokenizer.py +++ b/tests/dataset/test_tokenizer.py @@ -4,12 +4,7 @@ import pytest from transformers import AutoTokenizer -from delphi.dataset.tokenization import ( - extend_deque, - get_tokenized_batches, - make_new_samples, -) -from delphi.eval.utils import load_validation_dataset +from delphi.dataset.tokenization import extend_deque, make_new_samples, tokenize_dataset @pytest.fixture @@ -68,7 +63,7 @@ def test_make_new_sample(tokenizer): assert len(dq) > 0 # always leaving at least one element in the deque -def test_get_tokenized_batches(tokenizer): +def test_tokenize_dataset(tokenizer): CTX_SIZE = 10 BATCH_SIZE = 2 @@ -88,6 +83,6 @@ def test_get_tokenized_batches(tokenizer): [1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2], ] assert ( - get_tokenized_batches(text_stories, tokenizer, CTX_SIZE, BATCH_SIZE) + tokenize_dataset(text_stories, tokenizer, CTX_SIZE, BATCH_SIZE) == correct_batches )