Skip to content

Commit

Permalink
Set seed for dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
mhuen committed Apr 21, 2024
1 parent c261fdd commit 3f6c6a4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
20 changes: 16 additions & 4 deletions dnn_reco/modules/models/general_IC86_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def general_model_IC86(

# apply DOM dropout, split and reshape DeepCore input
X_IC78, X_DeepCore_upper, X_DeepCore_lower = preprocess_icecube_data(
is_training, shared_objects
is_training,
shared_objects,
seed=config["tf_random_seed"],
)

# -----------------------------------
Expand Down Expand Up @@ -162,7 +164,11 @@ def general_model_IC86(
)

# dropout
layer_flat = tf.nn.dropout(layer_flat, rate=1 - (keep_prob_list[2]))
layer_flat = tf.nn.dropout(
layer_flat,
rate=1 - (keep_prob_list[2]),
seed=config["tf_random_seed"],
)

# -----------------------------------
# fully connected layers
Expand Down Expand Up @@ -387,7 +393,9 @@ def general_model_IC86_opt4(

# apply DOM dropout, split and reshape DeepCore input
X_IC78, X_DeepCore_upper, X_DeepCore_lower = preprocess_icecube_data(
is_training, shared_objects
is_training,
shared_objects,
seed=config["tf_random_seed"],
)

# -----------------------------------
Expand Down Expand Up @@ -450,7 +458,11 @@ def general_model_IC86_opt4(
)

# dropout
layer_flat = tf.nn.dropout(layer_flat, rate=1 - (keep_prob_list[2]))
layer_flat = tf.nn.dropout(
layer_flat,
rate=1 - (keep_prob_list[2]),
seed=config["tf_random_seed"],
)

# -----------------------------------
# fully connected layers
Expand Down
6 changes: 5 additions & 1 deletion dnn_reco/modules/models/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tensorflow as tf


def preprocess_icecube_data(is_training, shared_objects):
def preprocess_icecube_data(is_training, shared_objects, seed=None):
"""Performs some basic preprocessing of IceCube input data.
Applies drop out for whole DOMs.
Expand All @@ -20,6 +20,8 @@ def preprocess_icecube_data(is_training, shared_objects):
shared_objects : dict
A dictionary containing settings and objects that are shared and passed
on to sub modules.
seed : int, optional
Random seed for reproducibility.
Returns
-------
Expand Down Expand Up @@ -56,12 +58,14 @@ def preprocess_icecube_data(is_training, shared_objects):
X_IC78,
rate=1 - (keep_prob_list[0]),
noise_shape=noise_shape_IC78,
seed=seed,
)

X_DeepCore = tf.nn.dropout(
X_DeepCore,
rate=1 - (keep_prob_list[0]),
noise_shape=noise_shape_DeepCore,
seed=seed,
)

# -----------------------------------
Expand Down

0 comments on commit 3f6c6a4

Please sign in to comment.