Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct typo and Adding Episode Sampling Seed functionality #98

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"print.colourScheme": "nnfx"
}
11 changes: 8 additions & 3 deletions meta_dataset/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ def make_one_source_episode_pipeline(dataset_spec,
image_size=None,
num_to_take=None,
ignore_hierarchy_probability=0.0,
simclr_episode_fraction=0.0):
simclr_episode_fraction=0.0,
episode_sampling_seed=None):
"""Returns a pipeline emitting data from one single source as Episodes.

Args:
Expand Down Expand Up @@ -428,7 +429,8 @@ def make_one_source_episode_pipeline(dataset_spec,
use_dag_hierarchy=use_dag_ontology,
use_bilevel_hierarchy=use_bilevel_ontology,
use_all_classes=use_all_classes,
ignore_hierarchy_probability=ignore_hierarchy_probability)
ignore_hierarchy_probability=ignore_hierarchy_probability,
episode_sampling_seed=episode_sampling_seed)
dataset = episode_reader.create_dataset_input_pipeline(sampler, pool=pool)
# Episodes coming out of `dataset` contain flushed examples and are internally
# padded with placeholder examples. `process_episode` discards flushed
Expand Down Expand Up @@ -463,6 +465,7 @@ def make_multisource_episode_pipeline(dataset_spec_list,
image_size=None,
num_to_take=None,
source_sampling_seed=None,
episode_sampling_seed=None,
simclr_episode_fraction=0.0):
"""Returns a pipeline emitting data from multiple sources as Episodes.

Expand Down Expand Up @@ -492,6 +495,7 @@ def make_multisource_episode_pipeline(dataset_spec_list,
length must be the same as len(dataset_spec). If None, no restrictions are
applied to any dataset and all data per class is used.
source_sampling_seed: random seed for source sampling.
episode_sampling_seed: random seed for episode sampling.
simclr_episode_fraction: Float, fraction of episodes that will be converted
to SimCLR Episodes as described in the CrossTransformers paper.

Expand Down Expand Up @@ -524,7 +528,8 @@ def make_multisource_episode_pipeline(dataset_spec_list,
episode_descr_config,
pool=pool,
use_dag_hierarchy=use_dag_ontology,
use_bilevel_hierarchy=use_bilevel_ontology)
use_bilevel_hierarchy=use_bilevel_ontology,
episode_sampling_seed=episode_sampling_seed)
dataset = episode_reader.create_dataset_input_pipeline(sampler, pool=pool)
# Create a dataset to zip with the above for identifying the source.
source_id_dataset = tf.data.Dataset.from_tensors(source_id).repeat()
Expand Down
14 changes: 11 additions & 3 deletions meta_dataset/data/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ def __init__(self,
use_dag_hierarchy=False,
use_bilevel_hierarchy=False,
use_all_classes=False,
ignore_hierarchy_probability=0.0):
ignore_hierarchy_probability=0.0,
episode_sampling_seed=None):
"""Initializes an EpisodeDescriptionSampler.episode_config.

Args:
Expand All @@ -251,6 +252,8 @@ def __init__(self,
ignore_hierarchy_probability: Float, if using a hierarchy, this flag makes
the sampler ignore the hierarchy for this proportion of episodes and
instead sample categories uniformly.
episode_sampling_seed: random seed for making episode description sampling
deterministic within individual data sources

Raises:
RuntimeError: if required parameters are missing.
Expand All @@ -259,8 +262,13 @@ def __init__(self,
# Each instance has its own RNG which is seeded from the module-level RNG,
# which makes episode description sampling deterministic within individual
# data sources.
self._rng = np.random.RandomState(
seed=RNG.randint(0, 2**32, size=None, dtype='uint32'))
if episode_sampling_seed == None:
self._rng = np.random.RandomState(
seed=RNG.randint(0, 2**32, size=None, dtype='uint32'))
else:
self._rng = np.random.RandomState(
seed=episode_sampling_seed)

self.dataset_spec = dataset_spec
self.split = split
self.pool = pool
Expand Down
2 changes: 1 addition & 1 deletion meta_dataset/learn/gin/best/prototypical_imagenet.gin
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ include 'meta_dataset/learn/gin/setups/imagenet.gin'
include 'meta_dataset/learn/gin/learners/prototypical_config.gin'

# Backbone hypers.
include 'meta_dataset/learn/gin/best_v2/pretrained_resnet.gin'
include 'meta_dataset/learn/gin/best/pretrained_resnet.gin'

# Data hypers.
DataConfig.image_height = 126
Expand Down