-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
56 lines (43 loc) · 1.66 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import glob
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision.transforms import ToTensor
class FishermanDatasetPreparer:
def prepare(self, raw_img_path, result_path, classes, transform):
if not os.path.exists(result_path):
os.mkdir(result_path)
for cls in classes:
raw_class_dir = raw_img_path + os.sep + cls
result_class_dir = result_path + os.sep + cls
if not os.path.exists(result_class_dir):
os.mkdir(result_class_dir)
files = glob.glob(raw_class_dir + os.sep + "*.png")
for file in files:
filename = file.split(os.sep)[-1]
result_file = result_class_dir + os.sep + filename
image = Image.open(file).convert("RGB")
image = transform(image)
image.save(result_file)
class FishermanSimplifiedDataset(Dataset):
def __init__(self, dir, classes):
self.dir = dir
self.classes = classes
self.items = []
self.load_dataset()
def __len__(self):
return len(self.items)
def __getitem__(self, item):
return self.items[item]
def load_dataset(self):
class_idx = 0
for cls in self.classes:
cls_dir = self.dir + os.sep + cls
files = glob.glob(cls_dir + os.sep + "*.png")
label = torch.zeros(len(self.classes))
label[class_idx] = 1.0
for file in files:
image = Image.open(file).convert("RGB")
self.items.append((label, ToTensor()(image)))
class_idx += 1