-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdataset_loader.py
279 lines (243 loc) · 10.7 KB
/
dataset_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
from typing import List
from torch.utils.data import IterableDataset
from tqdm import tqdm
import random
import torch
class ConstantLengthDataset(IterableDataset):
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
num_of_sequences (int): Number of token sequences to keep in buffer.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
"""
def __init__(
self,
tokenizer,
dataset,
infinite=False,
seq_length=2048,
num_of_sequences=1024,
chars_per_token=3.6,
content_field="content",
concat_token_id=None,
):
self.tokenizer = tokenizer
self.concat_token_id = concat_token_id if concat_token_id is not None else tokenizer.eos_token_id
print(f"Concat token id (EOS token): {self.concat_token_id}")
self.dataset = dataset
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.content_field = content_field
def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(next(iterator)[self.content_field])
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
else:
more_examples = False
break
tokenized_inputs = self.tokenizer(
buffer, truncation=False)["input_ids"]
all_token_ids = []
examples = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i: i + self.seq_length]
if len(input_ids) == self.seq_length:
examples.append(input_ids)
random.shuffle(examples)
for input_ids in examples:
self.current_size += 1
yield {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"labels": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.ones(len(input_ids)),
}
def get_tokenizer(self):
return self.tokenizer
class PaddedDataset(IterableDataset):
"""
Unlike ConstantLengthDataset this dataset returns padded sequences of tokens concatenated together,
which all have a fixed length of seq_length. The dataset will panic if a sequence is longer
than seq_length, except if trim_longer is set to True, in which case the sequence
will be trimmed to seq_length. It is important to set pad_token_id to the id of the
padding token in the tokenizer, otherwise the model will be trained on a wrong padding token.
By default, we set the pad_token_id to the pad_token_id of the tokenizer, if it exists,
otherwise we set it to 0. The padding is done at the end of the concatenated sequence (right padding).
"""
def __init__(
self,
tokenizer,
dataset,
infinite=False,
seq_length=2048,
content_field="content",
concat_token_id=None,
pad_token_id=None,
trim_longer=False,
# niche option for some instruct tasks. removes loss calculation for tokens before a certain token id (example-wise and inclusive)
# IMPORTANT: this token is a marker and gets removed from the input_ids and labels
mask_loss_till_token_id=None,
):
self.tokenizer = tokenizer
self.concat_token_id = concat_token_id if concat_token_id is not None else tokenizer.eos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.infinite = infinite
self.content_field = content_field
self.trim_longer = trim_longer
self.mask_loss_till_token_id = mask_loss_till_token_id
if pad_token_id is not None:
self.pad_token_id = pad_token_id
elif self.tokenizer.pad_token_id is None:
# default to self.concat_token_id if pad_token_id is not set
self.pad_token_id = self.concat_token_id
else:
# we good, we have a pad token id preset
self.pad_token_id = self.tokenizer.pad_token_id
print(f"Concat token id (EOS token): {self.concat_token_id}")
print(f"Pad token id: {self.pad_token_id}")
if self.mask_loss_till_token_id is not None:
print(
f"Masking loss till token id: {self.mask_loss_till_token_id}")
def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
prev_iter_skipped = None
while more_examples:
buffer, buffer_len = [], 0
while True:
try:
if prev_iter_skipped is not None:
encoded = prev_iter_skipped
prev_iter_skipped = None
else:
new = next(iterator)[self.content_field]
encoded = self.tokenizer.encode(
new) + [self.concat_token_id]
if len(encoded) > self.seq_length:
if self.trim_longer:
encoded = encoded[:self.seq_length -
1] + [self.concat_token_id]
else:
raise ValueError(
f"Sequence of length {len(encoded)} is longer than seq_length {self.seq_length}."
)
if len(encoded) + buffer_len > self.seq_length:
prev_iter_skipped = encoded
break
buffer.append(encoded)
buffer_len += len(encoded)
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
else:
more_examples = False
break
assert buffer_len <= self.seq_length
# shuffle the buffer
# random.shuffle(buffer)
# concatenate all sequences
token_ids = []
for tokenized_input in buffer:
token_ids.extend(tokenized_input)
# pad to seq_length
token_ids.extend([self.pad_token_id] *
(self.seq_length - len(token_ids)))
labels = token_ids
# TODO: this is awful and just for experimentation, clean it up
if self.mask_loss_till_token_id is not None:
labels = apply_mask_till_token_id(
labels,
mask_till_token_id=self.mask_loss_till_token_id,
concat_token_id=self.concat_token_id,
pad_token_id=self.pad_token_id,
)
# remove the mask_till_token_id from the input_ids and labels
token_ids = [t for t in token_ids if t !=
self.mask_loss_till_token_id]
assert len(token_ids) == len(labels)
# pad to seq_length
token_ids.extend([self.pad_token_id] *
(self.seq_length - len(token_ids)))
labels.extend([self.pad_token_id] *
(self.seq_length - len(labels)))
assert len(token_ids) == len(labels)
yield {
"input_ids": torch.tensor(token_ids, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
# TODO: this is not 100% optimal, but it's fine for now; we use right padding
"attention_mask": torch.ones(len(token_ids)),
}
def get_tokenizer(self):
return self.tokenizer
def apply_mask_till_token_id(
token_ids: List[int],
mask_till_token_id: int,
concat_token_id: int,
pad_token_id: int,
mask=-100,
):
# notes:
# - stop condition: we reach the end of the sequence or we find pad_token_id and evey token after it is pad_token_id
# - concat_token_id delimiates examples. we need to start re-masking from this token_id
# - mask_till_token_id is the token that dictates the end of the masking (we also mask this token)
masked = []
masking = True
masking_off_forever = False
for i, token in enumerate(token_ids):
if token == pad_token_id and all([t == pad_token_id or t == concat_token_id for t in token_ids[i:]]):
masking_off_forever = True
if token == mask_till_token_id:
masking = False
else: # do not add if mask token
if masking and not masking_off_forever:
masked.append(mask)
else:
masked.append(token)
if token == concat_token_id:
masking = True
return masked
class TQDMWraper(IterableDataset):
def __init__(self, dataset, num_iters=None, desc=""):
self.dataset = dataset
self.num_iters = num_iters
self.desc = desc
def __iter__(self):
for example in tqdm(self.dataset, total=self.num_iters, desc=self.desc):
yield example
def get_tokenizer(self):
return self.dataset.get_tokenizer()
if __name__ == "__main__":
# testing out the padded dataset
import datasets
from transformers import AutoTokenizer
ds = datasets.load_dataset(
"nuprl-staging/multiplt-python-instrs-5k-train-codellama", split="train")
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-13b-hf")
dataset = PaddedDataset(tokenizer, ds, seq_length=2048)
num_exs = 0
for i, example in enumerate(dataset):
decoded = tokenizer.decode(example["input_ids"])
if i < 4:
print("#" * 80)
print(decoded)
num_exs += decoded.count("### Instruction:")
print(ds)
print(f"Total number of examples: {num_exs}")