-
Notifications
You must be signed in to change notification settings - Fork 0
/
siamese_datasets.py
131 lines (109 loc) · 4.28 KB
/
siamese_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from attr import has
import torch
import numpy as np
from multiprocessing import Pool
import os
import torchvision
class TripletDataset(torch.utils.data.Dataset):
def __init__(self, ds, validation = False):
self.ds = ds
self.validation = validation
print("len ds", len(ds))
print("Creating label groups")
self.label_groups = {label: [] for label in self.ds.class_to_idx.values()}
for i, (_, l) in enumerate(self.ds):
self.label_groups[l].append(i)
def __getitem__(self, i):
anchor, label = self.ds[i]
# if self.validation:
# np.random.seed(i)
# get positive example
pos_idx = i
# while pos_idx == i: TODO could do this but rand aug will be enough
pos_idx = np.random.choice(self.label_groups[label])
pos = self.ds[pos_idx][0]
# get negative example
potential_labels = list(self.label_groups.keys() - set([label]))
neg_label = np.random.choice(potential_labels)
neg_idx = np.random.choice(self.label_groups[neg_label])
neg = self.ds[neg_idx][0]
return (anchor, pos, neg), label
def __len__(self):
return len(self.ds)
class PairDataset(torch.utils.data.Dataset):
def __init__(
self,
anchor_ds,
other_ds = None,
anchor_transform=None,
other_transform=None,
validation=False,
):
self.ds = anchor_ds
if other_ds is None:
self.other_ds = anchor_ds
else:
self.other_ds = other_ds
self.anchor_transform = anchor_transform
self.other_transform = other_transform
self.validation = validation
print("len ds", len(anchor_ds))
try:
self.label_groups = {label: [] for label in self.ds.class_to_idx.values()}
except AttributeError:
self.label_groups = {}
# print(sorted(self.label_groups.keys()))
print("Created label groups")
# Disable transform for SPEED
if isinstance(self.ds, torchvision.datasets.ImageFolder) or isinstance(self.ds, torchvision.datasets.ImageNet):
print("ImageFolder")
transform = self.ds.transform
self.ds.transform = None
loader = self.ds.loader
self.ds.loader = lambda x: x
elif isinstance(self.ds, torchvision.datasets.VisionDataset):
print("VisionDataset")
transform = self.ds.transforms
self.ds.transforms = None
else:
print("other")
for i, (_, l) in enumerate(self.other_ds):
if l not in self.label_groups:
self.label_groups[l] = [] # shouldn't happen but does
self.label_groups[l].append(i)
print(f"Len label groups: {len(self.label_groups)}")
# Re-enable transform
if isinstance(self.ds, torchvision.datasets.ImageFolder):
self.ds.transform = transform
self.ds.loader = loader
elif isinstance(self.ds, torchvision.datasets.VisionDataset):
self.ds.transforms = transform
def __getitem__(self, i):
if self.validation:
np.random.seed(i)
# This fixes the randomness of the training dataset as well
# torch.manual_seed(i)
anchor, label = self.ds[i]
is_pos = np.random.random() > 0.5
if is_pos:
# get positive example
other_idx = np.random.choice(self.label_groups[label])
other = self.other_ds[other_idx][0]
label = 1
else:
# get negative example
potential_labels = list(self.label_groups.keys() - set([label]))
other_label = np.random.choice(potential_labels)
other_idx = np.random.choice(self.label_groups[other_label])
other = self.other_ds[other_idx][0]
label = 0
if self.anchor_transform is not None:
anchor = self.anchor_transform(anchor)
if self.other_transform is not None:
other = self.other_transform(other)
if self.validation:
# This fixes the randomness of the training dataset as well
np.random.seed()
return (anchor, other), label
def __len__(self):
return len(self.ds)