-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathEpisodeSampler.py
54 lines (41 loc) · 2.05 KB
/
EpisodeSampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import random
import typing
from scipy.special import comb
class EpisodeSampler(torch.utils.data.BatchSampler):
"""Sample data to form a classification task
Args:
data_source (Dataset): dataset to sample from
"""
def __init__(self, sampler: torch.utils.data.Sampler[typing.List[int]], num_ways: int, num_samples_per_class: int, drop_last: bool = True) -> None:
super().__init__(sampler=sampler, batch_size=num_ways, drop_last=drop_last)
self.num_ways = num_ways
self.num_samples_per_class = num_samples_per_class
# create a list of dictionary, each has:
# - key = label
# - value = img_idx
self.class_img_idx = [None] * len(sampler.data_source.datasets)
j = 0 # track the length of each dataset
for dataset_id in range(len(sampler.data_source.datasets)):
self.class_img_idx[dataset_id] = {}
for i in range(len(self.sampler.data_source.datasets[dataset_id])):
label_idx = self.sampler.data_source.datasets[dataset_id].targets[i]
if label_idx not in self.class_img_idx[dataset_id]:
self.class_img_idx[dataset_id][label_idx] = []
self.class_img_idx[dataset_id][label_idx].append(i + j)
j = len(sampler.data_source.datasets[dataset_id])
def __iter__(self) -> typing.Iterator[typing.List[int]]:
while(True):
# randomly sample a dataset
dataset_id = random.randint(a=0, b=len(self.sampler.data_source.datasets) - 1)
# n-way
labels = random.sample(population=self.class_img_idx[dataset_id].keys(), k=self.num_ways)
# variable to store img idx
batch = []
for label in labels:
batch.extend(random.sample(population=self.class_img_idx[dataset_id][label], k=self.num_samples_per_class))
yield batch
batch = []
labels = []
def __len__(self) -> int:
return comb(N=len(self.label_list), k=self.batch)