diff --git a/vizier/_src/algorithms/designers/gp_ucb_pe.py b/vizier/_src/algorithms/designers/gp_ucb_pe.py index 2145e5dd9..9fbde843d 100644 --- a/vizier/_src/algorithms/designers/gp_ucb_pe.py +++ b/vizier/_src/algorithms/designers/gp_ucb_pe.py @@ -67,7 +67,8 @@ class UCBPEConfig(eqx.Module): default=10.0, converter=jnp.asarray ) # Probability of selecting the UCB acquisition function when there are no new - # completed trials. + # completed trials. No-op if `optimize_set_acquisition_for_exploration` below + # is True. ucb_overwrite_probability: jt.Float[jt.Array, ''] = eqx.field( default=0.25, converter=jnp.asarray ) @@ -76,6 +77,10 @@ class UCBPEConfig(eqx.Module): pe_overwrite_probability: jt.Float[jt.Array, ''] = eqx.field( default=0.1, converter=jnp.asarray ) + # Whether to optimize the set acquisition function for exploration. + optimize_set_acquisition_for_exploration: bool = eqx.field( + default=False, static=True + ) def __repr__(self): return eqx.tree_pformat(self, short_arrays=False) @@ -170,6 +175,33 @@ def _apply_trust_region( ) +def _apply_trust_region_to_set( + tr: acquisitions.TrustRegion, xs: types.ModelInput, acq_values: jax.Array +) -> jax.Array: + """Applies the trust region to a batch of set acquisition function values. + + Args: + tr: Trust region. + xs: A batch of predictive index point sets of a fixed size. + acq_values: A batch of acquisition function values at predictive index point + sets, shaped as [batch_size]. + + Returns: + Acquisition function values with trust region applied, shaped as + [batch_size]. + """ + distance = tr.min_linf_distance(xs) # [batch_size, index_point_set_size] + # Due to output normalization, acquisition values can't be as low as -1e12. + # We penalize the acquisition values by an amount that decreases in the + # total distances to the trust region so that acquisition optimizer can follow + # the gradient and escape untrustred regions. + return acq_values + jnp.sum( + ((distance > tr.trust_radius) & (tr.trust_radius <= 0.5)) + * (-1e12 - distance), + axis=1, + ) + + def _get_features_shape( features: types.ModelInput, ) -> types.ContinuousAndCategorical: @@ -233,11 +265,11 @@ def score_with_aux( class PEScoreFunction(eqx.Module): """Computes the Pure-Exploration acquisition value. - The PE acquisition value is the predicted standard deviation based on - all suggestions, completed and pending, plus a penalty term that grows - linearly in the amount of violation of the constraint - `UCB(xs) >= threshold`. This class follows the `acquisitions.ScoreFunction` - protocol. + The PE acquisition value is the predicted standard deviation (eq. (9) + in https://arxiv.org/pdf/1304.5350) based on all completed and active trials, + plus a penalty term that grows linearly in the amount of violation of the + constraint `UCB(xs) >= threshold`. This class follows the + `acquisitions.ScoreFunction` protocol. Attributes: predictive: Predictive model with cached Cholesky conditioned on completed @@ -303,6 +335,98 @@ def score_with_aux( } +def _logdet(matrix: jax.Array): + """Computes the log-determinant of a symmetric and positive-definite matrix. + + Args: + matrix: A square matrix. + + Returns: + The log-determinant of `matrix`. If `matrix` is not symmetric or not + positive-definite, the result is invalid and may be -inf. + """ + cholesky_matrix = jnp.linalg.cholesky(matrix) + output = 2.0 * jnp.sum(jnp.log(jnp.linalg.diagonal(cholesky_matrix)), axis=-1) + return jnp.where(jnp.isnan(output), -jnp.inf, output) + + +class SetPEScoreFunction(eqx.Module): + """Computes the Pure-Exploration acquisition value over sets. + + The PE acquisition value over a set of points is the log-determinant of the + predicted covariance matrix evaluated at the points (eq. (8) in + https://arxiv.org/pdf/1304.5350) based on all completed and active trials, + plus a penalty term that grows linearly in the amount of violation of the + constraint `UCB(xs) >= threshold`. This class follows the + `acquisitions.ScoreFunction` protocol. + + Attributes: + predictive: Predictive model with cached Cholesky conditioned on completed + trials. + predictive_all_features: Predictive model with cached Cholesky conditioned + on completed and pending trials. + ucb_coefficient: The UCB coefficient used to compute the threshold. + explore_ucb_coefficient: The UCB coefficient used for computing the UCB + values on `xs`. + penalty_coefficient: Multiplier on the constraint violation penalty. + trust_region: + + Returns: + The Pure-Exploration acquisition value. + """ + + predictive: sp.UniformEnsemblePredictive + predictive_all_features: sp.UniformEnsemblePredictive + ucb_coefficient: jt.Float[jt.Array, ''] + explore_ucb_coefficient: jt.Float[jt.Array, ''] + penalty_coefficient: jt.Float[jt.Array, ''] + trust_region: Optional[acquisitions.TrustRegion] + + def score( + self, xs: types.ModelInput, seed: Optional[jax.Array] = None + ) -> jax.Array: + return self.score_with_aux(xs, seed=seed)[0] + + def aux( + self, xs: types.ModelInput, seed: Optional[jax.Array] = None + ) -> chex.ArrayTree: + return self.score_with_aux(xs, seed=seed)[1] + + def score_with_aux( + self, xs: types.ModelInput, seed: Optional[jax.Array] = None + ) -> tuple[jax.Array, chex.ArrayTree]: + del seed + features = self.predictive_all_features.predictives.observed_data.features + is_missing = ( + features.continuous.is_missing[0] | features.categorical.is_missing[0] + ) + gprm_threshold = self.predictive.predict(features) + threshold = _compute_ucb_threshold( + gprm_threshold, is_missing, self.ucb_coefficient + ) + gprm = self.predictive.predict(xs) + mean = gprm.mean() + stddev = gprm.stddev() + explore_ucb = mean + stddev * self.explore_ucb_coefficient + + gprm_all = self.predictive_all_features.predict(xs) + cov = gprm_all.covariance() + acq_values = _logdet(cov) + self.penalty_coefficient * jnp.sum( + jnp.minimum( + explore_ucb - threshold, + 0.0, + ), + axis=1, + ) + if self.trust_region is not None: + acq_values = _apply_trust_region_to_set(self.trust_region, xs, acq_values) + return acq_values, { + 'mean': mean, + 'stddev': stddev, + 'stddev_from_all': jnp.sqrt(jnp.diagonal(cov, axis1=1, axis2=2)), + } + + def default_ard_optimizer() -> optimizers.Optimizer[types.ParameterDict]: return optimizers.JaxoptScipyLbfgsB( options=optimizers.LbfgsBOptions( @@ -637,6 +761,234 @@ def _get_predictive_all_features( predictives=eqx.filter_jit(model.precompute_predictive)(all_data) ) + def _suggest_one( + self, + active_trials: Sequence[vz.Trial], + data: types.ModelData, + model: sp.StochasticProcessWithCoroutine, + predictive: sp.UniformEnsemblePredictive, + tr: acquisitions.TrustRegion, + acquisition_problem: vz.ProblemStatement, + ) -> vz.TrialSuggestion: + """Generates one suggestion.""" + start_time = datetime.datetime.now() + self._rng, rng = jax.random.split(self._rng, 2) + if _has_new_completed_trials( + completed_trials=self._all_completed_trials, + active_trials=active_trials, + ): + # When there are trials completed after all active trials were created, + # we optimize the UCB acquisition function except with a small + # probability the PE acquisition function to ensure exploration. + use_ucb = not jax.random.bernoulli( + key=rng, p=self._config.pe_overwrite_probability + ) + else: + has_completed_trials = len(self._all_completed_trials) > 0 # pylint:disable=g-explicit-length-test + # When there are no trials completed after all active trials were + # created, we optimize the PE acquisition function except with a small + # probability the UCB acquisition function, in case the UCB acquisition + # function is not well optimized. + use_ucb = has_completed_trials and jax.random.bernoulli( + key=rng, p=self._config.ucb_overwrite_probability + ) + + # TODO: Feed the eagle strategy with completed trials. + # TODO: Change budget based on requested suggestion count. + acquisition_optimizer = self._acquisition_optimizer_factory(self._converter) + + if active_trials: + pending_features = self._converter.to_features(active_trials) + predictive_all_features = self._get_predictive_all_features( + pending_features, data, model + ) + else: + predictive_all_features = predictive + + # When `use_ucb` is true, the acquisition function computes the UCB + # values. Otherwise, it computes the Pure-Exploration acquisition values. + if use_ucb: + scoring_fn = UCBScoreFunction( + predictive, + predictive_all_features, + ucb_coefficient=self._config.ucb_coefficient, + trust_region=tr if self._use_trust_region else None, + ) + else: + scoring_fn = PEScoreFunction( + predictive, + predictive_all_features, + penalty_coefficient=self._config.cb_violation_penalty_coefficient, + ucb_coefficient=self._config.ucb_coefficient, + explore_ucb_coefficient=self._config.explore_region_ucb_coefficient, + trust_region=tr if self._use_trust_region else None, + ) + + if isinstance(acquisition_optimizer, vb.VectorizedOptimizer): + acq_rng, self._rng = jax.random.split(self._rng) + prior_features = None + if self._all_completed_trials: + prior_features = vb.trials_to_sorted_array( + self._all_completed_trials, self._converter + ) + with profiler.timeit('acquisition_optimizer', also_log=True): + best_candidates = eqx.filter_jit(acquisition_optimizer)( + scoring_fn.score, + prior_features=prior_features, + count=1, + seed=acq_rng, + score_with_aux_fn=scoring_fn.score_with_aux, + ) + jax.block_until_ready(best_candidates) + with profiler.timeit('best_candidates_to_trials', also_log=True): + best_candidate = vb.best_candidates_to_trials( + best_candidates, self._converter + )[0] + elif isinstance(acquisition_optimizer, vza.GradientFreeOptimizer): + # Seed the optimizer with previous trials. + acquisition = self.get_score_fn_on_trials(scoring_fn.score) + best_candidate = acquisition_optimizer.optimize( + acquisition, + acquisition_problem, + count=1, + seed_candidates=copy.deepcopy(self._all_completed_trials), + )[0] + else: + raise ValueError( + f'Unrecognized acquisition_optimizer: {type(acquisition_optimizer)}' + ) + + # Make predictions (in the warped space). + logging.info('Converting the optimization result into suggestion...') + optimal_features = self._converter.to_features([best_candidate]) # [1, D] + aux = eqx.filter_jit(scoring_fn.aux)(optimal_features) + predict_mean = aux['mean'] # [1,] + predict_stddev = aux['stddev'] # [1,] + predict_stddev_from_all = aux['stddev_from_all'] # [1,] + acquisition = best_candidate.final_measurement_or_die.metrics.get_value( + 'acquisition', float('nan') + ) + logging.info( + 'Created predictions for the best candidates which were converted to' + f' an array of shape: {_get_features_shape(optimal_features)}. mean' + f' has shape {predict_mean.shape}. stddev has shape' + f' {predict_stddev.shape}.stddev_from_all has shape' + f' {predict_stddev_from_all.shape}. acquisition value of' + f' best_candidate: {acquisition}, use_ucb: {use_ucb}' + ) + + # Create a suggestion, injecting the predictions as metadata for + # debugging needs. + metadata = best_candidate.metadata.ns(self._metadata_ns) + metadata.ns('prediction_in_warped_y_space').update({ + 'mean': f'{predict_mean[0]}', + 'stddev': f'{predict_stddev[0]}', + 'stddev_from_all': f'{predict_stddev_from_all[0]}', + 'acquisition': f'{acquisition}', + 'use_ucb': f'{use_ucb}', + 'trust_radius': f'{tr.trust_radius}', + 'params': f'{model.params}', + }) + metadata.ns('timing').update( + {'time': f'{datetime.datetime.now() - start_time}'} + ) + return vz.TrialSuggestion( + best_candidate.parameters, metadata=best_candidate.metadata + ) + + def _suggest_batch_with_exploration( + self, + count: int, + active_trials: Sequence[vz.Trial], + data: types.ModelData, + model: sp.StochasticProcessWithCoroutine, + predictive: sp.UniformEnsemblePredictive, + tr: acquisitions.TrustRegion, + ): + """Generates a batch of suggestions with exploration.""" + start_time = datetime.datetime.now() + if active_trials: + pending_features = self._converter.to_features(active_trials) + predictive_all_features = self._get_predictive_all_features( + pending_features, data, model + ) + else: + predictive_all_features = predictive + + scoring_fn = SetPEScoreFunction( + predictive, + predictive_all_features, + penalty_coefficient=self._config.cb_violation_penalty_coefficient, + ucb_coefficient=self._config.ucb_coefficient, + explore_ucb_coefficient=self._config.explore_region_ucb_coefficient, + trust_region=tr if self._use_trust_region else None, + ) + + acquisition_optimizer = self._acquisition_optimizer_factory(self._converter) + + acq_rng, self._rng = jax.random.split(self._rng) + with profiler.timeit('acquisition_optimizer', also_log=True): + best_candidates = eqx.filter_jit(acquisition_optimizer)( + scoring_fn.score, + count=1, + seed=acq_rng, + score_with_aux_fn=scoring_fn.score_with_aux, + n_parallel=count, + ) + jax.block_until_ready(best_candidates) + with profiler.timeit('best_candidates_to_trials', also_log=True): + trials = vb.best_candidates_to_trials(best_candidates, self._converter)[ + :count + ] + + optimal_features = self._converter.to_features(trials) # [count, D] + aux = eqx.filter_jit(scoring_fn.aux)( + jax.tree_util.tree_map( + lambda x: jnp.expand_dims(x, axis=0), optimal_features + ) + ) + predict_mean = aux['mean'] # [1, count] + predict_stddev = aux['stddev'] # [1, count] + predict_stddev_from_all = aux['stddev_from_all'] # [1, count] + acquisition = trials[0].final_measurement_or_die.metrics.get_value( + 'acquisition', float('nan') + ) + logging.info( + 'Created predictions for the best candidates which were converted to' + f' an array of shape: {_get_features_shape(optimal_features)}. mean' + f' has shape {predict_mean.shape}. stddev has shape' + f' {predict_stddev.shape}.stddev_from_all has shape' + f' {predict_stddev_from_all.shape}. acquisition value of' + f' best_candidate: {acquisition}, use_ucb: False' + ) + + logging.info( + 'Converting the optimization result into %d suggestions...', count + ) + suggestions = [] + end_time = datetime.datetime.now() + for idx, best_candidate in enumerate(trials): + # Make predictions (in the warped space). + # Create suggestions, injecting the predictions as metadata for + # debugging needs. + metadata = best_candidate.metadata.ns(self._metadata_ns) + metadata.ns('prediction_in_warped_y_space').update({ + 'mean': f'{predict_mean[0, idx]}', + 'stddev': f'{predict_stddev[0, idx]}', + 'stddev_from_all': f'{predict_stddev_from_all[0, idx]}', + 'acquisition': f'{acquisition}', + 'use_ucb': 'False', + 'trust_radius': f'{tr.trust_radius}', + 'params': f'{model.params}', + }) + metadata.ns('timing').update({'time': f'{end_time - start_time}'}) + suggestions.append( + vz.TrialSuggestion( + best_candidate.parameters, metadata=best_candidate.metadata + ) + ) + return suggestions + @profiler.record_runtime(name_prefix='VizierGPUCBPEBandit', name='suggest') def suggest( self, count: Optional[int] = None @@ -650,7 +1002,6 @@ def suggest( jax.clear_caches() self._rng, rng = jax.random.split(self._rng, 2) - next_suggestion_start_time = datetime.datetime.now() data = self._trials_to_data(self._all_completed_trials) model = self._build_gp_model_and_optimize_parameters(data, rng) predictive = sp.UniformEnsemblePredictive( @@ -692,137 +1043,34 @@ def suggest( # TODO: Feed the eagle strategy with completed trials. # TODO: Change budget based on requested suggestion count. - suggestions = [] active_trials = list(self._all_active_trials) - for _ in range(count): - self._rng, rng = jax.random.split(self._rng, 2) + if count <= 1: + return [ + self._suggest_one( + active_trials, data, model, predictive, tr, acquisition_problem + ) + ] + + suggestions = [] + if self._config.optimize_set_acquisition_for_exploration: if _has_new_completed_trials( completed_trials=self._all_completed_trials, active_trials=active_trials, ): - # When there are trials completed after all active trials were created, - # we optimize the UCB acquisition function except with a small - # probability the PE acquisition function to ensure exploration. - use_ucb = not jax.random.bernoulli( - key=rng, p=self._config.pe_overwrite_probability - ) - else: - has_completed_trials = len(self._all_completed_trials) > 0 # pylint:disable=g-explicit-length-test - # When there are no trials completed after all active trials were - # created, we optimize the PE acquisition function except with a small - # probability the UCB acquisition function, in case the UCB acquisition - # function is not well optimized. - use_ucb = has_completed_trials and jax.random.bernoulli( - key=rng, p=self._config.ucb_overwrite_probability + suggestions.append( + self._suggest_one( + active_trials, data, model, predictive, tr, acquisition_problem + ) ) - - # TODO: Feed the eagle strategy with completed trials. - # TODO: Change budget based on requested suggestion count. - acquisition_optimizer = self._acquisition_optimizer_factory( - self._converter + return suggestions + self._suggest_batch_with_exploration( + count - len(suggestions), active_trials, data, model, predictive, tr ) - - if active_trials: - pending_features = self._converter.to_features(active_trials) - predictive_all_features = self._get_predictive_all_features( - pending_features, data, model - ) - else: - predictive_all_features = predictive - - # When `use_ucb` is true, the acquisition function computes the UCB - # values. Otherwise, it computes the Pure-Exploration acquisition values. - if use_ucb: - scoring_fn = UCBScoreFunction( - predictive, - predictive_all_features, - ucb_coefficient=self._config.ucb_coefficient, - trust_region=tr if self._use_trust_region else None, - ) - else: - scoring_fn = PEScoreFunction( - predictive, - predictive_all_features, - penalty_coefficient=self._config.cb_violation_penalty_coefficient, - ucb_coefficient=self._config.ucb_coefficient, - explore_ucb_coefficient=self._config.explore_region_ucb_coefficient, - trust_region=tr if self._use_trust_region else None, - ) - - if isinstance(acquisition_optimizer, vb.VectorizedOptimizer): - acq_rng, self._rng = jax.random.split(self._rng) - prior_features = None - if self._all_completed_trials: - prior_features = vb.trials_to_sorted_array( - self._all_completed_trials, self._converter - ) - with profiler.timeit('acquisition_optimizer', also_log=True): - best_candidates = eqx.filter_jit(acquisition_optimizer)( - scoring_fn.score, - prior_features=prior_features, - count=1, - seed=acq_rng, - score_with_aux_fn=scoring_fn.score_with_aux, - ) - jax.block_until_ready(best_candidates) - with profiler.timeit('best_candidates_to_trials', also_log=True): - best_candidate = vb.best_candidates_to_trials( - best_candidates, self._converter - )[0] - elif isinstance(acquisition_optimizer, vza.GradientFreeOptimizer): - # Seed the optimizer with previous trials. - acquisition = self.get_score_fn_on_trials(scoring_fn.score) - best_candidate = acquisition_optimizer.optimize( - acquisition, - acquisition_problem, - count=1, - seed_candidates=copy.deepcopy(self._all_completed_trials), - )[0] - else: - raise ValueError( - f'Unrecognized acquisition_optimizer: {type(acquisition_optimizer)}' + else: + for _ in range(count): + suggestions.append( + self._suggest_one( + active_trials, data, model, predictive, tr, acquisition_problem + ) ) - - # Make predictions (in the warped space). - logging.info('Converting the optimization result into suggestion...') - optimal_features = self._converter.to_features([best_candidate]) # [1, D] - aux = eqx.filter_jit(scoring_fn.aux)(optimal_features) - predict_mean = aux['mean'] # [1,] - predict_stddev = aux['stddev'] # [1,] - predict_stddev_from_all = aux['stddev_from_all'] # [1,] - acquisition = best_candidate.final_measurement_or_die.metrics.get_value( - 'acquisition', float('nan') - ) - logging.info( - 'Created predictions for the best candidates which were converted to' - f' an array of shape: {_get_features_shape(optimal_features)}. mean' - f' has shape {predict_mean.shape}. stddev has shape' - f' {predict_stddev.shape}.stddev_from_all has shape' - f' {predict_stddev_from_all.shape}. acquisition value of' - f' best_candidate: {acquisition}, use_ucb: {use_ucb}' - ) - - # Create suggestions, injecting the predictions as metadata for - # debugging needs. - metadata = best_candidate.metadata.ns(self._metadata_ns) - metadata.ns('prediction_in_warped_y_space').update({ - 'mean': f'{predict_mean[0]}', - 'stddev': f'{predict_stddev[0]}', - 'stddev_from_all': f'{predict_stddev_from_all[0]}', - 'acquisition': f'{acquisition}', - 'use_ucb': f'{use_ucb}', - 'trust_radius': f'{tr.trust_radius}', - 'params': f'{model.params}', - }) - metadata.ns('timing').update( - {'time': f'{datetime.datetime.now() - next_suggestion_start_time}'} - ) - suggestions.append( - vz.TrialSuggestion( - best_candidate.parameters, metadata=best_candidate.metadata - ) - ) - active_trials.append(suggestions[-1].to_trial()) - next_suggestion_start_time = datetime.datetime.now() - - return suggestions + active_trials.append(suggestions[-1].to_trial()) + return suggestions diff --git a/vizier/_src/algorithms/designers/gp_ucb_pe_test.py b/vizier/_src/algorithms/designers/gp_ucb_pe_test.py index 47b8f89a1..12ff1e549 100644 --- a/vizier/_src/algorithms/designers/gp_ucb_pe_test.py +++ b/vizier/_src/algorithms/designers/gp_ucb_pe_test.py @@ -20,6 +20,7 @@ from typing import Any, Tuple import jax +import numpy as np from vizier import pyvizier as vz from vizier._src.algorithms.core import abstractions from vizier._src.algorithms.designers import gp_ucb_pe @@ -56,6 +57,19 @@ class GpUcbPeTest(parameterized.TestCase): dict(iters=5, batch_size=3, num_seed_trials=2, ensemble_size=3), dict(iters=3, batch_size=5, num_seed_trials=5, applies_padding=True), dict(iters=5, batch_size=1, num_seed_trials=2, pe_overwrite=True), + dict( + iters=3, + batch_size=5, + num_seed_trials=5, + optimize_set_acquisition_for_exploration=True, + ), + dict( + iters=3, + batch_size=5, + num_seed_trials=5, + applies_padding=True, + optimize_set_acquisition_for_exploration=True, + ), ) def test_on_flat_continuous_space( self, @@ -66,6 +80,7 @@ def test_on_flat_continuous_space( ensemble_size: int = 1, applies_padding: bool = False, pe_overwrite: bool = False, + optimize_set_acquisition_for_exploration: bool = False, ): # We use string names so that test case names are readable. Convert them # to objects. @@ -95,6 +110,9 @@ def test_on_flat_continuous_space( cb_violation_penalty_coefficient=10.0, ucb_overwrite_probability=0.0, pe_overwrite_probability=1.0 if pe_overwrite else 0.0, + optimize_set_acquisition_for_exploration=( + optimize_set_acquisition_for_exploration + ), ), ensemble_size=ensemble_size, padding_schedule=padding.PaddingSchedule( @@ -156,15 +174,17 @@ def test_on_flat_continuous_space( _, _, _, acq, use_ucb = _extract_predictions( all_trials[jdx].metadata.ns('gp_ucb_pe_bandit_test') ) - self.assertGreaterEqual(acq, 0.0, msg=f'suggestion: {jdx}') self.assertFalse(use_ucb) + if not optimize_set_acquisition_for_exploration: + self.assertGreaterEqual(acq, 0.0, msg=f'suggestion: {jdx}') for idx in range(2, iters + 2): # Skips seed trials, which are not generated by acquisition function # optimization. if idx * batch_size < num_seed_trials: continue - + set_acq_value = None + stddev_from_all_list = [] for jdx in range(batch_size): mean, stddev, stddev_from_all, acq, use_ucb = _extract_predictions( all_trials[idx * batch_size + jdx].metadata.ns( @@ -178,6 +198,15 @@ def test_on_flat_continuous_space( # predicted standard deviation based on all trials. self.assertAlmostEqual(mean + 10.0 * stddev_from_all, acq) self.assertTrue(use_ucb) + continue + + self.assertFalse(use_ucb) + if optimize_set_acquisition_for_exploration: + stddev_from_all_list.append(stddev_from_all) + if set_acq_value is None: + set_acq_value = acq + else: + self.assertAlmostEqual(set_acq_value, acq) else: # Because `ucb_overwrite_probability` is set to 0.0, when the designer # makes suggestions without seeing newer completed trials, it uses the @@ -190,7 +219,16 @@ def test_on_flat_continuous_space( self.assertLessEqual( acq, 2 * stddev, msg=f'batch: {idx}, suggestion: {jdx}' ) - self.assertFalse(use_ucb) + if optimize_set_acquisition_for_exploration: + geometric_mean_of_pred_cov_eigs = np.exp( + set_acq_value / (batch_size - 1) + ) + arithmetic_mean_of_pred_cov_eigs = np.mean( + np.square(stddev_from_all_list) + ) + self.assertLessEqual( + geometric_mean_of_pred_cov_eigs, arithmetic_mean_of_pred_cov_eigs + ) def test_ucb_overwrite(self): problem = vz.ProblemStatement(