Skip to content

Commit

Permalink
tokenize into seq_len instead of seq_len+1
Browse files Browse the repository at this point in the history
  • Loading branch information
jettjaniak committed May 14, 2024
1 parent 6702e9f commit 411756a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
11 changes: 4 additions & 7 deletions src/delphi/dataset/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
import itertools
from collections import deque
from collections.abc import Iterator
from pathlib import Path

from datasets import Dataset
from huggingface_hub import HfApi
from tqdm.auto import trange
from transformers import PreTrainedTokenizerBase


Expand Down Expand Up @@ -47,7 +44,7 @@ def extend_deque(
return doc_idx


def make_new_sample(deq: deque[int], context_size: int, bos_token_id: int) -> list[int]:
def make_new_sample(deq: deque[int], seq_len: int, bos_token_id: int) -> list[int]:
"""
Generates new sample for training by creating sequence of tokens
from the deque until the deque.
Expand All @@ -64,10 +61,10 @@ def make_new_sample(deq: deque[int], context_size: int, bos_token_id: int) -> li
list[int]: token sequence.
"""
sample = [bos_token_id]
# For the first (n-1) elements, pop from the left of the deque
# and add to the new sample, the n-th element will be retained
# For the first n-2 elements, pop from the left of the deque
# and add to the new sample, the (n-1)-th element will be retained
# in the deque for making the next sample.
for _ in range(context_size - 1):
for _ in range(seq_len - 2):
sample.append(deq.popleft())
sample.append(deq[0])
return sample
Expand Down
4 changes: 2 additions & 2 deletions tests/dataset/test_tokeniation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_make_new_sample(tokenizer):


def test_tokenize_dataset(tokenizer):
CTX_SIZE = 10
SEQ_LEN = 11
BATCH_SIZE = 2

documents = [
Expand All @@ -92,5 +92,5 @@ def test_tokenize_dataset(tokenizer):
[0, 284, 260, 2606, 1, 431, 440, 260, 399, 13, 402],
[0, 402, 284, 260, 1, 1370, 268, 415, 484, 412, 15],
]
actual = [x for x in tokenize_dataset(dataset, tokenizer, CTX_SIZE, BATCH_SIZE)]
actual = [x for x in tokenize_dataset(dataset, tokenizer, SEQ_LEN, BATCH_SIZE)]
assert actual == expected

0 comments on commit 411756a

Please sign in to comment.