From f8d64c80f042074a692589a8fe8df8d2ba650dbd Mon Sep 17 00:00:00 2001 From: Yating Date: Mon, 25 Sep 2023 11:31:33 -0400 Subject: [PATCH] Hparams: Sort and apply limit to hparams from runs. (#6588) Apply `hparams_limit` to hyperparameters from tensors, and sort them by `differs` field (differed first). #hparams --- .../plugins/hparams/backend_context.py | 4 +- .../plugins/hparams/backend_context_test.py | 161 ++++++++++++++++++ 2 files changed, 164 insertions(+), 1 deletion(-) diff --git a/tensorboard/plugins/hparams/backend_context.py b/tensorboard/plugins/hparams/backend_context.py index ff248d4c5d..b05ef8df04 100644 --- a/tensorboard/plugins/hparams/backend_context.py +++ b/tensorboard/plugins/hparams/backend_context.py @@ -104,7 +104,9 @@ def experiment_from_metadata( ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content ) if experiment_from_runs: - # TODO(yatbear): Apply `hparams_limit` to `experiment_from_runs`. + _sort_and_reduce_to_hparams_limit( + experiment_from_runs, hparams_limit + ) return experiment_from_runs experiment_from_data_provider_hparams = ( diff --git a/tensorboard/plugins/hparams/backend_context_test.py b/tensorboard/plugins/hparams/backend_context_test.py index 2f36c4a6d9..a8e1aa4a2c 100644 --- a/tensorboard/plugins/hparams/backend_context_test.py +++ b/tensorboard/plugins/hparams/backend_context_test.py @@ -1185,6 +1185,167 @@ def test_experiment_from_tags_sorts_differed_hparams_first(self): ) self.assertProtoEquals(expected_exp, actual_exp) + def test_experiment_from_runs_with_hparams_limit_no_differed_hparams(self): + self.session_1_start_info_ = """ + hparams: [ + {key: 'lr' value: {number_value: 100}}, + {key: 'model_type' value: {string_value: 'LATTICE'}}, + {key: 'use_batch_norm' value: {bool_value: true}} + ] + """ + self.session_2_start_info_ = """ + hparams: [ + {key: 'lr' value: {number_value: 100}}, + {key: 'model_type' value: {string_value: 'LATTICE'}}, + {key: 'use_batch_norm' value: {bool_value: true}} + ] + """ + self.session_3_start_info_ = """ + hparams: [ + {key: 'lr' value: {number_value: 100}}, + {key: 'model_type' value: {string_value: 'LATTICE'}}, + {key: 'use_batch_norm' value: {bool_value: true}} + ] + """ + expected_exp = """ + hparam_infos: { + name: 'use_batch_norm' + type: DATA_TYPE_BOOL + domain_discrete: { + values: [{bool_value: true}] + } + differs: false + } + hparam_infos: { + name: 'model_type' + type: DATA_TYPE_STRING + domain_discrete: { + values: [{string_value: 'LATTICE'}] + } + differs: false + } + """ + actual_exp = self._experiment_from_metadata( + include_metrics=False, hparams_limit=2 + ) + self.assertProtoEquals(expected_exp, actual_exp) + + def test_experiment_from_runs_with_hparams_limit_returns_differed_hparams_first( + self, + ): + self.session_1_start_info_ = """ + hparams: [ + {key: 'batch_size' value: {number_value: 200}}, + {key: 'lr' value: {number_value: 0.01}}, + {key: 'model_type' value: {string_value: 'CNN'}} + ] + """ + self.session_2_start_info_ = """ + hparams: [ + {key: 'batch_size' value: {number_value: 200}}, + {key: 'lr' value: {number_value: 0.02}}, + {key: 'model_type' value: {string_value: 'LATTICE'}} + ] + """ + self.session_3_start_info_ = """ + hparams: [ + {key: 'batch_size' value: {number_value: 200}}, + {key: 'lr' value: {number_value: 0.05}}, + {key: 'model_type' value: {string_value: 'CNN'}} + ] + """ + expected_exp = """ + hparam_infos: { + name: 'lr' + type: DATA_TYPE_FLOAT64 + domain_interval { + min_value: 0.01 + max_value: 0.05 + } + differs: true + } + hparam_infos: { + name: 'model_type' + type: DATA_TYPE_STRING + domain_discrete: { + values: [{string_value: 'CNN'}, + {string_value: 'LATTICE'}] + } + differs: true + } + """ + actual_exp = self._experiment_from_metadata( + include_metrics=False, hparams_limit=2 + ) + _canonicalize_experiment(actual_exp) + self.assertProtoEquals(expected_exp, actual_exp) + + def test_experiment_from_runs_sorts_differed_hparams_first(self): + self.session_1_start_info_ = """ + hparams: [ + {key: 'batch_size' value: {number_value: 200}}, + {key: 'lr' value: {number_value: 0.01}}, + {key: 'model_type' value: {string_value: 'CNN'}}, + {key: 'use_batch_norm' value: {bool_value: false}} + ] + """ + self.session_2_start_info_ = """ + hparams: [ + {key: 'batch_size' value: {number_value: 300}}, + {key: 'lr' value: {number_value: 0.01}}, + {key: 'model_type' value: {string_value: 'CNN'}}, + {key: 'use_batch_norm' value: {bool_value: false}} + ] + """ + self.session_3_start_info_ = """ + hparams: [ + {key: 'batch_size' value: {number_value: 100}}, + {key: 'lr' value: {number_value: 0.01}}, + {key: 'model_type' value: {string_value: 'CNN'}}, + {key: 'use_batch_norm' value: {bool_value: true}} + ] + """ + expected_exp = """ + hparam_infos: { + name: 'use_batch_norm' + type: DATA_TYPE_BOOL + domain_discrete: { + values: [{bool_value: false}, {bool_value: true}] + } + differs: true + } + hparam_infos: { + name: 'batch_size' + type: DATA_TYPE_FLOAT64 + domain_interval { + min_value: 100 + max_value: 300 + } + differs: true + } + hparam_infos: { + name: 'model_type' + type: DATA_TYPE_STRING + domain_discrete: { + values: [{string_value: 'CNN'}] + } + differs: false + } + hparam_infos: { + name: 'lr' + type: DATA_TYPE_FLOAT64 + domain_interval { + min_value: 0.01 + max_value: 0.01 + } + differs: false + } + """ + actual_exp = self._experiment_from_metadata( + include_metrics=False, hparams_limit=None + ) + self.assertProtoEquals(expected_exp, actual_exp) + def _serialized_plugin_data(self, data_oneof_field, text_protobuffer): oneof_type_dict = { DATA_TYPE_EXPERIMENT: api_pb2.Experiment,