-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataset.py
73 lines (60 loc) · 3.04 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torch.utils.data import Dataset
class BilingualDataset(Dataset):
def __init__(self, dataset, source_tokenizer, target_tokenizer, source_language, target_language, sequence_length):
super().__init__()
self.dataset = dataset
self.source_tokenizer = source_tokenizer
self.target_tokenizer = target_tokenizer
self.source_language = source_language
self.target_language = target_language
self.sequence_length = sequence_length
self.SOS_token = torch.tensor([target_tokenizer.token_to_id("[SOS]")], dtype=torch.int64)
self.PAD_token = torch.tensor([target_tokenizer.token_to_id("[PAD]")], dtype= torch.int64)
self.EOS_token = torch.tensor([target_tokenizer.token_to_id("[EOS]")], dtype= torch.int64)
def __len__(self):
return len(self.dataset)
def __getitem__(self, index) :
source_target_dataset = self.dataset[index]
source_text = source_target_dataset['translation'][self.source_language]
target_text = source_target_dataset['translation'][self.target_language]
encode_source_tokenizer = self.source_tokenizer.encode(source_text).ids
encode_target_tokenizer = self.target_tokenizer.encode(target_text).ids
encode_source_padding = self.sequence_length - len(encode_source_tokenizer) - 2 # sure
encode_target_padding = self.sequence_length - len(encode_target_tokenizer) - 1
if encode_source_padding < 0 or encode_target_padding < 0:
raise ValueError("sequence is too long")
encoder_input = torch.cat(
[
self.SOS_token,
torch.tensor(encode_source_tokenizer, dtype=torch.int64),
self.EOS_token,
torch.tensor([self.PAD_token] * encode_source_padding, dtype=torch.int64)
]
)
decoder_input = torch.cat(
[
self.SOS_token,
torch.tensor(encode_target_tokenizer, dtype=torch.int64),
torch.tensor([self.PAD_token] * encode_target_padding, dtype=torch.int64)
]
)
Target = torch.cat(
[
torch.tensor(encode_target_tokenizer, dtype=torch.int64),
torch.tensor([self.PAD_token] * encode_target_padding, dtype=torch.int64),
self.EOS_token
]
)
assert encoder_input.size(0) == self.sequence_length
assert decoder_input.size(0) == self.sequence_length
assert Target.size(0) == self.sequence_length
return {
"encoder_input": encoder_input,
"decoder_input": decoder_input,
"encoder_input_mask": (encoder_input != self.PAD_token).unsqueeze(0).unsqueeze(0).int(),
"decoder_input_mask": (decoder_input != self.PAD_token).unsqueeze(0).int() & casual_mask(decoder_input.size(0)),
"Target": Target,
"source_text": source_text,
"target_text": target_text
}