diff --git a/tensorboard/plugins/hparams/backend_context.py b/tensorboard/plugins/hparams/backend_context.py index 700c7b50fc..dc767252ce 100644 --- a/tensorboard/plugins/hparams/backend_context.py +++ b/tensorboard/plugins/hparams/backend_context.py @@ -60,6 +60,7 @@ def experiment_from_metadata( include_metrics, hparams_run_to_tag_to_content, data_provider_hparams, + hparams_limit=None, ): """Returns the experiment proto defining the experiment. @@ -85,6 +86,8 @@ def experiment_from_metadata( data_provider_hparams: The ouput from an hparams_from_data_provider() call, corresponding to DataProvider.list_hyperparameters(). A provider.ListHyperpararametersResult. + hparams_limit: Optional number of hyperparameter metadata to include in the + result. If unset or zero, all metadata will be included. Returns: The experiment proto. If no data is found for an experiment proto to @@ -94,12 +97,15 @@ def experiment_from_metadata( hparams_run_to_tag_to_content, include_metrics ) if experiment: + _sort_and_reduce_to_hparams_limit(experiment, hparams_limit) return experiment experiment_from_runs = self._compute_experiment_from_runs( ctx, experiment_id, include_metrics, hparams_run_to_tag_to_content ) if experiment_from_runs: + # TODO(yatbear): Apply `hparams_limit` to `experiment_from_runs` after `differs` + # fields are populated in `_compute_hparam_info_from_values()`. return experiment_from_runs experiment_from_data_provider_hparams = ( @@ -325,6 +331,7 @@ def _compute_hparam_info_from_values(self, name, values): if result.type == api_pb2.DATA_TYPE_UNSET: return None + # TODO(yatbear): Populate `differs` fields for hparams once go/tbpr/6574 is merged. if result.type == api_pb2.DATA_TYPE_STRING: distinct_string_values = set( _protobuf_value_to_string(v) @@ -576,3 +583,28 @@ def _protobuf_value_to_string(value): # Remove the quotations. return value_in_json[1:-1] return value_in_json + + +def _sort_and_reduce_to_hparams_limit(experiment, hparams_limit=None): + """Sorts and applies limit to the hparams in the given experiment proto. + + Args: + experiment: An api_pb2.Experiment proto, which will be modified in place. + hparams_limit: Optional number of hyperparameter metadata to include in the + result. If unset or zero, no limit will be applied. + + Returns: + None. `experiment` proto will be modified in place. + """ + if not hparams_limit: + hparams_limit = -1 + + # Prioritizes returning HParamInfo protos with `differed` values. + limited_hparam_infos = sorted( + experiment.hparam_infos, + key=lambda hparam_info: hparam_info.differs, + reverse=True, + )[:hparams_limit] + + experiment.ClearField("hparam_infos") + experiment.hparam_infos.extend(limited_hparam_infos) diff --git a/tensorboard/plugins/hparams/backend_context_test.py b/tensorboard/plugins/hparams/backend_context_test.py index b3d91cb793..2f3de2e004 100644 --- a/tensorboard/plugins/hparams/backend_context_test.py +++ b/tensorboard/plugins/hparams/backend_context_test.py @@ -153,7 +153,9 @@ def _mock_list_hyperparameters( ): return self._hyperparameters - def _experiment_from_metadata(self, *, include_metrics=True): + def _experiment_from_metadata( + self, *, include_metrics=True, hparams_limit=None + ): """Calls the expected operations for generating an Experiment proto.""" ctxt = backend_context.Context(self._mock_tb_context) request_ctx = context.RequestContext() @@ -162,7 +164,10 @@ def _experiment_from_metadata(self, *, include_metrics=True): "123", include_metrics, ctxt.hparams_metadata(request_ctx, "123"), - ctxt.hparams_from_data_provider(request_ctx, "123", limit=None), + ctxt.hparams_from_data_provider( + request_ctx, "123", limit=hparams_limit + ), + hparams_limit, ) def test_experiment_with_experiment_tag(self): @@ -897,6 +902,178 @@ def test_experiment_from_data_provider_old_response_type(self): """ self.assertProtoEquals(expected_exp, actual_exp) + def test_experiment_from_tags_with_hparams_limit_no_differed_hparams(self): + experiment = """ + name: 'Test experiment' + hparam_infos: { + name: 'batch_size' + type: DATA_TYPE_FLOAT64 + differs: false + } + hparam_infos: { + name: 'lr' + type: DATA_TYPE_FLOAT64 + differs: false + } + hparam_infos: { + name: 'use_batch_norm' + type: DATA_TYPE_BOOL + differs: false + } + hparam_infos: { + name: 'model_type' + type: DATA_TYPE_STRING + differs: false + } + """ + t = provider.TensorTimeSeries( + max_step=0, + max_wall_time=0, + plugin_content=self._serialized_plugin_data( + DATA_TYPE_EXPERIMENT, experiment + ), + description="", + display_name="", + ) + self._mock_tb_context.data_provider.list_tensors.side_effect = None + self._mock_tb_context.data_provider.list_tensors.return_value = { + "train": {metadata.EXPERIMENT_TAG: t} + } + expected_exp = """ + name: 'Test experiment' + hparam_infos: { + name: 'batch_size' + type: DATA_TYPE_FLOAT64 + differs: false + } + hparam_infos: { + name: 'lr' + type: DATA_TYPE_FLOAT64 + differs: false + } + """ + actual_exp = self._experiment_from_metadata( + include_metrics=False, hparams_limit=2 + ) + self.assertProtoEquals(expected_exp, actual_exp) + + def test_experiment_from_tags_with_hparams_limit_returns_differed_hparams_first( + self, + ): + experiment = """ + name: 'Test experiment' + hparam_infos: { + name: 'batch_size' + type: DATA_TYPE_FLOAT64 + differs: false + } + hparam_infos: { + name: 'lr' + type: DATA_TYPE_FLOAT64 + differs: true + } + hparam_infos: { + name: 'use_batch_norm' + type: DATA_TYPE_BOOL + differs: false + } + hparam_infos: { + name: 'model_type' + type: DATA_TYPE_STRING + differs: true + } + """ + t = provider.TensorTimeSeries( + max_step=0, + max_wall_time=0, + plugin_content=self._serialized_plugin_data( + DATA_TYPE_EXPERIMENT, experiment + ), + description="", + display_name="", + ) + self._mock_tb_context.data_provider.list_tensors.side_effect = None + self._mock_tb_context.data_provider.list_tensors.return_value = { + "train": {metadata.EXPERIMENT_TAG: t} + } + expected_exp = """ + name: 'Test experiment' + hparam_infos: { + name: 'lr' + type: DATA_TYPE_FLOAT64 + differs: true + }, + hparam_infos: { + name: 'model_type' + type: DATA_TYPE_STRING + differs: true + } + """ + actual_exp = self._experiment_from_metadata( + include_metrics=False, hparams_limit=2 + ) + self.assertProtoEquals(expected_exp, actual_exp) + + def test_experiment_from_tags_sorts_differed_hparams_first(self): + experiment = """ + name: 'Test experiment' + hparam_infos: { + name: 'batch_size' + type: DATA_TYPE_FLOAT64 + differs: false + } + hparam_infos: { + name: 'lr' + type: DATA_TYPE_FLOAT64 + differs: true + } + hparam_infos: { + name: 'use_batch_norm' + type: DATA_TYPE_BOOL + differs: false + } + hparam_infos: { + name: 'model_type' + type: DATA_TYPE_STRING + differs: true + } + """ + t = provider.TensorTimeSeries( + max_step=0, + max_wall_time=0, + plugin_content=self._serialized_plugin_data( + DATA_TYPE_EXPERIMENT, experiment + ), + description="", + display_name="", + ) + self._mock_tb_context.data_provider.list_tensors.side_effect = None + self._mock_tb_context.data_provider.list_tensors.return_value = { + "train": {metadata.EXPERIMENT_TAG: t} + } + expected_exp = """ + name: 'Test experiment' + hparam_infos: { + name: 'lr' + type: DATA_TYPE_FLOAT64 + differs: true + } + hparam_infos: { + name: 'model_type' + type: DATA_TYPE_STRING + differs: true + } + hparam_infos: { + name: 'batch_size' + type: DATA_TYPE_FLOAT64 + differs: false + } + """ + actual_exp = self._experiment_from_metadata( + include_metrics=False, hparams_limit=None + ) + self.assertProtoEquals(expected_exp, actual_exp) + def _serialized_plugin_data(self, data_oneof_field, text_protobuffer): oneof_type_dict = { DATA_TYPE_EXPERIMENT: api_pb2.Experiment, diff --git a/tensorboard/plugins/hparams/get_experiment.py b/tensorboard/plugins/hparams/get_experiment.py index 51bcaf2eab..e249b23599 100644 --- a/tensorboard/plugins/hparams/get_experiment.py +++ b/tensorboard/plugins/hparams/get_experiment.py @@ -46,6 +46,13 @@ def run(self): Returns: An Experiment object. """ + data_provider_hparams = ( + self._backend_context.hparams_from_data_provider( + self._request_context, + self._experiment_id, + limit=self._hparams_limit, + ) + ) return self._backend_context.experiment_from_metadata( self._request_context, self._experiment_id, @@ -53,9 +60,6 @@ def run(self): self._backend_context.hparams_metadata( self._request_context, self._experiment_id ), - self._backend_context.hparams_from_data_provider( - self._request_context, - self._experiment_id, - limit=self._hparams_limit, - ), + data_provider_hparams, + self._hparams_limit, )