Skip to content

Commit

Permalink
Unit tests for integration of acqfs with their input constructors, wi…
Browse files Browse the repository at this point in the history
…th LearnedObjective and constraints (#2112)

Summary:
Pull Request resolved: #2112

Additional test using acqf with LearnedObjective

Reviewed By: ItsMrLin

Differential Revision: D51269389

fbshipit-source-id: 6a1603dec06ab0dc51c2389ef894ea9110066e89
  • Loading branch information
esantorella authored and facebook-github-bot committed Nov 21, 2023
1 parent 6bb9e31 commit 0d66aa0
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 15 deletions.
2 changes: 1 addition & 1 deletion botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,7 @@ def get_best_f_mc(
objective=objective,
posterior_transform=posterior_transform,
X_baseline=X_baseline,
)
).squeeze()


def optimize_objective(
Expand Down
3 changes: 3 additions & 0 deletions botorch/acquisition/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ def forward(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
)
samples = samples.to(torch.float64)

if samples.ndim < 3:
raise ValueError("samples should have at least 3 dimensions.")

posterior = self.pref_model.posterior(samples)
if isinstance(self.pref_model, DeterministicModel):
# return preference posterior mean
Expand Down
3 changes: 1 addition & 2 deletions botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def repeat_to_match_aug_dim(target_tensor: Tensor, reference_tensor: Tensor) ->
matches that of `reference_tensor`.
The shape will be `(augmented_sample * sample_size) x batch_shape x q x m`.
Example:
Examples:
>>> import torch
>>> target_tensor = torch.arange(3).repeat(2, 1).T
>>> target_tensor
Expand All @@ -71,7 +71,6 @@ def repeat_to_match_aug_dim(target_tensor: Tensor, reference_tensor: Tensor) ->
[1, 1],
[2, 2]])
"""

augmented_sample_num, remainder = divmod(
reference_tensor.shape[0], target_tensor.shape[0]
)
Expand Down
6 changes: 3 additions & 3 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,21 @@ def test_get_best_f_mc(self) -> None:
best_f = get_best_f_mc(training_data=self.blockX_blockY)
self.assertEqual(best_f, get_best_f_mc(self.blockX_blockY[0]))

best_f_expected = self.blockX_blockY[0].Y.max(dim=0).values
best_f_expected = self.blockX_blockY[0].Y.max()
self.assertAllClose(best_f, best_f_expected)
with self.assertRaisesRegex(UnsupportedError, "require an objective"):
get_best_f_mc(training_data=self.blockX_multiY)
obj = LinearMCObjective(weights=torch.rand(2))
best_f = get_best_f_mc(training_data=self.blockX_multiY, objective=obj)

multi_Y = torch.cat([d.Y for d in self.blockX_multiY.values()], dim=-1)
best_f_expected = (multi_Y @ obj.weights).amax(dim=-1, keepdim=True)
best_f_expected = (multi_Y @ obj.weights).max()
self.assertAllClose(best_f, best_f_expected)
post_tf = ScalarizedPosteriorTransform(weights=torch.ones(2))
best_f = get_best_f_mc(
training_data=self.blockX_multiY, posterior_transform=post_tf
)
best_f_expected = (multi_Y.sum(dim=-1)).amax(dim=-1, keepdim=True)
best_f_expected = multi_Y.sum(dim=-1).max()
self.assertAllClose(best_f, best_f_expected)

@mock.patch("botorch.acquisition.input_constructors.optimize_acqf")
Expand Down
224 changes: 224 additions & 0 deletions test/acquisition/test_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from itertools import product
from typing import Dict
from warnings import catch_warnings, simplefilter

import torch
from botorch.acquisition.input_constructors import get_acqf_input_constructor
from botorch.acquisition.logei import (
qLogExpectedImprovement,
qLogNoisyExpectedImprovement,
)
from botorch.acquisition.monte_carlo import (
qExpectedImprovement,
qNoisyExpectedImprovement,
qProbabilityOfImprovement,
)
from botorch.acquisition.objective import LearnedObjective
from botorch.exceptions.warnings import InputDataWarning
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.testing import BotorchTestCase
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood


class TestObjectiveAndConstraintIntegration(BotorchTestCase):
def setUp(self) -> None:
self.q = 3
self.d = 2
self.tkwargs = {"device": self.device, "dtype": torch.double}

def _get_acqf_inputs(self, train_batch_shape: torch.Size, m: int) -> Dict:

train_x = torch.rand((*train_batch_shape, 5, self.d), **self.tkwargs)
y = torch.rand((*train_batch_shape, 5, m), **self.tkwargs)

training_data = SupervisedDataset(
X=train_x,
Y=y,
feature_names=[f"x{i}" for i in range(self.d)],
outcome_names=[f"y{i}" for i in range(m)],
)
utility = y.sum(-1).unsqueeze(-1)

with catch_warnings():
simplefilter("ignore", category=InputDataWarning)
model = SingleTaskGP(train_x, y)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll=mll)

with catch_warnings():
simplefilter("ignore", category=InputDataWarning)
pref_model = SingleTaskGP(y, utility)
pref_mll = ExactMarginalLogLikelihood(pref_model.likelihood, pref_model)
fit_gpytorch_mll(mll=pref_mll)
return {
"training_data": training_data,
"model": model,
"pref_model": pref_model,
"train_x": train_x,
}

def _base_test_with_learned_objective(
self,
train_batch_shape: torch.Size,
prune_baseline: bool,
test_batch_shape: torch.Size,
) -> None:
acq_inputs = self._get_acqf_inputs(train_batch_shape=train_batch_shape, m=4)

pref_sample_shapes = [1, 8]
test_acqf_classes_and_kws = [
# Not yet working
# (qExpectedImprovement, {}),
# (qProbabilityOfImprovement, {}),
# (qLogExpectedImprovement, {}),
(qNoisyExpectedImprovement, {"prune_baseline": prune_baseline}),
(qLogNoisyExpectedImprovement, {"prune_baseline": prune_baseline}),
]

for (acqf_cls, kws), pref_sample_shape in product(
test_acqf_classes_and_kws, pref_sample_shapes
):
with self.subTest(
train_batch_shape=train_batch_shape,
test_batch_shape=test_batch_shape,
prune_baseline=prune_baseline,
acqf_cls=acqf_cls,
pref_sample_shape=pref_sample_shape,
):
objective = LearnedObjective(
pref_model=acq_inputs["pref_model"],
sample_shape=torch.Size([pref_sample_shape]),
)
test_x = torch.rand(
(*test_batch_shape, *train_batch_shape, self.q, self.d),
**self.tkwargs,
)
input_constructor = get_acqf_input_constructor(acqf_cls=acqf_cls)

inputs = input_constructor(
objective=objective,
model=acq_inputs["model"],
training_data=acq_inputs["training_data"],
X_baseline=acq_inputs["train_x"],
sampler=SobolQMCNormalSampler(torch.Size([4])),
**kws,
)
acqf = acqf_cls(**inputs)
acq_val = acqf(test_x)
self.assertEqual(acq_val.shape.numel(), test_x.shape[:-2].numel())

def test_with_learned_objective_train_data_not_batched(self) -> None:
train_batch_shape = []
test_batch_shapes = [[], [1], [2]]
for test_batch_shape in test_batch_shapes:
self._base_test_with_learned_objective(
train_batch_shape=torch.Size(train_batch_shape),
prune_baseline=True,
test_batch_shape=torch.Size(test_batch_shape),
)

def test_with_learned_objective_train_data_1d_batch(self) -> None:
train_batch_shape = [1]
test_batch_shapes = [[], [1], [2]]
for test_batch_shape in test_batch_shapes:
self._base_test_with_learned_objective(
train_batch_shape=torch.Size(train_batch_shape),
# Batched inputs `X_baseline` are currently unsupported by
# prune_inferior_points
prune_baseline=False,
test_batch_shape=torch.Size(test_batch_shape),
)

def test_with_learned_objective_train_data_batched(self) -> None:
train_batch_shape = [3]
test_batch_shapes = [[], [1], [2]]
for test_batch_shape in test_batch_shapes:
self._base_test_with_learned_objective(
train_batch_shape=torch.Size(train_batch_shape),
# Batched inputs `X_baseline` are currently unsupported by
# prune_inferior_points
prune_baseline=False,
test_batch_shape=torch.Size(test_batch_shape),
)

def _base_test_without_learned_objective(
self,
train_batch_shape: torch.Size,
prune_baseline: bool,
test_batch_shape: torch.Size,
) -> None:
inputs = self._get_acqf_inputs(train_batch_shape=train_batch_shape, m=1)
constraints = [lambda y: y[..., 0]]
test_x = torch.rand(
(*test_batch_shape, *train_batch_shape, self.q, self.d), **self.tkwargs
)

input_constructor_kwargs = {
"model": inputs["model"],
"training_data": inputs["training_data"],
"X_baseline": inputs["train_x"],
"sampler": SobolQMCNormalSampler(torch.Size([4])),
}

for acqf_cls, kws in [
(qNoisyExpectedImprovement, {"prune_baseline": prune_baseline}),
(qLogNoisyExpectedImprovement, {"prune_baseline": prune_baseline}),
(qExpectedImprovement, {}),
(qProbabilityOfImprovement, {}),
(qLogExpectedImprovement, {}),
]:
# Not working.
if train_batch_shape.numel() > 1 and acqf_cls == qLogExpectedImprovement:
continue
input_constructor = get_acqf_input_constructor(acqf_cls=acqf_cls)

with self.subTest(
"no objective or constraints",
train_batch_shape=train_batch_shape,
prune_baseline=prune_baseline,
test_batch_shape=test_batch_shape,
acqf_cls=acqf_cls,
):
acqf = acqf_cls(**input_constructor(**input_constructor_kwargs, **kws))
acq_val = acqf(test_x)
self.assertEqual(acq_val.shape.numel(), test_x.shape[:-2].numel())

with self.subTest(
"constrained",
train_batch_shape=train_batch_shape,
prune_baseline=prune_baseline,
test_batch_shape=test_batch_shape,
acqf_cls=acqf_cls,
):
acqf = acqf_cls(
**input_constructor(
constraints=constraints, **input_constructor_kwargs, **kws
)
)
self.assertEqual(acq_val.shape.numel(), test_x.shape[:-2].numel())
acq_val = acqf(test_x)

def test_without_learned_objective(self) -> None:
train_batch_shapes = [[], [1], [2]]
test_batch_shapes = [[], [1], [3]]
for train_batch_shape, test_batch_shape in product(
train_batch_shapes, test_batch_shapes
):
# Batched inputs `X_baseline` are currently unsupported by
# prune_inferior_points
prune_baseline_ = [False] if len(train_batch_shape) > 0 else [False, True]
for prune_baseline in prune_baseline_:
self._base_test_without_learned_objective(
train_batch_shape=torch.Size(train_batch_shape),
prune_baseline=prune_baseline,
test_batch_shape=torch.Size(test_batch_shape),
)
2 changes: 0 additions & 2 deletions test/acquisition/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,6 @@ def test_q_expected_improvement_batch(self):
acqf(X)
self.assertTrue(torch.equal(acqf.sampler.base_samples, bs))

# TODO: Test different objectives (incl. constraints)


class TestQNoisyExpectedImprovement(BotorchTestCase):
def test_q_noisy_expected_improvement(self):
Expand Down
19 changes: 12 additions & 7 deletions test/acquisition/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,13 +459,13 @@ def test_learned_preference_objective(self) -> None:
og_sample_shape = 3
large_sample_shape = 256
batch_size = 2
n = 8
q = 8
test_X = torch.rand(
torch.Size((og_sample_shape, batch_size, n, self.x_dim)),
torch.Size((og_sample_shape, batch_size, q, self.x_dim)),
dtype=torch.float64,
)
large_X = torch.rand(
torch.Size((large_sample_shape, batch_size, n, self.x_dim)),
torch.Size((large_sample_shape, batch_size, q, self.x_dim)),
dtype=torch.float64,
)

Expand All @@ -476,19 +476,24 @@ def test_learned_preference_objective(self) -> None:
first_call_output = pref_obj(test_X)
self.assertEqual(
first_call_output.shape,
torch.Size([og_sample_shape * DEFAULT_NUM_PREF_SAMPLES, batch_size, n]),
torch.Size([og_sample_shape * DEFAULT_NUM_PREF_SAMPLES, batch_size, q]),
)
# Making sure the sampler has correct base_samples shape
self.assertEqual(
pref_obj.sampler.base_samples.shape,
torch.Size([DEFAULT_NUM_PREF_SAMPLES, og_sample_shape, 1, n]),
torch.Size([DEFAULT_NUM_PREF_SAMPLES, og_sample_shape, 1, q]),
)
# Passing through a same-shaped X again shouldn't change the base sample
previous_base_samples = pref_obj.sampler.base_samples
another_test_X = torch.rand_like(test_X)
pref_obj(another_test_X)
self.assertIs(pref_obj.sampler.base_samples, previous_base_samples)

with self.assertRaisesRegex(
ValueError, "samples should have at least 3 dimensions."
):
pref_obj(torch.rand(q, self.x_dim))

# test when sampler has multiple preference samples
with self.subTest("Multiple samples"):
num_samples = 256
Expand All @@ -498,7 +503,7 @@ def test_learned_preference_objective(self) -> None:
)
self.assertEqual(
pref_obj(test_X).shape,
torch.Size([num_samples * og_sample_shape, batch_size, n]),
torch.Size([num_samples * og_sample_shape, batch_size, q]),
)

avg_obj_val = pref_obj(large_X).mean(dim=0)
Expand All @@ -513,7 +518,7 @@ def test_learned_preference_objective(self) -> None:
pref_obj = LearnedObjective(pref_model=mean_pref_model)
self.assertEqual(
pref_obj(test_X).shape,
torch.Size([og_sample_shape, batch_size, n]),
torch.Size([og_sample_shape, batch_size, q]),
)

# the order of samples shouldn't matter
Expand Down

0 comments on commit 0d66aa0

Please sign in to comment.