forked from Sohl-Dickstein/Diffusion-Probabilistic-Models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
imagenet_data.py
118 lines (88 loc) · 3.58 KB
/
imagenet_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fuel.streams import AbstractDataStream
from fuel.iterator import DataIterator
import numpy as np
import theano
class IMAGENET(AbstractDataStream):
"""
A fuel DataStream for imagenet data
from fuel:
A data stream is an iterable stream of examples/minibatches. It shares
similarities with Python file handles return by the ``open`` method.
Data streams can be closed using the :meth:`close` method and reset
using :meth:`reset` (similar to ``f.seek(0)``).
"""
def __init__(self, partition_label='train', datadir='/home/jascha/data/imagenet/JPEG/', seed=12345, fraction=0.9, width=256, **kwargs):
# ignore axis labels if not given
kwargs.setdefault('axis_labels', '')
# call __init__ of the AbstractDataStream
super(self.__class__, self).__init__(**kwargs)
# get a list of the images
import glob
print "getting imagenet images"
image_files = glob.glob(datadir + "*.JPEG")
print "filenames loaded"
self.sources = ('features',)
self.width = width
# shuffle indices, subselect a fraction
np.random.seed(seed=seed)
np.random.shuffle(image_files)
num_train = int(np.round(fraction * np.float32(len(image_files))))
train_files = image_files[:num_train]
test_files = image_files[num_train:]
if 'train' in partition_label:
self.X = train_files
elif 'test' in partition_label:
self.X = test_files
self.num_examples = len(self.X)
self.current_index = 0
def get_data(self, data_state, request=None):
"""Get a new sample of data"""
if request is None:
request = [self.current_index]
self.current_index += 1
return self.load_images(request)
def apply_default_transformers(self, data_stream):
return data_stream
def open(self):
return None
def close(self):
"""Close the hdf5 file"""
pass
def reset(self):
"""Reset the current data index"""
self.current_index = 0
def get_epoch_iterator(self, **kwargs):
return super(self.__class__, self).get_epoch_iterator(**kwargs)
# return None
# TODO: implement iterator
def next_epoch(self, *args, **kwargs):
self.current_index = 0
return super(self.__class__, self).next_epoch(**kwargs)
# return None
def load_images(self, inds):
print ".",
output = np.zeros((len(inds), 3, self.width, self.width), dtype=theano.config.floatX)
for ii, idx in enumerate(inds):
output[ii] = self.load_image(idx)
return [output]
def load_image(self, idx):
filename = self.X[idx]
import Image
import ImageOps
# print "loading ", self.X[idx]
image = Image.open(self.X[idx])
width, height = image.size
if width > height:
delta2 = int((width - height)/2)
image = ImageOps.expand(image, border=(0, delta2, 0, delta2))
else:
delta2 = int((height - width)/2)
image = ImageOps.expand(image, border=(delta2, 0, delta2, 0))
image = image.resize((self.width, self.width), resample=Image.BICUBIC)
try:
imagenp = np.array(image.getdata()).reshape((self.width,self.width,3))
imagenp = imagenp.transpose((2,0,1)) # move color channels to beginning
except:
# print "reshape failure (black and white?)"
imagenp = self.load_image(np.random.randint(len(self.X)))
return imagenp.astype(theano.config.floatX)