-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtiny_imagenet_datamodule.py
114 lines (95 loc) · 4.28 KB
/
tiny_imagenet_datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from argparse import ArgumentParser
from typing import Any, Callable, Optional, Sequence, Union
from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import TrialCIFAR10
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import CIFAR100
from torchvision.datasets import ImageFolder
else: # pragma: no cover
warn_missing_pkg("torchvision")
CIFAR100 = None
class TinyImagenetDataModule(VisionDataModule):
name = "tiny_imagenet"
dims = (3, 64, 64)
def __init__(
self,
data_dir: Optional[str] = None,
val_split: Union[int, float] = 0.1,
num_workers: int = 0,
normalize: bool = False,
batch_size: int = 32,
seed: int = 42,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: Where to save/load the data
val_split: Percent (float) or number (int) of samples to use for the validation split
num_workers: How many workers to use for loading data
normalize: If true applies image normalize
batch_size: How many samples per batch to load
seed: Random seed to be used for train/val/test splits
shuffle: If true shuffles the train data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__( # type: ignore[misc]
data_dir=data_dir,
val_split=val_split,
num_workers=num_workers,
normalize=normalize,
batch_size=batch_size,
seed=seed,
shuffle=shuffle,
pin_memory=pin_memory,
drop_last=drop_last,
*args,
**kwargs,
)
@property
def num_samples(self) -> int:
train_len, _ = self._get_splits(len_dataset=100000)
return train_len
@property
def num_classes(self) -> int:
return 200
def default_transforms(self) -> Callable:
if self.normalize:
cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
else:
cf10_transforms = transform_lib.Compose([transform_lib.ToTensor()])
return cf10_transforms
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
pass
def setup(self, stage: Optional[str] = None) -> None:
"""Creates train, val, and test dataset."""
if stage == "fit" or stage is None:
train_transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
val_transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
dataset_train = ImageFolder(self.data_dir + "/tiny-imagenet-200/train", transform=train_transforms, **self.EXTRA_ARGS)
dataset_val = ImageFolder(self.data_dir + "/tiny-imagenet-200/train", transform=val_transforms, **self.EXTRA_ARGS)
# Split
self.dataset_train = self._split_dataset(dataset_train)
self.dataset_val = self._split_dataset(dataset_val, train=False)
if stage == "test" or stage is None:
test_transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms
self.dataset_test = ImageFolder(
self.data_dir + "/tiny-imagenet-200/val", transform=test_transforms, **self.EXTRA_ARGS
)
@staticmethod
def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--num_workers", type=int, default=0)
parser.add_argument("--batch_size", type=int, default=32)
return parser