-
Notifications
You must be signed in to change notification settings - Fork 2
/
util.py
40 lines (32 loc) · 1.73 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import math
from functools import partial
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
def pack(dataset, tokenizer, context_length, key='text'):
"""Concatenate ("pack") samples from a dataset into tokenized chunks of `context_length`.
Used for efficient training of causal models without padding. No special measures are taken
to disallow a sequence attending to a previous sequence. The model is left to learn the
unrelatedness of sequences from the presence of the start- and end-of-sequence-tokens
between the samples, following a similar convention from GPT-3 and T5.
See https://github.com/huggingface/transformers/issues/17726 for a feature request for
Hugging Face Transformers.
The incomplete final chunk is discarded.
:param dataset: Dataset of samples (iterable of dict-like, e.g. Hugging Face dataset)
:param tokenizer: Callable that tokenizes the samples (e.g. Hugging Face tokenizer)
:param context_length: number of tokens in packed sequences
:param key: key of the text field in the sample. Defaults to 'text'
:yield: dicts of packed input_ids, attention_masks and (self-supervised) labels
"""
cache = []
for row in dataset:
ids = tokenizer(row[key], max_length=None)['input_ids']
# end-of-sentence-token seems to have been present in Mistral 7B training data,
# but is not automatically added by the tokenizer
ids.append(2)
cache.extend(ids)
while len(cache) >= context_length:
chunk = cache[:context_length]
yield {'input_ids': chunk,
'attention_mask': [1] * context_length,
'labels': chunk}
cache = cache[context_length:]