-
Notifications
You must be signed in to change notification settings - Fork 6
/
dataset_new.py
92 lines (76 loc) · 3.59 KB
/
dataset_new.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from torch.utils.data import Dataset
import os
import tifffile as tiff
import numpy as np
import torch
class Sen2_MTC(Dataset):
def __init__(self, opt, mode='train'):
self.root = opt.root
self.mode = mode
self.filepair = []
self.image_name = []
if mode == 'train':
self.tile_list = np.loadtxt(os.path.join(
self.root, 'train.txt'), dtype=str)
elif mode == 'val':
self.data_augmentation = None
self.tile_list = np.loadtxt(
os.path.join(self.root, 'val.txt'), dtype=str)
elif mode == 'test':
self.data_augmentation = None
self.tile_list = np.loadtxt(
os.path.join(self.root, 'test.txt'), dtype=str)
for tile in self.tile_list:
image_name_list = [image_name.split('.')[0] for image_name in os.listdir(
os.path.join(self.root, 'Sen2_MTC', tile, 'cloudless'))]
for image_name in image_name_list:
image_cloud_path0 = os.path.join(
self.root, 'Sen2_MTC', tile, 'cloud', image_name + '_0.tif')
image_cloud_path1 = os.path.join(
self.root, 'Sen2_MTC', tile, 'cloud', image_name + '_1.tif')
image_cloud_path2 = os.path.join(
self.root, 'Sen2_MTC', tile, 'cloud', image_name + '_2.tif')
image_cloudless_path = os.path.join(
self.root, 'Sen2_MTC', tile, 'cloudless', image_name + '.tif')
self.filepair.append(
[image_cloud_path0, image_cloud_path1, image_cloud_path2, image_cloudless_path])
self.image_name.append(image_name)
self.augment_rotation_param = np.random.randint(
0, 4, len(self.filepair))
self.augment_flip_param = np.random.randint(0, 3, len(self.filepair))
self.index = 0
def __getitem__(self, index):
cloud_image_path0, cloud_image_path1, cloud_image_path2 = self.filepair[
index][0], self.filepair[index][1], self.filepair[index][2]
cloudless_image_path = self.filepair[index][3]
image_cloud0 = self.image_read(cloud_image_path0)
image_cloud1 = self.image_read(cloud_image_path1)
image_cloud2 = self.image_read(cloud_image_path2)
image_cloudless = self.image_read(cloudless_image_path)
return [image_cloud0, image_cloud1, image_cloud2], image_cloudless, self.image_name[index]
def __len__(self):
return len(self.filepair)
def image_read(self, image_path):
img = tiff.imread(image_path)
img = (img / 1.0).transpose((2, 0, 1))
if self.mode == 'train':
if not self.augment_flip_param[self.index // 4] == 0:
img = np.flip(img, self.augment_flip_param[self.index//4])
if not self.augment_rotation_param[self.index // 4] == 0:
img = np.rot90(
img, self.augment_rotation_param[self.index // 4], (1, 2))
self.index += 1
if self.index // 4 >= len(self.filepair):
self.index = 0
image = torch.from_numpy((img.copy())).float()
image = image / 10000.0
mean = torch.as_tensor([0.5, 0.5, 0.5, 0.5],
dtype=image.dtype, device=image.device)
std = torch.as_tensor([0.5, 0.5, 0.5, 0.5],
dtype=image.dtype, device=image.device)
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
image.sub_(mean).div_(std)
return image