-
Notifications
You must be signed in to change notification settings - Fork 3
/
dataset_DiabeticRet.py
86 lines (65 loc) · 2.27 KB
/
dataset_DiabeticRet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
11-05-2022 Linde S. Hesse
File containing the dataset for the diabetic retinopathy challenge
"""
from torch.utils.data import Dataset
from pathlib import PosixPath
import cv2
import torch
from tqdm import tqdm
class DiabeticRet(Dataset):
def __init__(self, datapath, labels, preload=True, datapart=1.0, transform = None):
""" Load in the diabetic Res dataset
Args:
datapath (Posixpath / List(Posixpath)): Path to folders or list of image paths
csv_path (str):name of csv
preload (bool, optional): [description]. Defaults to True.
datapart (float, optional): [description]. Defaults to 1.0.
"""
self.datapath = datapath
self.preload = preload
if type(datapath) == PosixPath:
self.image_paths = list(datapath.glob('**/*.jpeg'))
else:
self.image_paths = datapath
if preload:
self.preload_ims()
self.transform = transform
# get the labels from the images given
if labels is not None:
self.labels = labels
else:
self.labels = len(self.image_paths) * [0]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
im_path = self.image_paths[idx]
# Based on preloading get image
if self.preload:
im = self.ims[idx]
else:
im = self.read_im(im_path)[0]
# Get label
label = float(self.labels[idx])
if self.transform is not None:
im = self.transform(im)
return im, label, str(im_path.name)
def preload_ims(self):
""" Preload the images in memory
"""
self.ims = torch.zeros(len(self.image_paths), 3, 540, 540)
for i, pathx in enumerate(tqdm(self.image_paths)):
im = self.read_im(pathx)
self.ims[i] = im[0]
def read_im(self, pathx):
""" Read a single image from the path
Args:
pathx ([type]): [description]
Returns:
[type]: [description]
"""
jpeg_im = cv2.imread(str(pathx))
norm = jpeg_im/255
im = torch.zeros(1, 3, 540,540)
im[0] = torch.from_numpy(norm).permute([2,1,0])
return im