-
Notifications
You must be signed in to change notification settings - Fork 2
/
data_transform.py
89 lines (59 loc) · 2.26 KB
/
data_transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torchvision.transforms.functional as F
import random
import torch
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
from torchvision import transforms
import PIL
from util import mask_to_semantic
class HorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, image, mask):
if random.random() < self.p:
image = np.flip(image, axis=1)
mask = np.flip(mask, axis=1)
return image, mask
class VerticalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, image, mask):
if random.random() < self.p:
image = np.flip(image, axis=0)
mask = np.flip(mask, axis=0)
return image, mask
class Rotate(object):
def __init__(self, degrees):
self.degrees = degrees
def __call__(self, image, mask):
angle = random.choice(self.degrees)
if angle == 90:
image = np.rot90(image, 1, (1, 0))
mask = np.rot90(mask, 1, (1, 0))
elif angle == 180:
image = np.rot90(image, 2, (1, 0))
mask = np.rot90(mask, 2, (1, 0))
elif angle == 270:
image = np.rot90(image, 3, (1, 0))
mask = np.rot90(mask, 3, (1, 0))
return image, mask
class Resize(object):
def __init__(self, p=0.5, scales=[(320, 320), (192, 192), (384, 384), (128, 128)]):
self.scales = scales
self.p = p
def __call__(self, image, mask):
if random.random() < self.p:
scale = random.choice(self.scales)
image = image.resize(scale, resample=PIL.Image.BILINEAR)
mask = mask.resize(scale, resample=PIL.Image.BILINEAR)
return image, mask
class ToTensor(object):
def __call__(self, image, mask, labels=[0, 1, 2], mode="train", smooth=False):
# image transform
for i in range(image.shape[2]):
image[:, :, i] = (image[:, :, i] - np.min(image[:, :, i])) / (np.max(image[:, :, i]) - np.min(image[:, :, i]))
# print(image.shape, image.dtype)
image = torch.from_numpy(image.transpose((2, 0, 1)).copy())
# mask transform to semantic
mask = torch.from_numpy(mask_to_semantic(mask, labels, smooth=smooth))
return image, mask