-
Notifications
You must be signed in to change notification settings - Fork 4
/
datamodule.py
63 lines (56 loc) · 2.03 KB
/
datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import lightning as L
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
class DigitDataModule(L.LightningDataModule):
def __init__(self, dict_size: int, batch_size=64):
super().__init__()
self.batch_size = batch_size
# Setup the transforms
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(lambda x: (x[0] * (dict_size - 2) + 1).long()),
# we fill the padding with 1 since 0 is the mask token
transforms.Pad((2, 2, 2, 2), fill=1, padding_mode="constant"),
]
)
def prepare_data(self):
MNIST("MNIST", train=True, download=True)
MNIST("MNIST", train=False, download=True)
def setup(self, stage: str = "fit"):
if stage == "fit":
full_set = MNIST(
root="MNIST",
train=True,
transform=self.transform,
download=True,
)
train_set_size = int(len(full_set) * 0.8)
val_set_size = len(full_set) - train_set_size
seed = torch.Generator().manual_seed(42)
(
self.train_set,
self.val_set,
) = data.random_split( # Split train/val datasets
full_set, [train_set_size, val_set_size], generator=seed
)
elif stage == "test":
self.test_set = MNIST(
root="MNIST",
train=False,
transform=self.transform,
)
def train_dataloader(self):
return data.DataLoader(
self.train_set, batch_size=self.batch_size, num_workers=10
)
def val_dataloader(self):
return data.DataLoader(
self.val_set, batch_size=self.batch_size, num_workers=10
)
def test_dataloader(self):
return data.DataLoader(
self.test_set, batch_size=self.batch_size, num_workers=10
)