From aad6dc4ac551e0a13e6af2579cc4bbe63884b5d7 Mon Sep 17 00:00:00 2001 From: Nicolay Rusnachenko Date: Tue, 20 Apr 2021 12:24:26 +0300 Subject: [PATCH] Embedding seed parameter. Not utilized by default (#96) --- common/model/base.py | 2 +- contrib/networks/context/architectures/base/base.py | 6 +++++- contrib/networks/core/model.py | 4 ++-- contrib/networks/core/nn.py | 2 +- contrib/networks/multi/architectures/base/base.py | 6 +++++- contrib/networks/run_training.py | 8 ++++++-- contrib/networks/tests/test_tf_ctx_compile.py | 2 +- contrib/networks/tests/test_tf_ctx_feed.py | 2 +- contrib/networks/tests/test_tf_mi_compile.py | 2 +- 9 files changed, 23 insertions(+), 11 deletions(-) diff --git a/common/model/base.py b/common/model/base.py index 90b3f359..e5e54399 100644 --- a/common/model/base.py +++ b/common/model/base.py @@ -20,7 +20,7 @@ def IO(self): return self.__io # TODO. Remove epochs count, since it is related to NeuralNetworks only. - def run_training(self, epochs_count): + def run_training(self, epochs_count, seed): raise NotImplementedError() def predict(self, data_type=DataType.Test): diff --git a/contrib/networks/context/architectures/base/base.py b/contrib/networks/context/architectures/base/base.py index 158f1553..4db71f09 100644 --- a/contrib/networks/context/architectures/base/base.py +++ b/contrib/networks/context/architectures/base/base.py @@ -131,15 +131,19 @@ def compile_hidden_states_only(self, config): self.__init_embedding_hidden_states() self.init_body_dependent_hidden_states() - def compile(self, config, reset_graph): + def compile(self, config, reset_graph, graph_seed=None): assert(isinstance(config, DefaultNetworkConfig)) assert(isinstance(reset_graph, bool)) + assert(isinstance(graph_seed, int) or graph_seed is None) self.__cfg = config if reset_graph: tf.reset_default_graph() + if graph_seed is not None: + tf.set_random_seed(graph_seed) + self.init_input() self.__init_embedding_hidden_states() self.init_body_dependent_hidden_states() diff --git a/contrib/networks/core/model.py b/contrib/networks/core/model.py index d54fe9f3..01ce83d9 100644 --- a/contrib/networks/core/model.py +++ b/contrib/networks/core/model.py @@ -104,8 +104,8 @@ def __dispose_session(self): """ self.__sess.close() - def run_training(self, epochs_count): - self.__network.compile(self.Config, reset_graph=True) + def run_training(self, epochs_count, seed): + self.__network.compile(self.Config, reset_graph=True, graph_seed=seed) self.set_optimiser() self.__notify_initialized() diff --git a/contrib/networks/core/nn.py b/contrib/networks/core/nn.py index 7a93263b..f1a0be65 100644 --- a/contrib/networks/core/nn.py +++ b/contrib/networks/core/nn.py @@ -23,7 +23,7 @@ def iter_input_dependent_hidden_parameters(self): return yield - def compile(self, config, reset_graph): + def compile(self, config, reset_graph, graph_seed): raise NotImplementedError() def create_feed_dict(self, input, data_type): diff --git a/contrib/networks/multi/architectures/base/base.py b/contrib/networks/multi/architectures/base/base.py index de0247f9..eb3c864d 100644 --- a/contrib/networks/multi/architectures/base/base.py +++ b/contrib/networks/multi/architectures/base/base.py @@ -64,12 +64,16 @@ def DropoutKeepProb(self): # region body - def compile(self, config, reset_graph): + def compile(self, config, reset_graph, graph_seed=None): assert(isinstance(config, BaseMultiInstanceConfig)) + assert(isinstance(graph_seed, int) or graph_seed is None) self.__cfg = config tf.reset_default_graph() + if graph_seed is not None: + tf.set_random_seed(graph_seed) + with tf.variable_scope(self.__ctx_network_scope): self.__context_network.compile_hidden_states_only(config=config.ContextConfig) diff --git a/contrib/networks/run_training.py b/contrib/networks/run_training.py index 9e403d28..0746ad6b 100644 --- a/contrib/networks/run_training.py +++ b/contrib/networks/run_training.py @@ -21,11 +21,14 @@ class NetworksTrainingEngine(ExperimentEngine): def __init__(self, bags_collection_type, experiment, load_model, config, create_network_func, - prepare_model_root=True): + prepare_model_root=True, + seed=None): assert(callable(create_network_func)) assert(isinstance(config, DefaultNetworkConfig)) assert(issubclass(bags_collection_type, BagsCollection)) assert(isinstance(load_model, bool)) + assert(isinstance(seed, int) or seed is None) + super(NetworksTrainingEngine, self).__init__(experiment) self.__clear_model_root_before_experiment = prepare_model_root @@ -33,6 +36,7 @@ def __init__(self, bags_collection_type, experiment, self.__create_network_func = create_network_func self.__bags_collection_type = bags_collection_type self.__load_model = load_model + self.__seed = seed def __get_model_dir(self): return self._experiment.DataIO.ModelIO.get_model_dir() @@ -93,7 +97,7 @@ def _handle_iteration(self, it_index): # Run model with callback: - model.run_training(epochs_count=callback.Epochs) + model.run_training(epochs_count=callback.Epochs, seed=self.__seed) del network del model diff --git a/contrib/networks/tests/test_tf_ctx_compile.py b/contrib/networks/tests/test_tf_ctx_compile.py index fb0ab199..b8ca5309 100755 --- a/contrib/networks/tests/test_tf_ctx_compile.py +++ b/contrib/networks/tests/test_tf_ctx_compile.py @@ -25,7 +25,7 @@ def test(self): logger.info("Clases count: {}".format(config.ClassesCount)) init_config(config) - network.compile(config, reset_graph=True) + network.compile(config, reset_graph=True, graph_seed=42) if __name__ == '__main__': diff --git a/contrib/networks/tests/test_tf_ctx_feed.py b/contrib/networks/tests/test_tf_ctx_feed.py index c952d011..2016965e 100755 --- a/contrib/networks/tests/test_tf_ctx_feed.py +++ b/contrib/networks/tests/test_tf_ctx_feed.py @@ -58,7 +58,7 @@ def run_feeding(network, network_config, create_minibatch_func, logger, labels_scaler = ThreeLabelScaler() init_config(network_config) # Init network. - network.compile(config=network_config, reset_graph=True) + network.compile(config=network_config, reset_graph=True, graph_seed=42) minibatch = create_minibatch_func(config=network_config, labels_scaler=labels_scaler) diff --git a/contrib/networks/tests/test_tf_mi_compile.py b/contrib/networks/tests/test_tf_mi_compile.py index 8fda8ba5..90fa1ff2 100755 --- a/contrib/networks/tests/test_tf_mi_compile.py +++ b/contrib/networks/tests/test_tf_mi_compile.py @@ -27,7 +27,7 @@ def mpmi(context_config, context_network): network = MaxPoolingOverSentences(context_network=context_network) init_config(config) - network.compile(config, reset_graph=True) + network.compile(config, reset_graph=True, graph_seed=42) def test(self): logging.basicConfig(level=logging.INFO)