From c6d6922003380342ab2e3509425d96307aa925c5 Mon Sep 17 00:00:00 2001 From: Malik Date: Fri, 16 Apr 2021 18:46:11 +0200 Subject: [PATCH] removed useless print --- pytorch_meta_dataset/pipeline.py | 5 +++-- pytorch_meta_dataset/tfrecord/reader.py | 14 +++++++++----- pytorch_meta_dataset/tfrecord/torch/dataset.py | 4 +++- pytorch_meta_dataset/utils.py | 6 ++++-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/pytorch_meta_dataset/pipeline.py b/pytorch_meta_dataset/pipeline.py index e4a71d5..4724068 100644 --- a/pytorch_meta_dataset/pipeline.py +++ b/pytorch_meta_dataset/pipeline.py @@ -103,7 +103,7 @@ def __init__(self, self.transforms = transforms self.max_query_size = max_query_size self.max_support_size = max_support_size - self.random_gen = None + self.random_gen = np.random.RandomState() def __iter__(self): while True: @@ -154,6 +154,7 @@ def __init__(self, super(BatchDataset).__init__() self.class_datasets = class_datasets self.transforms = transforms + self.random_gen = np.random.RandomState() def __iter__(self): while True: @@ -176,7 +177,7 @@ class ZipDataset(torch.utils.data.IterableDataset): def __init__(self, dataset_list: List[EpisodicDataset]): self.dataset_list = dataset_list - self.random_gen = None + self.random_gen = np.random.RandomState() def __iter__(self): while True: diff --git a/pytorch_meta_dataset/tfrecord/reader.py b/pytorch_meta_dataset/tfrecord/reader.py index e8a86d5..780a3b3 100644 --- a/pytorch_meta_dataset/tfrecord/reader.py +++ b/pytorch_meta_dataset/tfrecord/reader.py @@ -13,6 +13,7 @@ def tfrecord_iterator(data_path: str, + random_gen: np.random.RandomState, index_path: typing.Optional[str] = None, shard: typing.Optional[typing.Tuple[int, int]] = None, shuffle: bool = False, @@ -48,8 +49,9 @@ def tfrecord_iterator(data_path: str, crc_bytes = bytearray(4) datum_bytes = bytearray(1024 * 1024) - def random_reader(indexes): - random_permutation = np.random.permutation(range(indexes.shape[0])) + def random_reader(indexes: np.ndarray, + random_gen: np.random.RandomState,): + random_permutation = random_gen.permutation(range(indexes.shape[0])) for i in random_permutation: start = indexes[i, 0] end = indexes[i, 0] + indexes[i, 1] @@ -83,7 +85,7 @@ def read_records(start_offset=None, end_offset=None): indexes = np.loadtxt(index_path, dtype=np.int64) # if shard is None: if shuffle: - yield from random_reader(indexes) + yield from random_reader(indexes=indexes, random_gen=random_gen) else: yield from read_records() # else: @@ -165,6 +167,7 @@ def get_value(typename, typename_mapping, key): def example_loader(data_path: str, + random_gen: np.random.RandomState, index_path: typing.Union[str, None], description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None, shard: typing.Optional[typing.Tuple[int, int]] = None, @@ -210,7 +213,7 @@ def example_loader(data_path: str, "int": "int64_list" } - record_iterator = tfrecord_iterator(data_path, index_path, shard, shuffle) + record_iterator = tfrecord_iterator(data_path, random_gen, index_path, shard, shuffle) for record, (start, end) in record_iterator: # yield record @@ -291,6 +294,7 @@ def sequence_loader(data_path: str, def tfrecord_loader(data_path: str, index_path: typing.Union[str, None], + random_gen: np.random.RandomState, description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None, shard: typing.Optional[typing.Tuple[int, int]] = None, shuffle: bool = False, @@ -343,7 +347,7 @@ def tfrecord_loader(data_path: str, """ if sequence_description is not None: return sequence_loader(data_path, index_path, description, sequence_description, shard) - return example_loader(data_path, index_path, description, shard, shuffle) + return example_loader(data_path, random_gen, index_path, description, shard, shuffle) def multi_tfrecord_loader(data_pattern: str, diff --git a/pytorch_meta_dataset/tfrecord/torch/dataset.py b/pytorch_meta_dataset/tfrecord/torch/dataset.py index a03d5d8..59a1b71 100644 --- a/pytorch_meta_dataset/tfrecord/torch/dataset.py +++ b/pytorch_meta_dataset/tfrecord/torch/dataset.py @@ -65,6 +65,7 @@ def __init__(self, self.sequence_description = sequence_description self.shuffle = shuffle self.transform = transform + self.random_gen = np.random.RandomState() def __iter__(self): worker_info = torch.utils.data.get_worker_info() @@ -78,7 +79,8 @@ def __iter__(self): description=self.description, shard=shard, shuffle=self.shuffle, - sequence_description=self.sequence_description) + sequence_description=self.sequence_description, + random_gen=self.random_gen) if self.transform: it = map(self.transform, it) return it diff --git a/pytorch_meta_dataset/utils.py b/pytorch_meta_dataset/utils.py index ba9a2cb..9b2124e 100644 --- a/pytorch_meta_dataset/utils.py +++ b/pytorch_meta_dataset/utils.py @@ -8,8 +8,10 @@ def worker_init_fn_(worker_id, seed): dataset = worker_info.dataset # the dataset copy in this worker process random_gen = np.random.RandomState(seed + worker_id) dataset.random_gen = random_gen - for d in dataset.dataset_list: - d.random_gen = random_gen + for source_dataset in dataset.dataset_list: + source_dataset.random_gen = random_gen + for class_dataset in source_dataset.class_datasets: + class_dataset.random_gen = random_gen def cycle_(iterable):