Skip to content

Commit

Permalink
removed useless print
Browse files Browse the repository at this point in the history
  • Loading branch information
mboudiaf committed Apr 16, 2021
1 parent 90edd48 commit c6d6922
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
5 changes: 3 additions & 2 deletions pytorch_meta_dataset/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self,
self.transforms = transforms
self.max_query_size = max_query_size
self.max_support_size = max_support_size
self.random_gen = None
self.random_gen = np.random.RandomState()

def __iter__(self):
while True:
Expand Down Expand Up @@ -154,6 +154,7 @@ def __init__(self,
super(BatchDataset).__init__()
self.class_datasets = class_datasets
self.transforms = transforms
self.random_gen = np.random.RandomState()

def __iter__(self):
while True:
Expand All @@ -176,7 +177,7 @@ class ZipDataset(torch.utils.data.IterableDataset):
def __init__(self,
dataset_list: List[EpisodicDataset]):
self.dataset_list = dataset_list
self.random_gen = None
self.random_gen = np.random.RandomState()

def __iter__(self):
while True:
Expand Down
14 changes: 9 additions & 5 deletions pytorch_meta_dataset/tfrecord/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


def tfrecord_iterator(data_path: str,
random_gen: np.random.RandomState,
index_path: typing.Optional[str] = None,
shard: typing.Optional[typing.Tuple[int, int]] = None,
shuffle: bool = False,
Expand Down Expand Up @@ -48,8 +49,9 @@ def tfrecord_iterator(data_path: str,
crc_bytes = bytearray(4)
datum_bytes = bytearray(1024 * 1024)

def random_reader(indexes):
random_permutation = np.random.permutation(range(indexes.shape[0]))
def random_reader(indexes: np.ndarray,
random_gen: np.random.RandomState,):
random_permutation = random_gen.permutation(range(indexes.shape[0]))
for i in random_permutation:
start = indexes[i, 0]
end = indexes[i, 0] + indexes[i, 1]
Expand Down Expand Up @@ -83,7 +85,7 @@ def read_records(start_offset=None, end_offset=None):
indexes = np.loadtxt(index_path, dtype=np.int64)
# if shard is None:
if shuffle:
yield from random_reader(indexes)
yield from random_reader(indexes=indexes, random_gen=random_gen)
else:
yield from read_records()
# else:
Expand Down Expand Up @@ -165,6 +167,7 @@ def get_value(typename, typename_mapping, key):


def example_loader(data_path: str,
random_gen: np.random.RandomState,
index_path: typing.Union[str, None],
description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
shard: typing.Optional[typing.Tuple[int, int]] = None,
Expand Down Expand Up @@ -210,7 +213,7 @@ def example_loader(data_path: str,
"int": "int64_list"
}

record_iterator = tfrecord_iterator(data_path, index_path, shard, shuffle)
record_iterator = tfrecord_iterator(data_path, random_gen, index_path, shard, shuffle)

for record, (start, end) in record_iterator:
# yield record
Expand Down Expand Up @@ -291,6 +294,7 @@ def sequence_loader(data_path: str,

def tfrecord_loader(data_path: str,
index_path: typing.Union[str, None],
random_gen: np.random.RandomState,
description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
shard: typing.Optional[typing.Tuple[int, int]] = None,
shuffle: bool = False,
Expand Down Expand Up @@ -343,7 +347,7 @@ def tfrecord_loader(data_path: str,
"""
if sequence_description is not None:
return sequence_loader(data_path, index_path, description, sequence_description, shard)
return example_loader(data_path, index_path, description, shard, shuffle)
return example_loader(data_path, random_gen, index_path, description, shard, shuffle)


def multi_tfrecord_loader(data_pattern: str,
Expand Down
4 changes: 3 additions & 1 deletion pytorch_meta_dataset/tfrecord/torch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self,
self.sequence_description = sequence_description
self.shuffle = shuffle
self.transform = transform
self.random_gen = np.random.RandomState()

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
Expand All @@ -78,7 +79,8 @@ def __iter__(self):
description=self.description,
shard=shard,
shuffle=self.shuffle,
sequence_description=self.sequence_description)
sequence_description=self.sequence_description,
random_gen=self.random_gen)
if self.transform:
it = map(self.transform, it)
return it
Expand Down
6 changes: 4 additions & 2 deletions pytorch_meta_dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ def worker_init_fn_(worker_id, seed):
dataset = worker_info.dataset # the dataset copy in this worker process
random_gen = np.random.RandomState(seed + worker_id)
dataset.random_gen = random_gen
for d in dataset.dataset_list:
d.random_gen = random_gen
for source_dataset in dataset.dataset_list:
source_dataset.random_gen = random_gen
for class_dataset in source_dataset.class_datasets:
class_dataset.random_gen = random_gen


def cycle_(iterable):
Expand Down

0 comments on commit c6d6922

Please sign in to comment.