Skip to content

Commit

Permalink
Update function name
Browse files Browse the repository at this point in the history
  • Loading branch information
Siwei Li authored and joshuawe committed Mar 30, 2024
1 parent c5c0e09 commit ba1b109
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 11 deletions.
4 changes: 2 additions & 2 deletions scripts/tokenize_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datasets import Dataset
from transformers import AutoTokenizer

from delphi.dataset.tokenization import get_tokenized_batches
from delphi.dataset.tokenization import tokenize_dataset
from delphi.eval.utils import load_validation_dataset

if __name__ == "__main__":
Expand Down Expand Up @@ -62,7 +62,7 @@

output_dataset = Dataset.from_dict(
{
"tokens": get_tokenized_batches(
"tokens": tokenize_dataset(
text_docs,
tokenizer,
context_size=args.context_size,
Expand Down
2 changes: 1 addition & 1 deletion src/delphi/dataset/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def make_new_samples(
return samples


def get_tokenized_batches(
def tokenize_dataset(
text_documents: list[str],
tokenizer: PreTrainedTokenizerBase,
context_size: int,
Expand Down
11 changes: 3 additions & 8 deletions tests/dataset/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,7 @@
import pytest
from transformers import AutoTokenizer

from delphi.dataset.tokenization import (
extend_deque,
get_tokenized_batches,
make_new_samples,
)
from delphi.eval.utils import load_validation_dataset
from delphi.dataset.tokenization import extend_deque, make_new_samples, tokenize_dataset


@pytest.fixture
Expand Down Expand Up @@ -68,7 +63,7 @@ def test_make_new_sample(tokenizer):
assert len(dq) > 0 # always leaving at least one element in the deque


def test_get_tokenized_batches(tokenizer):
def test_tokenize_dataset(tokenizer):
CTX_SIZE = 10
BATCH_SIZE = 2

Expand All @@ -88,6 +83,6 @@ def test_get_tokenized_batches(tokenizer):
[1, 4037, 311, 519, 268, 326, 317, 264, 525, 4037, 2],
]
assert (
get_tokenized_batches(text_stories, tokenizer, CTX_SIZE, BATCH_SIZE)
tokenize_dataset(text_stories, tokenizer, CTX_SIZE, BATCH_SIZE)
== correct_batches
)

0 comments on commit ba1b109

Please sign in to comment.