-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmydataset.py
143 lines (114 loc) · 5.08 KB
/
mydataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# This script is used to create a dataset from a jsonl file for fine-tuning the OpenBuddy models.
# This script supports FastChat and OpenAI sample formats.
import torch
import os
import json
import random
import torch.multiprocessing as mp
from functools import partial
from tqdm import tqdm
from torch.utils.data import Dataset
# Must be the same value as the one defined in transformers.
IGNORE_INDEX = -100
NO_MASK = False
SEPS = ['\n', '\n\n', '、', ',' , '。', ';']
import hashlib
import os
PASS = ''
if 'PASS' in os.environ:
PASS = os.environ['PASS']
def derive_key(password, salt):
return hashlib.sha256(password.encode() + salt).digest()
def decrypt_file(filename, password):
with open(filename, 'rb') as file:
nonce = file.read(12)
ciphertext = file.read()
key = derive_key(password, b"mysalt")
cipher = ChaCha20.new(key=key, nonce=nonce)
plaintext = cipher.decrypt(ciphertext)
lines = plaintext.decode('utf-8').splitlines()
return lines
class SupervisedDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=100, sample_format='fourfourml'):
os.environ.setdefault('TOKENIZERS_PARALLELISM', 'false')
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
if self.tokenizer.pad_token_id is None:
raise ValueError("Tokenizer must have a pad token.")
self.cached_items = {}
self.data = []
if file_path.lower().endswith('.enc'):
global ChaCha20
from Crypto.Cipher import ChaCha20
lines = decrypt_file(file_path, PASS)
for line in lines:
item = json.loads(line)
self.data.append(item)
else:
with open(file_path, 'r', encoding='utf8') as f:
for line in tqdm(f):
item = json.loads(line)
self.data.append(item)
self.bos_ids = tokenizer.encode("", add_special_tokens=True)
self.role_assistant_says_ids = tokenizer.encode("<|role|>assistant<|says|>", add_special_tokens=False)
self.nl_ids = tokenizer.encode('\n', add_special_tokens=False)
assert len(self.nl_ids) == 1
self.end_ids = tokenizer.encode('<|end|>', add_special_tokens=False)
assert len(self.end_ids) == 1
self.end_nl_ids = tokenizer.encode('<|end|>\n', add_special_tokens=False)
assert self.end_nl_ids == self.end_ids + self.nl_ids
self.sep_ids_set = set()
for sep in SEPS:
enc = tokenizer.encode(sep, add_special_tokens=False)
self.sep_ids_set.add(enc[-1])
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# print('__getitem__', index)
if index in self.cached_items:
return self.cached_items[index]
dobj = self.data[index]
input_ids = []
labels = []
input_ids += self.bos_ids
labels += [IGNORE_INDEX] * len(self.bos_ids)
if 'txt' in dobj:
input_ids = self.tokenizer.encode(dobj['txt'])
labels = input_ids.copy()
else:
messages = dobj['messages']
for i in range(0, len(messages)):
msg = messages[i]
content = msg['content'].strip()
if msg['role'] != 'assistant':
inp = self.tokenizer.encode(f'<|role|>{msg["role"]}<|says|>{content}<|end|>\n', add_special_tokens=False)
input_ids += inp
labels += [IGNORE_INDEX] * len(inp)
else:
input_ids += self.role_assistant_says_ids
labels += [IGNORE_INDEX] * len(self.role_assistant_says_ids)
inp = self.tokenizer.encode(f'{content}<|end|>\n', add_special_tokens=False)
input_ids += inp
inp[-1] = IGNORE_INDEX
labels += inp
if NO_MASK:
labels = input_ids.copy()
assert len(input_ids) == len(labels)
if len(input_ids) > self.max_length:
input_ids = input_ids[:self.max_length]
labels = labels[:self.max_length]
input_ids_tensor = torch.full((self.max_length,), self.tokenizer.pad_token_id, dtype=torch.long)
attention_mask_tensor = torch.zeros(self.max_length, dtype=torch.bool)
labels_tensor = torch.full((self.max_length,), IGNORE_INDEX, dtype=torch.long)
seq_length = min(len(input_ids), self.max_length)
input_ids_tensor[:seq_length] = torch.tensor(input_ids[:seq_length], dtype=torch.long)
attention_mask_tensor[:seq_length] = True
labels_tensor[:seq_length] = torch.tensor(labels[:seq_length], dtype=torch.long)
ret = {
'input_ids': input_ids_tensor,
'attention_mask': attention_mask_tensor,
'labels': labels_tensor,
}
self.cached_items[index] = ret
return ret