-
Notifications
You must be signed in to change notification settings - Fork 4
/
GeneratorCustom.py
69 lines (68 loc) · 3.1 KB
/
GeneratorCustom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
from tensorflow.keras.utils import Sequence
import cv2
import tensorflow as tf
class DataGenerator(Sequence):
def __init__(self, all_filenames, input_size = (256, 256), batch_size = 32, shuffle = True, seed = 123, encode: dict = None, encode_with_kmean = None, color_mode = 'hsv', function = None) -> None:
super(DataGenerator, self).__init__()
assert (encode != None and encode_with_kmean == None) or (encode == None and encode_with_kmean != None), 'Not empty !'
assert color_mode == 'hsv' or color_mode == 'rgb' or color_mode == 'gray'
self.all_filenames = all_filenames
self.input_size = input_size
self.batch_size = batch_size
self.shuffle = shuffle
self.color_mode = color_mode
self.encode = encode
self.function = function
self.kmean = encode_with_kmean
np.random.seed(seed)
self.on_epoch_end()
def processing(self, mask):
d = list(map(lambda x: self.encode[tuple(x)], mask.reshape(-1,3)))
return np.array(d).reshape(*self.input_size, 1)
def __len__(self):
return int(np.floor(len(self.all_filenames) / self.batch_size))
def __getitem__(self, index):
indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
all_filenames_temp = [self.all_filenames[k] for k in indexes]
X, Y = self.__data_generation(all_filenames_temp)
return X, Y
def on_epoch_end(self):
self.indexes = np.arange(len(self.all_filenames))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation(self, all_filenames_temp):
batch = len(all_filenames_temp)
if self.color_mode == 'gray':
X = np.empty(shape=(batch, *self.input_size, 1))
else:
X = np.empty(shape=(batch, *self.input_size,3))
Y = np.empty(shape=(batch, *self.input_size, 1))
for i, (fn, label_fn) in enumerate(all_filenames_temp):
# img
img = cv2.imread(fn)
if self.color_mode == 'hsv':
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
elif self.color_mode == 'rgb':
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
elif self.color_mode == 'gray':
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = tf.expand_dims(img, axis = 2)
img = tf.image.resize(img, self.input_size, method = 'nearest')
img = tf.cast(img, tf.float32)
img /= 255.
#mask
mask = cv2.imread(label_fn, 0)
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
mask = tf.image.resize(mask, self.input_size, method= 'nearest')
mask = np.array(mask)
if self.function:
mask = self.function(mask)
if self.encode:
mask = self.processing(mask)
if self.kmean:
mask = self.kmean.predict(mask.reshape(-1,3)).reshape(*self.input_size, 1)
mask = tf.cast(mask, tf.float32)
X[i,] = img
Y[i,] = mask
return X, Y