diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index 0128246107..a38e0fe022 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -109,12 +109,43 @@ def __post_init__(self) -> None: "3-dimensional. Its shape is " f"{batch_initial_conditions_shape}." ) + if batch_initial_conditions_shape[-1] != d: raise ValueError( f"batch_initial_conditions.shape[-1] must be {d}. The " f"shape is {batch_initial_conditions_shape}." ) + if len(batch_initial_conditions_shape) == 2: + warnings.warn( + "If using a 2-dim `batch_initial_conditions` botorch will " + "default to old behavior of ignoring `num_restarts` and just " + "use the given `batch_initial_conditions` by setting " + "`raw_samples` to None.", + RuntimeWarning, + stacklevel=3, + ) + # Use object.__setattr__ to bypass immutability and set a value + object.__setattr__(self, "raw_samples", None) + + if ( + len(batch_initial_conditions_shape) == 3 + and batch_initial_conditions_shape[0] < self.num_restarts + and batch_initial_conditions_shape[-2] != self.q + ): + warnings.warn( + "If using a 3-dim `batch_initial_conditions` where the " + "first dimension is less than `num_restarts` and the second " + "dimension is not equal to `q`, botorch will default to " + "old behavior of ignoring `num_restarts` and just use the " + "given `batch_initial_conditions` by setting `raw_samples` " + "to None.", + RuntimeWarning, + stacklevel=3, + ) + # Use object.__setattr__ to bypass immutability and set a value + object.__setattr__(self, "raw_samples", None) + elif self.ic_generator is None: if self.nonlinear_inequality_constraints is not None: raise RuntimeError( @@ -126,6 +157,7 @@ def __post_init__(self) -> None: "Must specify `raw_samples` when " "`batch_initial_conditions` is None`." ) + if self.fixed_features is not None and any( (k < 0 for k in self.fixed_features) ): @@ -253,20 +285,49 @@ def _optimize_acqf_sequential_q( return candidates, torch.stack(acq_value_list) +def _combine_initial_conditions( + provided_initial_conditions: Tensor | None = None, + generated_initial_conditions: Tensor | None = None, + dim=0, +) -> Tensor: + if ( + provided_initial_conditions is not None + and generated_initial_conditions is not None + ): + return torch.cat( + [provided_initial_conditions, generated_initial_conditions], dim=dim + ) + elif provided_initial_conditions is not None: + return provided_initial_conditions + elif generated_initial_conditions is not None: + return generated_initial_conditions + else: + raise ValueError( + "Either `batch_initial_conditions` or `raw_samples` must be set." + ) + + def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor]: options = opt_inputs.options or {} - initial_conditions_provided = opt_inputs.batch_initial_conditions is not None + required_num_restarts = opt_inputs.num_restarts + provided_initial_conditions = opt_inputs.batch_initial_conditions + generated_initial_conditions = None - if initial_conditions_provided: - batch_initial_conditions = opt_inputs.batch_initial_conditions - else: - # pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call. - batch_initial_conditions = opt_inputs.get_ic_generator()( + if ( + provided_initial_conditions is not None + and len(provided_initial_conditions.shape) == 3 + ): + required_num_restarts -= provided_initial_conditions.shape[0] + + if opt_inputs.raw_samples is not None and required_num_restarts > 0: + # pyre-ignore[28]: Unexpected keyword argument `acq_function` + # to anonymous call. + generated_initial_conditions = opt_inputs.get_ic_generator()( acq_function=opt_inputs.acq_function, bounds=opt_inputs.bounds, q=opt_inputs.q, - num_restarts=opt_inputs.num_restarts, + num_restarts=required_num_restarts, raw_samples=opt_inputs.raw_samples, fixed_features=opt_inputs.fixed_features, options=options, @@ -275,6 +336,11 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor **opt_inputs.ic_gen_kwargs, ) + batch_initial_conditions = _combine_initial_conditions( + provided_initial_conditions=provided_initial_conditions, + generated_initial_conditions=generated_initial_conditions, + ) + batch_limit: int = options.get( "batch_limit", ( @@ -344,23 +410,24 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]: first_warn_msg = ( "Optimization failed in `gen_candidates_scipy` with the following " f"warning(s):\n{[w.message for w in ws]}\nBecause you specified " - "`batch_initial_conditions`, optimization will not be retried with " - "new initial conditions and will proceed with the current solution." - " Suggested remediation: Try again with different " - "`batch_initial_conditions`, or don't provide `batch_initial_conditions.`" - if initial_conditions_provided + "`batch_initial_conditions` larger than required `num_restarts`, " + "optimization will not be retried with new initial conditions and " + "will proceed with the current solution. Suggested remediation: " + "Try again with different `batch_initial_conditions`, don't provide " + "`batch_initial_conditions`, or increase `num_restarts`." + if batch_initial_conditions is not None and required_num_restarts <= 0 else "Optimization failed in `gen_candidates_scipy` with the following " f"warning(s):\n{[w.message for w in ws]}\nTrying again with a new " "set of initial conditions." ) warnings.warn(first_warn_msg, RuntimeWarning, stacklevel=2) - if not initial_conditions_provided: - batch_initial_conditions = opt_inputs.get_ic_generator()( + if opt_inputs.raw_samples is not None and required_num_restarts > 0: + generated_initial_conditions = opt_inputs.get_ic_generator()( acq_function=opt_inputs.acq_function, bounds=opt_inputs.bounds, q=opt_inputs.q, - num_restarts=opt_inputs.num_restarts, + num_restarts=required_num_restarts, raw_samples=opt_inputs.raw_samples, fixed_features=opt_inputs.fixed_features, options=options, @@ -369,6 +436,11 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]: **opt_inputs.ic_gen_kwargs, ) + batch_initial_conditions = _combine_initial_conditions( + provided_initial_conditions=provided_initial_conditions, + generated_initial_conditions=generated_initial_conditions, + ) + batch_candidates, batch_acq_values, ws = _optimize_batch_candidates() optimization_warning_raised = any( @@ -1177,7 +1249,7 @@ def _gen_batch_initial_conditions_local_search( inequality_constraints: list[tuple[Tensor, Tensor, float]], min_points: int, max_tries: int = 100, -): +) -> Tensor: """Generate initial conditions for local search.""" device = discrete_choices[0].device dtype = discrete_choices[0].dtype @@ -1197,6 +1269,58 @@ def _gen_batch_initial_conditions_local_search( raise RuntimeError(f"Failed to generate at least {min_points} initial conditions") +def _gen_starting_points_local_search( + discrete_choices: list[Tensor], + raw_samples: int, + batch_initial_conditions: Tensor, + X_avoid: Tensor, + inequality_constraints: list[tuple[Tensor, Tensor, float]], + min_points: int, + acq_function: AcquisitionFunction, + max_batch_size: int = 2048, + max_tries: int = 100, +) -> Tensor: + required_min_points = min_points + provided_X0 = None + generated_X0 = None + + if batch_initial_conditions is not None: + provided_X0 = _filter_invalid( + X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid + ) + provided_X0 = _filter_infeasible( + X=provided_X0, inequality_constraints=inequality_constraints + ).unsqueeze(1) + required_min_points -= batch_initial_conditions.shape[0] + + if required_min_points > 0: + generated_X0 = _gen_batch_initial_conditions_local_search( + discrete_choices=discrete_choices, + raw_samples=raw_samples, + X_avoid=X_avoid, + inequality_constraints=inequality_constraints, + min_points=min_points, + max_tries=max_tries, + ) + + # pick the best starting points + with torch.no_grad(): + acqvals_init = _split_batch_eval_acqf( + acq_function=acq_function, + X=generated_X0.unsqueeze(1), + max_batch_size=max_batch_size, + ).unsqueeze(-1) + + generated_X0 = generated_X0[ + acqvals_init.topk(k=min_points, largest=True, dim=0).indices + ] + + return _combine_initial_conditions( + provided_initial_conditions=provided_X0 if provided_X0 is not None else None, + generated_initial_conditions=generated_X0 if generated_X0 is not None else None, + ) + + def optimize_acqf_discrete_local_search( acq_function: AcquisitionFunction, discrete_choices: list[Tensor], @@ -1207,6 +1331,7 @@ def optimize_acqf_discrete_local_search( X_avoid: Tensor | None = None, batch_initial_conditions: Tensor | None = None, max_batch_size: int = 2048, + max_tries: int = 100, unique: bool = True, ) -> tuple[Tensor, Tensor]: r"""Optimize acquisition function over a lattice. @@ -1238,6 +1363,8 @@ def optimize_acqf_discrete_local_search( max_batch_size: The maximum number of choices to evaluate in batch. A large limit can cause excessive memory usage if the model has a large training set. + max_tries: Maximum number of iterations to try when generating initial + conditions. unique: If True return unique choices, o/w choices may be repeated (only relevant if `q > 1`). @@ -1247,6 +1374,16 @@ def optimize_acqf_discrete_local_search( - a `q x d`-dim tensor of generated candidates. - an associated acquisition value. """ + if batch_initial_conditions is not None: + if not ( + len(batch_initial_conditions.shape) == 3 + and batch_initial_conditions.shape[-2] == 1 + ): + raise ValueError( + "batch_initial_conditions must have shape `n x 1 x d` if " + f"given (received shape {batch_initial_conditions.shape})." + ) + candidate_list = [] base_X_pending = acq_function.X_pending if q > 1 else None base_X_avoid = X_avoid @@ -1259,27 +1396,18 @@ def optimize_acqf_discrete_local_search( inequality_constraints = inequality_constraints or [] for i in range(q): # generate some starting points - if i == 0 and batch_initial_conditions is not None: - X0 = _filter_invalid(X=batch_initial_conditions.squeeze(1), X_avoid=X_avoid) - X0 = _filter_infeasible( - X=X0, inequality_constraints=inequality_constraints - ).unsqueeze(1) - else: - X_init = _gen_batch_initial_conditions_local_search( - discrete_choices=discrete_choices, - raw_samples=raw_samples, - X_avoid=X_avoid, - inequality_constraints=inequality_constraints, - min_points=num_restarts, - ) - # pick the best starting points - with torch.no_grad(): - acqvals_init = _split_batch_eval_acqf( - acq_function=acq_function, - X=X_init.unsqueeze(1), - max_batch_size=max_batch_size, - ).unsqueeze(-1) - X0 = X_init[acqvals_init.topk(k=num_restarts, largest=True, dim=0).indices] + X0 = _gen_starting_points_local_search( + discrete_choices=discrete_choices, + raw_samples=raw_samples, + batch_initial_conditions=batch_initial_conditions, + X_avoid=X_avoid, + inequality_constraints=inequality_constraints, + min_points=num_restarts, + acq_function=acq_function, + max_batch_size=max_batch_size, + max_tries=max_tries, + ) + batch_initial_conditions = None # optimize from the best starting points best_xs = torch.zeros(len(X0), dim, device=device, dtype=dtype) diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py index 84ffbcaf91..136897fe60 100644 --- a/botorch/optim/optimize_homotopy.py +++ b/botorch/optim/optimize_homotopy.py @@ -157,7 +157,6 @@ def optimize_acqf_homotopy( """ shared_optimize_acqf_kwargs = { "num_restarts": num_restarts, - "raw_samples": raw_samples, "inequality_constraints": inequality_constraints, "equality_constraints": equality_constraints, "nonlinear_inequality_constraints": nonlinear_inequality_constraints, @@ -178,6 +177,7 @@ def optimize_acqf_homotopy( for _ in range(q): candidates = batch_initial_conditions + q_raw_samples = raw_samples homotopy.restart() while not homotopy.should_stop: @@ -187,10 +187,15 @@ def optimize_acqf_homotopy( q=1, options=options, batch_initial_conditions=candidates, + raw_samples=q_raw_samples, **shared_optimize_acqf_kwargs, ) homotopy.step() + # Set raw_samples to None such that pruned restarts are not repopulated + # at each step in the homotopy. + q_raw_samples = None + # Prune candidates candidates = prune_candidates( candidates=candidates.squeeze(1), @@ -204,6 +209,7 @@ def optimize_acqf_homotopy( bounds=bounds, q=1, options=final_options, + raw_samples=q_raw_samples, batch_initial_conditions=candidates, **shared_optimize_acqf_kwargs, ) diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 331b86be55..d913882c1d 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -33,6 +33,7 @@ gen_one_shot_kg_initial_conditions, ) from botorch.optim.optimize import ( + _combine_initial_conditions, _filter_infeasible, _filter_invalid, _gen_batch_initial_conditions_local_search, @@ -113,6 +114,39 @@ def rounding_func(X: Tensor) -> Tensor: return X_round.view(*batch_shape, d) +class TestCombineInitialConditions(BotorchTestCase): + def test_combine_both_conditions(self): + provided = torch.randn(1, 3, 4) + generated = torch.randn(2, 3, 4) + + result = _combine_initial_conditions( + provided_initial_conditions=provided, + generated_initial_conditions=generated, + ) + + assert result.shape == (3, 3, 4) # Combined shape + + def test_only_generated_conditions(self): + generated = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + + result = _combine_initial_conditions( + provided_initial_conditions=None, + generated_initial_conditions=generated, + ) + + assert torch.equal(result, generated) + + def test_no_conditions_raises_error(self): + with self.assertRaisesRegex( + ValueError, + "Either `batch_initial_conditions` or `raw_samples` must be set.", + ): + _combine_initial_conditions( + provided_initial_conditions=None, + generated_initial_conditions=None, + ) + + class TestOptimizeAcqf(BotorchTestCase): @mock.patch("botorch.generation.gen.gen_candidates_torch") @mock.patch("botorch.optim.optimize.gen_batch_initial_conditions") @@ -170,7 +204,7 @@ def test_optimize_acqf_joint( cnt += 1 self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) - # test generation with provided initial conditions + # test case where provided initial conditions equal to raw_samples candidates, acq_vals = optimize_acqf( acq_function=mock_acq_function, bounds=bounds, @@ -188,6 +222,23 @@ def test_optimize_acqf_joint( self.assertTrue(torch.equal(acq_vals, mock_acq_values)) self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) + # test generation with batch initial conditions less than num_restarts + candidates, acq_vals = optimize_acqf( + acq_function=mock_acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts + 1, + raw_samples=raw_samples, + options=options, + return_best_only=False, + batch_initial_conditions=torch.zeros( + num_restarts, q, 3, device=self.device, dtype=dtype + ), + gen_candidates=mock_gen_candidates, + ) + cnt += 1 + self.assertEqual(mock_gen_batch_initial_conditions.call_count, cnt) + # test fixed features fixed_features = {0: 0.1} mock_candidates[:, 0] = 0.1 @@ -339,7 +390,9 @@ def test_optimize_acqf_sequential( expected_val = mock_acq_function(expected_candidates.unsqueeze(-2)) else: expected_candidates = base_candidates - expected_val = torch.cat([acqval for _, acqval in gcs_return_vals]) + expected_val = torch.cat( + [acqval for _candidate, acqval in gcs_return_vals] + ) self.assertTrue(torch.equal(candidates, expected_candidates)) self.assertTrue(torch.equal(acq_value, expected_val)) # verify error when using a OneShotAcquisitionFunction @@ -520,22 +573,39 @@ def test_optimize_acqf_sequential_q_constraint_notimplemented(self): ) def test_optimize_acqf_batch_limit(self) -> None: - num_restarts = 3 - raw_samples = 5 + num_restarts = 5 + raw_samples = 16 dim = 4 q = 4 batch_limit = 2 options = {"batch_limit": batch_limit} - initial_conditions = [ - torch.ones(shape) for shape in [(1, 2, dim), (2, 1, dim), (1, dim)] - ] + [None] + initial_conditions = [(1, 2, dim), (3, 1, dim), (3, q, dim), (1, dim), None] + expected_acqf_shapes = [1, 3, num_restarts, 1, num_restarts] + expected_candidates_shapes = [ + (1, 2, dim), + (3, 1, dim), + (num_restarts, q, dim), + (1, dim), + (num_restarts, q, dim), + ] - for gen_candidates, ics in zip( - [gen_candidates_scipy, gen_candidates_torch], initial_conditions + for gen_candidates, ( + ic_shape, + expected_acqf_shape, + expected_candidates_shape, + ) in product( + [gen_candidates_scipy, gen_candidates_torch], + zip( + initial_conditions, + expected_acqf_shapes, + expected_candidates_shapes, + strict=True, + ), ): + ics = torch.ones(ic_shape) if ic_shape is not None else None with self.subTest(gen_candidates=gen_candidates, initial_conditions=ics): - _, acq_value_list = optimize_acqf( + _candidates, acq_value_list = optimize_acqf( acq_function=SinOneOverXAcqusitionFunction(), bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]), q=q, @@ -546,8 +616,29 @@ def test_optimize_acqf_batch_limit(self) -> None: gen_candidates=gen_candidates, batch_initial_conditions=ics, ) - expected_shape = (num_restarts,) if ics is None else (ics.shape[0],) - self.assertEqual(acq_value_list.shape, expected_shape) + + self.assertEqual(acq_value_list.shape, (expected_acqf_shape,)) + self.assertEqual(_candidates.shape, expected_candidates_shape) + + for ic_shape, expected_shape in [((2, 1, dim), 2), ((2, dim), 1)]: + with self.subTest(gen_candidates=gen_candidates): + with self.assertWarnsRegex( + RuntimeWarning, "botorch will default to old behavior" + ): + ics = torch.ones((ic_shape)) + _candidates, acq_value_list = optimize_acqf( + acq_function=SinOneOverXAcqusitionFunction(), + bounds=torch.stack([-1 * torch.ones(dim), torch.ones(dim)]), + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options=options, + return_best_only=False, + gen_candidates=gen_candidates, + batch_initial_conditions=ics, + ) + + self.assertEqual(acq_value_list.shape, (expected_shape,)) def test_optimize_acqf_runs_given_batch_initial_conditions(self): num_restarts, raw_samples, dim = 1, 2, 3 @@ -563,7 +654,7 @@ def test_optimize_acqf_runs_given_batch_initial_conditions(self): ] q = 1 - ic_shapes = [(1, 2, dim), (2, 1, dim), (1, dim)] + ic_shapes = [(1, 2, dim), (1, dim)] torch.manual_seed(0) for shape in ic_shapes: @@ -633,16 +724,16 @@ def test_optimize_acqf_warns_on_opt_failure(self): raw_samples=raw_samples, batch_initial_conditions=initial_conditions, ) - message = ( "Optimization failed in `gen_candidates_scipy` with the following " "warning(s):\n[OptimizationWarning('Optimization failed within " - "`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION" - "_IN_LNSRCH.')]\nBecause you specified `batch_initial_conditions`, " - "optimization will not be retried with new initial conditions and will " - "proceed with the current solution. Suggested remediation: Try again with " - "different `batch_initial_conditions`, or don't provide " - "`batch_initial_conditions.`" + "`scipy.optimize.minimize` with status 2 and message " + "ABNORMAL_TERMINATION_IN_LNSRCH.')]\nBecause you specified " + "`batch_initial_conditions` larger than required `num_restarts`, " + "optimization will not be retried with new initial conditions and " + "will proceed with the current solution. Suggested remediation: " + "Try again with different `batch_initial_conditions`, don't provide " + "`batch_initial_conditions`, or increase `num_restarts`." ) expected_warning_raised = any( issubclass(w.category, RuntimeWarning) and message in str(w.message) @@ -947,8 +1038,9 @@ def nlc4(x): acq_function=mock_acq_function, bounds=bounds, q=3, - nonlinear_inequality_constraints=[(nlc1, True)], num_restarts=1, + raw_samples=16, + nonlinear_inequality_constraints=[(nlc1, True)], ic_generator=ic_generator, ) self.assertEqual(candidates.size(), torch.Size([1, 3])) @@ -1062,7 +1154,7 @@ def nlc(x): torch.cat( [ expected_acq_value - for _, expected_acq_value in gcs_return_vals[ + for _candidate, expected_acq_value in gcs_return_vals[ num_restarts - 1 :: num_restarts ] ] @@ -1845,6 +1937,16 @@ def test_optimize_acqf_discrete_local_search(self): ) ) + # test ValueError for batch_initial_conditions shape + with self.assertRaisesRegex(ValueError, "must have shape `n x 1 x d`"): + candidates, _acq_value = optimize_acqf_discrete_local_search( + acq_function=mock_acq_function, + q=q, + discrete_choices=discrete_choices, + X_avoid=torch.tensor([[6, 4, 9]], **tkwargs), + batch_initial_conditions=torch.tensor([[0, 2, 5]], **tkwargs), + ) + # test _gen_batch_initial_conditions_local_search with self.assertRaisesRegex(RuntimeError, "Failed to generate"): _gen_batch_initial_conditions_local_search(