-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
118 lines (100 loc) · 3.3 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import numpy as np
import os
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
transform_aug = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomAffine(degrees=[-30.0, 30.0], translate=[0.0, 0.5], scale=[0.7, 1.3]),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
])
def get_dataset_and_dataloader(cifar_dir='cifar'):
train_data = datasets.CIFAR100(
cifar_dir,
train=True,
transform=transform_aug
)
clean_test_data = datasets.CIFAR100(
cifar_dir,
train=False,
transform=transforms.ToTensor(),
download=False
)
# define dataset
train_ratio = 0.2
train_len = int(train_ratio * len(clean_test_data))
test_len = len(clean_test_data) - train_len
partial_train_data, partial_test_data = torch.utils.data.random_split(clean_test_data, [train_len, test_len])
train_loader = DataLoader(
train_data,
batch_size=256,
shuffle=True,
pin_memory=True,
num_workers=4
)
full_test_loader = DataLoader(
clean_test_data,
batch_size=256,
shuffle=False,
pin_memory=True,
num_workers=4
)
return train_data, clean_test_data, train_loader, full_test_loader
CORRUPTIONS = [
'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
'brightness', 'contrast', 'elastic_transform', 'pixelate',
'jpeg_compression'
]
def get_corrupt_data(corruption, cifar_corrupt_dir='cifar_corrupt/CIFAR-100-C'):
test_data = datasets.CIFAR100(
'cifar',
train=False,
transform=transforms.ToTensor(),
download=False
)
test_data.data = np.load(os.path.join(cifar_corrupt_dir, corruption + '.npy'))
test_data.targets = torch.LongTensor(np.load(os.path.join(cifar_corrupt_dir, 'labels.npy')))
return test_data
def get_corrupt_loader(corruption, batch_size, cifar_corrupt_dir='cifar_corrupt/CIFAR-100-C'):
test_data = datasets.CIFAR100(
'cifar',
train=False,
transform=transforms.ToTensor(),
download=False
)
test_data.data = np.load(os.path.join(cifar_corrupt_dir, corruption + '.npy'))
test_data.targets = torch.LongTensor(np.load(os.path.join(cifar_corrupt_dir, 'labels.npy')))
test_loader = DataLoader(
test_data,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)
return test_loader
class CycleLoader(object):
def __init__(self, data, batch_size=50):
super(CycleLoader, self).__init__()
self.data = data
self.batch_size = batch_size
self.iter = self.get_iter()
def get_iter(self):
return iter(DataLoader(
self.data,
batch_size=self.batch_size,
shuffle=True,
pin_memory=True,
num_workers=4
))
def __next__(self):
next_item = None
try:
next_item = next(self.iter)
except StopIteration:
self.iter = self.get_iter()
next_item = next(self.iter)
finally:
return next_item