-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
131 lines (112 loc) · 3.75 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import glob
import torch
import numpy as np
import pandas as pd
from PIL import Image, ImageFile
from collections import defaultdict
from torchvision import transforms
from albumentations import (
Compose,
OneOf,
RandomBrightnessContrast,
RandomGamma,
ShiftScaleRotate
)
ImageFile.LOAD_TRUNCATED_IMAGES = True
class SIIM_ACRDataset(torch.utils.data.Dataset):
def __init__(self,
image_ids,
arguments,
transform=True,
preprocessing_fn=None):
"""
Dataset class for segmentation problem
-> image_ids: ids of the image, list
-> transform: True/False, no transform in validation
-> preprocessing_fn: a function for preprocessing image
"""
self.data = defaultdict(dict)
self.counter = 0
# get the arguments
self.args = arguments
# for augmentation
self.transform = transform
# preprocessing function
self.preprocessing_fn = preprocessing_fn
# albumentation augmentation
# have shift, scale & rotate features
# and is applied with 80% probability.
# Gamma controls the and brightness /
# contrast i.e. applied to the image.
# Albumentation takes care of which augmentation
# is applied to image and mask
self.aug = Compose(
[
ShiftScaleRotate(
shift_limit=0.0625,
scale_limit=0.1,
rotate_limit=10, p=0.8
),
OneOf(
[
RandomGamma(
gamma_limit=(90, 110)
),
RandomBrightnessContrast(
brightness_limit=0.1,
contrast_limit=0.1
),
],
p=0.5,
),
]
)
# going pver all image_ids to store
# image and mask paths
for imgid in image_ids:
imgid = imgid.split('.')[0]
self.data[self.counter] = {
"img_path": os.path.join(
self.args.image_path, imgid + ".png"
),
"mask_path": os.path.join(
self.args.mask_path, imgid + ".png"
),
}
self.counter+=1
def __len__(self):
# return length of dataset
return len(self.data)
def __getitem__(self, item):
# for a givenm item index,
# return image and mask tensors
# read image and mask paths
img_path = self.data[item]["img_path"]
#print("img_path", img_path)
mask_path = self.data[item]["mask_path"]
# read image and convert to RGB
img = Image.open(img_path)
img = img.resize((self.args.resize_img, self.args.resize_img))
img = img.convert("RGB")
# PIL image to numpy array
img = np.array(img)
# read mask image
mask = Image.open(mask_path)
mask= mask.resize((self.args.resize_img, self.args.resize_img))
mask = np.array(mask)
# Convert to binary float matrix
mask = (mask>=1).astype("float32")
# if this is training data, apply transforms
if self.transform is True:
augmented = self.aug(image=img, mask=mask)
img = augmented["image"]
mask = augmented["mask"]
# image normalization using preprocessing
# tensors.
img = self.preprocessing_fn(img)
# return image and mask tensors
return {
"image": transforms.ToTensor()(img),
"mask": transforms.ToTensor()(mask).float(),
}