-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_dataset.py
74 lines (57 loc) · 2.23 KB
/
image_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import os.path
import random
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
def get_transform():
transform_list = [
transforms.Resize([286, 286], Image.BICUBIC),
transforms.RandomCrop(256),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
class ImageDataset(data.Dataset):
def __init__(self, dataroot, phase='train'):
self.dir_A = os.path.join(dataroot, phase + 'A')
self.dir_B = os.path.join(dataroot, phase + 'B')
self.A_paths = sorted(make_dataset(self.dir_A))
self.B_paths = sorted(make_dataset(self.dir_B))
self.A_size = len(self.A_paths)
self.B_size = len(self.B_paths)
self.transform = get_transform()
def __getitem__(self, index):
# make sure index is within then range
A_path = self.A_paths[index % self.A_size]
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
# apply image transformation
A = self.transform(A_img)
B = self.transform(B_img)
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
def __len__(self):
return max(self.A_size, self.B_size)
if __name__ == '__main__':
dataset = ImageDataset('./dataset/bart2lisa')
dataloader = data.DataLoader(dataset)
for d in dataloader:
print(d)