-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets_utils.py
61 lines (53 loc) · 2.16 KB
/
datasets_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torchvision.transforms as transforms
from PIL import Image, ImageOps
class HistogramEqualization:
def __call__(self, img):
return ImageOps.equalize(img)
# We use the recommended transformation from : https://www.frontiersin.org/journals/medicine/articles/10.3389/fmed.2021.629134/full
# We noticed some lower contrast images in the dataset, so we decided to apply histogram equalization to improve the contrast.
# Define transformations for the training set
train_transforms = transforms.Compose([
#HistogramEqualization(), # Apply histogram equalization
transforms.Resize((256, 256)),
transforms.Grayscale(num_output_channels=1),
#transforms.RandomRotation(10),
transforms.ToTensor(),
#transforms.RandomHorizontalFlip(),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.Normalize((0.5,), (0.5,))
])
verify_transforms_512 = transforms.Compose([
transforms.Resize((512, 512)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Define transformations for the test set (usually, we don't apply augmentation to the test set)
test_transforms_512 = transforms.Compose([
transforms.Resize((512, 512)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
verify_transforms_256 = transforms.Compose([
transforms.Resize((256, 256)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Define transformations for the test set (usually, we don't apply augmentation to the test set)
test_transforms_256 = transforms.Compose([
transforms.Resize((256, 256)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
transformations = {
'verify_transforms_512': verify_transforms_512,
'test_transforms_512': test_transforms_512,
'verify_transforms_256': verify_transforms_256,
'test_transforms_256': test_transforms_256,
'train_transforms': train_transforms
}
def get_transform(transform_name):
return transformations[transform_name]