-
Notifications
You must be signed in to change notification settings - Fork 0
/
alpaca_dataset.py
78 lines (66 loc) · 2.66 KB
/
alpaca_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
# For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html
import copy
import json
import torch
from torch.utils.data import Dataset
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
class InstructionDataset(Dataset):
def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
self.ann = json.load(open(dataset_config.data_path))
if partition == "train":
self.ann = self.ann
else:
self.ann = self.ann[:200]
self.max_words = max_words
# tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
self.tokenizer = tokenizer
# self.tokenizer1 = tokenizer
def __len__(self):
return len(self.ann)
def __getitem__(self, index):
IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
ann = self.ann[index]
if ann.get("input", "") == "":
prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
else:
prompt = PROMPT_DICT["prompt_input"].format_map(ann)
example = prompt + ann["output"]
prompt = torch.tensor(
self.tokenizer.encode(prompt), dtype=torch.int64
)
example = self.tokenizer.encode(example)
example.append(self.tokenizer.eos_token_id)
example = torch.tensor(
example, dtype=torch.int64
)
padding = self.max_words - example.shape[0]
if padding > 0:
example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
elif padding < 0:
example = example[: self.max_words]
labels = copy.deepcopy(example)
labels[: len(prompt)] = -1
example_mask = example.ge(0)
label_mask = labels.ge(0)
example[~example_mask] = 0
labels[~label_mask] = IGNORE_INDEX
example_mask = example_mask.float()
label_mask = label_mask.float()
return {
"input_ids": example,
"labels": labels,
"attention_mask":example_mask,
}