Skip to content

Commit

Permalink
implement suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt committed Dec 2, 2024
1 parent a059ca8 commit d5a4d4a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 26 deletions.
8 changes: 4 additions & 4 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ def optimize_acqf_mixed(
ic_gen_kwargs = ic_gen_kwargs or {}

if q == 1:
timeout_sec = timeout_sec / len(fixed_features_list) if timeout_sec else None
ff_candidate_list, ff_acq_value_list = [], []
num_candidate_generation_failures = 0
for fixed_features in fixed_features_list:
Expand All @@ -980,9 +981,7 @@ def optimize_acqf_mixed(
ic_generator=ic_generator,
return_best_only=False,
gen_candidates=gen_candidates,
timeout_sec=timeout_sec / len(fixed_features_list)
if timeout_sec
else None,
timeout_sec=timeout_sec,
retry_on_optimization_warning=retry_on_optimization_warning,
**ic_gen_kwargs,
)
Expand Down Expand Up @@ -1024,6 +1023,7 @@ def optimize_acqf_mixed(
base_X_pending = acq_function.X_pending
candidates = torch.tensor([], device=bounds.device, dtype=bounds.dtype)

timeout_sec = timeout_sec / q if timeout_sec else None
for _ in range(q):
candidate, acq_value = optimize_acqf_mixed(
acq_function=acq_function,
Expand All @@ -1041,7 +1041,7 @@ def optimize_acqf_mixed(
gen_candidates=gen_candidates,
ic_generator=ic_generator,
ic_gen_kwargs=ic_gen_kwargs,
timeout_sec=timeout_sec / q if timeout_sec else None,
timeout_sec=timeout_sec,
retry_on_optimization_warning=retry_on_optimization_warning,
return_best_only=True,
)
Expand Down
19 changes: 9 additions & 10 deletions botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,15 @@ def optimize_acqf_homotopy(
).unsqueeze(1)

# Optimize one more time with the final options
if fixed_features_list:
candidates, acq_values = optimization_fn(
acq_function=acq_function,
bounds=bounds,
q=1,
options=final_options,
batch_initial_conditions=candidates,
**fixed_features_kwargs,
**shared_optimize_acqf_kwargs,
)
candidates, acq_values = optimization_fn(
acq_function=acq_function,
bounds=bounds,
q=1,
options=final_options,
batch_initial_conditions=candidates,
**fixed_features_kwargs,
**shared_optimize_acqf_kwargs,
)

# Post-process the candidates and grab the best candidate
if post_processing_func is not None:
Expand Down
12 changes: 6 additions & 6 deletions test/optim/test_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_optimize_acqf_homotopy(self):
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10], [5]]).to(**tkwargs),
bounds=torch.tensor([[-10], [5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_optimize_acqf_homotopy(self):
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
Expand All @@ -173,7 +173,7 @@ def test_optimize_acqf_homotopy(self):
optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
Expand All @@ -183,7 +183,7 @@ def test_optimize_acqf_homotopy(self):
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
Expand All @@ -196,7 +196,7 @@ def test_optimize_acqf_homotopy(self):
candidate, acqf_val = optimize_acqf_homotopy(
q=3,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
Expand All @@ -218,7 +218,7 @@ def test_optimize_acqf_homotopy(self):
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
bounds=torch.tensor([[-10, -10], [5, 5]], **tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
Expand Down
13 changes: 7 additions & 6 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,7 @@ def test_optimize_acqf_mixed_q1(self, mock_optimize_acqf):
# compute expected output
best_acq_values = torch.tensor(
[torch.max(acq_values) for acq_values in acq_val_rvs]
).to(acq_val_rvs[0])
)
best_batch_idx = torch.argmax(best_acq_values)

if return_best_only:
Expand All @@ -1476,12 +1476,13 @@ def test_optimize_acqf_mixed_q1(self, mock_optimize_acqf):
best_idx = torch.argmax(best_batch_acq_values)
expected_candidates = best_batch_candidates[best_idx]
expected_acq_value = best_batch_acq_values[best_idx]
assert expected_candidates.dim() == 2
self.assertEqual(expected_candidates.dim(), 2)

else:
expected_candidates = candidate_rvs[best_batch_idx]
expected_acq_value = acq_val_rvs[best_batch_idx]
assert expected_candidates.dim() == 3
assert expected_acq_value.dim() == 1
self.assertEqual(expected_candidates.dim(), 3)
self.assertEqual(expected_acq_value.dim(), 1)

self.assertTrue(torch.equal(candidates, expected_candidates))
self.assertTrue(torch.equal(acq_value, expected_acq_value))
Expand Down Expand Up @@ -1549,7 +1550,7 @@ def test_optimize_acqf_mixed_q2(self, mock_optimize_acqf):

best_acq_values = torch.tensor(
[torch.max(acq_values) for acq_values in acq_val_rvs_q]
).to(acq_val_rvs_q[0])
)
best_batch_idx = torch.argmax(best_acq_values)

best_batch_candidates = candidate_rvs_q[best_batch_idx]
Expand Down Expand Up @@ -1599,11 +1600,11 @@ def test_optimize_acqf_mixed_empty_ff(self):
)

def test_optimize_acqf_mixed_return_best_only_q2(self):
mock_acq_function = MockAcquisitionFunction()
with self.assertRaises(
NotImplementedError,
msg="`return_best_only=False` is only supported for q=1.",
):
mock_acq_function = MockAcquisitionFunction()
optimize_acqf_mixed(
acq_function=mock_acq_function,
q=2,
Expand Down

0 comments on commit d5a4d4a

Please sign in to comment.