From d5a4d4ae73d971c900b196d6ae16d4f8e6f85310 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20P=2E=20D=C3=BCrholt?= Date: Mon, 2 Dec 2024 16:25:36 +0100 Subject: [PATCH] implement suggestions --- botorch/optim/optimize.py | 8 ++++---- botorch/optim/optimize_homotopy.py | 19 +++++++++---------- test/optim/test_homotopy.py | 12 ++++++------ test/optim/test_optimize.py | 13 +++++++------ 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index 76ad6b71bb..4c9e75ee1f 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -960,6 +960,7 @@ def optimize_acqf_mixed( ic_gen_kwargs = ic_gen_kwargs or {} if q == 1: + timeout_sec = timeout_sec / len(fixed_features_list) if timeout_sec else None ff_candidate_list, ff_acq_value_list = [], [] num_candidate_generation_failures = 0 for fixed_features in fixed_features_list: @@ -980,9 +981,7 @@ def optimize_acqf_mixed( ic_generator=ic_generator, return_best_only=False, gen_candidates=gen_candidates, - timeout_sec=timeout_sec / len(fixed_features_list) - if timeout_sec - else None, + timeout_sec=timeout_sec, retry_on_optimization_warning=retry_on_optimization_warning, **ic_gen_kwargs, ) @@ -1024,6 +1023,7 @@ def optimize_acqf_mixed( base_X_pending = acq_function.X_pending candidates = torch.tensor([], device=bounds.device, dtype=bounds.dtype) + timeout_sec = timeout_sec / q if timeout_sec else None for _ in range(q): candidate, acq_value = optimize_acqf_mixed( acq_function=acq_function, @@ -1041,7 +1041,7 @@ def optimize_acqf_mixed( gen_candidates=gen_candidates, ic_generator=ic_generator, ic_gen_kwargs=ic_gen_kwargs, - timeout_sec=timeout_sec / q if timeout_sec else None, + timeout_sec=timeout_sec, retry_on_optimization_warning=retry_on_optimization_warning, return_best_only=True, ) diff --git a/botorch/optim/optimize_homotopy.py b/botorch/optim/optimize_homotopy.py index fac41a149b..a9d2ba00ee 100644 --- a/botorch/optim/optimize_homotopy.py +++ b/botorch/optim/optimize_homotopy.py @@ -226,16 +226,15 @@ def optimize_acqf_homotopy( ).unsqueeze(1) # Optimize one more time with the final options - if fixed_features_list: - candidates, acq_values = optimization_fn( - acq_function=acq_function, - bounds=bounds, - q=1, - options=final_options, - batch_initial_conditions=candidates, - **fixed_features_kwargs, - **shared_optimize_acqf_kwargs, - ) + candidates, acq_values = optimization_fn( + acq_function=acq_function, + bounds=bounds, + q=1, + options=final_options, + batch_initial_conditions=candidates, + **fixed_features_kwargs, + **shared_optimize_acqf_kwargs, + ) # Post-process the candidates and grab the best candidate if post_processing_func is not None: diff --git a/test/optim/test_homotopy.py b/test/optim/test_homotopy.py index 62d2ef317e..8fb6a6f810 100644 --- a/test/optim/test_homotopy.py +++ b/test/optim/test_homotopy.py @@ -117,7 +117,7 @@ def test_optimize_acqf_homotopy(self): candidate, acqf_val = optimize_acqf_homotopy( q=1, acq_function=acqf, - bounds=torch.tensor([[-10], [5]]).to(**tkwargs), + bounds=torch.tensor([[-10], [5]], **tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), num_restarts=2, raw_samples=16, @@ -151,7 +151,7 @@ def test_optimize_acqf_homotopy(self): candidate, acqf_val = optimize_acqf_homotopy( q=1, acq_function=acqf, - bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), num_restarts=2, raw_samples=16, @@ -173,7 +173,7 @@ def test_optimize_acqf_homotopy(self): optimize_acqf_homotopy( q=1, acq_function=acqf, - bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), num_restarts=2, raw_samples=16, @@ -183,7 +183,7 @@ def test_optimize_acqf_homotopy(self): candidate, acqf_val = optimize_acqf_homotopy( q=1, acq_function=acqf, - bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), num_restarts=2, raw_samples=16, @@ -196,7 +196,7 @@ def test_optimize_acqf_homotopy(self): candidate, acqf_val = optimize_acqf_homotopy( q=3, acq_function=acqf, - bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), num_restarts=2, raw_samples=16, @@ -218,7 +218,7 @@ def test_optimize_acqf_homotopy(self): candidate, acqf_val = optimize_acqf_homotopy( q=1, acq_function=acqf, - bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs), + bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs), homotopy=Homotopy(homotopy_parameters=[hp]), num_restarts=2, raw_samples=16, diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 026bbc8842..6a3cdede35 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -1467,7 +1467,7 @@ def test_optimize_acqf_mixed_q1(self, mock_optimize_acqf): # compute expected output best_acq_values = torch.tensor( [torch.max(acq_values) for acq_values in acq_val_rvs] - ).to(acq_val_rvs[0]) + ) best_batch_idx = torch.argmax(best_acq_values) if return_best_only: @@ -1476,12 +1476,13 @@ def test_optimize_acqf_mixed_q1(self, mock_optimize_acqf): best_idx = torch.argmax(best_batch_acq_values) expected_candidates = best_batch_candidates[best_idx] expected_acq_value = best_batch_acq_values[best_idx] - assert expected_candidates.dim() == 2 + self.assertEqual(expected_candidates.dim(), 2) + else: expected_candidates = candidate_rvs[best_batch_idx] expected_acq_value = acq_val_rvs[best_batch_idx] - assert expected_candidates.dim() == 3 - assert expected_acq_value.dim() == 1 + self.assertEqual(expected_candidates.dim(), 3) + self.assertEqual(expected_acq_value.dim(), 1) self.assertTrue(torch.equal(candidates, expected_candidates)) self.assertTrue(torch.equal(acq_value, expected_acq_value)) @@ -1549,7 +1550,7 @@ def test_optimize_acqf_mixed_q2(self, mock_optimize_acqf): best_acq_values = torch.tensor( [torch.max(acq_values) for acq_values in acq_val_rvs_q] - ).to(acq_val_rvs_q[0]) + ) best_batch_idx = torch.argmax(best_acq_values) best_batch_candidates = candidate_rvs_q[best_batch_idx] @@ -1599,11 +1600,11 @@ def test_optimize_acqf_mixed_empty_ff(self): ) def test_optimize_acqf_mixed_return_best_only_q2(self): + mock_acq_function = MockAcquisitionFunction() with self.assertRaises( NotImplementedError, msg="`return_best_only=False` is only supported for q=1.", ): - mock_acq_function = MockAcquisitionFunction() optimize_acqf_mixed( acq_function=mock_acq_function, q=2,