diff --git a/botorch/models/latent_kronecker_gp.py b/botorch/models/latent_kronecker_gp.py new file mode 100644 index 0000000000..a205c9c0fe --- /dev/null +++ b/botorch/models/latent_kronecker_gp.py @@ -0,0 +1,517 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +References + +.. [lin2024scaling] + J. A. Lin, S. Ament, M. Balandat, E. Bakshy. Scaling Gaussian Processes + for Learning Curve Prediction via Latent Kronecker Structure. NeurIPS 2024 + Bayesian Decision-making and Uncertainty Workshop. + +.. [lin2023sampling] + J. A. Lin, J. Antorán, s. Padhy, D. Janz, J. M. Hernández-Lobato, A. Terenin. + Sampling from Gaussian Process Posterior using Stochastic Gradient Descent. + Advances in Neural Information Processing Systems 2023. +""" + +import contextlib +import warnings +from typing import Any + +import torch +from botorch.acquisition.objective import PosteriorTransform +from botorch.exceptions.errors import BotorchTensorDimensionError +from botorch.models.gpytorch import GPyTorchModel +from botorch.models.model import FantasizeMixin, Model +from botorch.models.transforms.input import InputTransform +from botorch.models.transforms.outcome import OutcomeTransform, Standardize +from botorch.posteriors.gpytorch import GPyTorchPosterior +from botorch.posteriors.latent_kronecker import LatentKroneckerGPPosterior +from botorch.utils.types import _DefaultType, DEFAULT +from gpytorch.distributions import MultivariateNormal +from gpytorch.kernels import MaternKernel, ScaleKernel +from gpytorch.likelihoods import GaussianLikelihood +from gpytorch.likelihoods.likelihood import Likelihood +from gpytorch.means import Mean, ZeroMean + +from gpytorch.models.exact_gp import ExactGP +from gpytorch.module import Module +from linear_operator import settings +from linear_operator.operators import ( + ConstantDiagLinearOperator, + KroneckerProductLinearOperator, + MaskedLinearOperator, +) +from linear_operator.utils.warnings import PerformanceWarning +from torch import Tensor + + +class MinMaxStandardize(Standardize): + r"""Standardize outcomes (zero mean, unit variance), + centered about the minimum (or maximum) instead of the mean. + Otherwise equivalent to 'Standardize'. + """ + + def __init__( + self, + m: int = 1, + use_min: bool = False, + outputs: list[int] | None = None, + batch_shape: torch.Size = torch.Size(), # noqa: B008 + min_stdv: float = 1e-8, + ) -> None: + r"""Standardize outcomes (zero mean, unit variance). + + Args: + m: The output dimension. + use_min: Whether to use the minimum or maximum (instead of the mean). + outputs: Which of the outputs to standardize. If omitted, all + outputs will be standardized. + batch_shape: The batch_shape of the training targets. + min_stddv: The minimum standard deviation for which to perform + standardization (if lower, only de-mean the data). + """ + super().__init__( + m=m, outputs=outputs, batch_shape=batch_shape, min_stdv=min_stdv + ) + self._use_min = use_min + + def forward( + self, Y: Tensor, Yvar: Tensor | None = None + ) -> tuple[Tensor, Tensor | None]: + r"""Standardize outcomes. + + If the module is in train mode, this updates the module state (i.e. the + mean/std normalizing constants). If the module is in eval mode, simply + applies the normalization using the module state. + + Args: + Y: A `batch_shape x n x m`-dim tensor of training targets. + Yvar: A `batch_shape x n x m`-dim tensor of observation noises + associated with the training targets (if applicable). + + Returns: + A two-tuple with the transformed outcomes: + + - The transformed outcome observations. + - The transformed observation noise (if applicable). + """ + if self.training: + if Y.shape[:-2] != self._batch_shape: + raise RuntimeError( + f"Expected Y.shape[:-2] to be {self._batch_shape}, matching " + "the `batch_shape` argument to `Standardize`, but got " + f"Y.shape[:-2]={Y.shape[:-2]}." + ) + if Y.size(-1) != self._m: + raise RuntimeError( + f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected " + f"{self._m}." + ) + if Y.shape[-2] < 1: + raise ValueError(f"Can't standardize with no observations. {Y.shape=}.") + + elif Y.shape[-2] == 1: + stdvs = torch.ones( + (*Y.shape[:-2], 1, Y.shape[-1]), dtype=Y.dtype, device=Y.device + ) + else: + stdvs = Y.std(dim=-2, keepdim=True) + stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0)) + means = ( + Y.min(dim=-2, keepdim=True).values + if self._use_min + else Y.max(dim=-2, keepdim=True).values + ) + if self._outputs is not None: + unused = [i for i in range(self._m) if i not in self._outputs] + means[..., unused] = 0.0 + stdvs[..., unused] = 1.0 + self.means = means + self.stdvs = stdvs + self._stdvs_sq = stdvs.pow(2) + self._is_trained = torch.tensor(True) + + Y_tf = (Y - self.means) / self.stdvs + Yvar_tf = Yvar / self._stdvs_sq if Yvar is not None else None + return Y_tf, Yvar_tf + + +class LatentKroneckerGP(GPyTorchModel, ExactGP, FantasizeMixin): + r""" + A multi-task GP model which uses Kronecker structure despite missing entries. + + Leverages pathwise conditioning and iterative linear system solvers to + efficiently draw samples from the GP posterior. See [lin2024scaling]_ + for details. + + For more information about pathwise conditioning, see [wilson2021pathwise]_ + and [Maddox2021bohdo]_. Details about iterative linear system solvers for GPs + with pathwise conditioning can be found in [lin2023sampling]_. + + NOTE: This model requires iterative methods for efficient posterior inference. + To enable iterative methods, the `use_iterative_methods` helper function can be + used as a context manager. + + Example: + >>> model = LatentKroneckerGP(train_X, train_Y) + >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) + >>> with model.use_iterative_methods(): + >>> fit_gpytorch_mll(mll) + >>> samples = model.posterior(test_X).rsample() + """ + + def __init__( + self, + train_X: Tensor, + train_Y: Tensor, + train_Y_valid: Tensor | None = None, + T: Tensor | None = None, + likelihood: Likelihood | None = None, + mean_module_X: Mean | None = None, + mean_module_T: Mean | None = None, + covar_module_X: Module | None = None, + covar_module_T: Module | None = None, + input_transform: InputTransform | None = None, + outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT, + ) -> None: + r""" + Args: + train_X: A `batch_shape x n x d` tensor of training features. + train_Y: A `batch_shape x n x t` tensor of training observations. + train_Y_valid: A `n x t` boolean tensor of valid values. + True indicates that the corresponding value is valid. + False indicates that the corresponding value is missing. + Does not allow explicit `batch_shape` because + the mask must be shared across batch dimensions. + T: A `batch_shape x t` tensor of training time steps. + If omitted, use [1, ..., t]. + likelihood: A likelihood. If omitted, use a standard + `GaussianLikelihood` with inferred noise level. + mean_module_X: The mean function to be used for X. + If omitted, use a `ConstantMean`. + mean_module_T: The mean function to be used for T. + If omitted, use a `ConstantMean`. + covar_module_X: The module computing the covariance matrix of X. + If omitted, use a `MaternKernel`. + covar_module_T: The module computing the covariance matrix of T. + If omitted, use a `MaternKernel`. + input_transform: An input transform that is applied to X. + outcome_transform: An outcome transform that is applied to Y. + """ + with torch.no_grad(): + # transform inputs here to check resulting shapes + # actual transforms will be applied in forward() and posterior() + transformed_X = self.transform_inputs( + X=train_X, input_transform=input_transform + ) + + self._validate_tensor_args(X=transformed_X, Y=train_Y) + batch_shape, ard_num_dims = transformed_X.shape[:-2], transformed_X.shape[-1] + + self.T = self._init_T(T, batch_shape, train_Y) + + self._num_outputs = self.T.shape[-1] + + if likelihood is None: + likelihood = GaussianLikelihood(batch_shape=batch_shape) + + if train_Y_valid is not None: + if train_Y_valid.shape != train_Y.shape[-2:]: + raise BotorchTensorDimensionError( + "Explicit batch_shape not allowed for train_Y_valid, " + "because the mask must be shared across batch dimensions. " + f"Expected train_Y_valid with shape: {train_Y.shape[-2:]} " + f"(got {train_Y_valid.shape})." + ) + assert train_Y_valid.dtype == torch.bool + self.mask = train_Y_valid.reshape(-1) + else: + mask_len = train_Y.shape[-2] * train_Y.shape[-1] + self.mask = torch.ones(mask_len, dtype=torch.bool, device=train_Y.device) + + train_Y = train_Y.reshape(*batch_shape, -1)[..., self.mask] + + if outcome_transform == DEFAULT: + outcome_transform = MinMaxStandardize(batch_shape=batch_shape) + if outcome_transform is not None: + # transform outputs once and keep the results + train_Y = outcome_transform(train_Y.unsqueeze(-1))[0].squeeze(-1) + + ExactGP.__init__( + self, + train_inputs=train_X, + train_targets=train_Y, + likelihood=likelihood, + ) + + if mean_module_X is None: + mean_module_X = ZeroMean(batch_shape=batch_shape) + self.mean_module_X: Module = mean_module_X + + if mean_module_T is None: + mean_module_T = ZeroMean(batch_shape=batch_shape) + self.mean_module_T: Module = mean_module_T + + if covar_module_X is None: + covar_module_X = MaternKernel( + ard_num_dims=ard_num_dims, batch_shape=batch_shape + ) + + if covar_module_T is None: + covar_module_T = ScaleKernel( + base_kernel=MaternKernel(ard_num_dims=1, batch_shape=batch_shape), + ) + + self.covar_module_X: Module = covar_module_X + self.covar_module_T: Module = covar_module_T + + if input_transform is not None: + self.input_transform = input_transform + if outcome_transform is not None: + self.outcome_transform = outcome_transform + + self._cached_base_samples = None + self._cached_L_train_train_X = None + self._cached_L_T = None + self._cached_H_inv_v = None + + self.to(train_X) + + def _init_T( + self, T: Tensor | None, batch_shape: torch.Size, train_Y: Tensor + ) -> Tensor: + if T is not None: + expected_shape = torch.Size([*batch_shape, train_Y.shape[-1]]) + if T.shape != expected_shape: + raise BotorchTensorDimensionError( + f"Expected T with shape: {expected_shape} (got {T.shape})." + ) + return T + else: + T = torch.linspace( + 0, 1, train_Y.shape[-1], dtype=train_Y.dtype, device=train_Y.device + ) + T = T.expand(*batch_shape, -1) + return T + + def use_iterative_methods( + self, + tol: float = 0.01, + max_iter: int = 10000, + covar_root_decomposition: bool = False, + log_prob: bool = True, + solves: bool = True, + ): + with contextlib.ExitStack() as stack: + stack.enter_context( + settings.fast_computations( + covar_root_decomposition=covar_root_decomposition, + log_prob=log_prob, + solves=solves, + ) + ) + stack.enter_context(settings.cg_tolerance(tol)) + stack.enter_context(settings.max_cg_iterations(max_iter)) + return stack.pop_all() + + def _get_mean(self, X: Tensor, mask: Tensor | None = None) -> Tensor: + mean_X = self.mean_module_X(X).unsqueeze(-1) + mean_T = self.mean_module_T(self.T.unsqueeze(-1)).unsqueeze(-1) + mean = KroneckerProductLinearOperator(mean_X, mean_T).squeeze(-1) + return mean[..., mask] if mask is not None else mean + + def forward(self, X: Tensor) -> MultivariateNormal: + if self.training: + X = self.transform_inputs(X) + mask = self.mask + else: + total_len = X.shape[-2] * self._num_outputs + mask = torch.ones(total_len, dtype=torch.bool, device=X.device) + mask[: self.mask.shape[-1]] = self.mask + + mean = self._get_mean(X, mask) + + covar_X = self.covar_module_X(X) + covar_T = self.covar_module_T(self.T.unsqueeze(-1)) + covar = KroneckerProductLinearOperator(covar_X, covar_T) + covar = MaskedLinearOperator(covar, row_mask=mask, col_mask=mask) + + return MultivariateNormal(mean, covar) + + def posterior( + self, + X: Tensor, + observation_noise: bool | Tensor = False, + posterior_transform: PosteriorTransform | None = None, + **kwargs: Any, + ) -> GPyTorchPosterior: + if posterior_transform is not None: + raise NotImplementedError( + "Posterior transforms currently not supported for " + f"{self.__class__.__name__}" + ) + if not isinstance(self.likelihood, GaussianLikelihood): + raise NotImplementedError( + "Only GaussianLikelihood currently supported for " + f"{self.__class__.__name__}" + ) + if observation_noise is not False: + raise NotImplementedError( + "Observation noise currently not supported for " + f"{self.__class__.__name__}" + ) + return LatentKroneckerGPPosterior(self, X) + + def _rsample_from_base_samples( + self, + X: Tensor, + base_samples: Tensor, + observation_noise: bool | Tensor = False, + ) -> Tensor: + r"""Sample from the posterior distribution at the provided points `X` + using Matheron's rule, requiring `n + 2 n_train` base samples. + + Args: + X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension + of the feature space and `q` is the number of points considered + jointly + base_samples: A Tensor of `N(0, I)` base samples of shape + `sample_shape x base_sample_shape`, typically obtained from + a `Sampler`. This is used for deterministic optimization. + Returns: + Samples from the posterior, a tensor of shape + `self._extended_shape(sample_shape=sample_shape)`. + """ + # toggle eval mode to switch the behavior of input / outcome transforms + # this also implicitly applies the input transform to the train_inputs + self.eval() + X_train = self.train_inputs[0] + X_test = self.transform_inputs(X) + n_train_full = X_train.shape[-2] * self._num_outputs + n_train = self.train_targets.shape[-1] + n_test = X_test.shape[-2] * self._num_outputs + + sample_shape = base_samples.shape[: -len(self.batch_shape) - 1] + w_train, eps_base, w_test = torch.split( + base_samples, [n_train_full, n_train, n_test], dim=-1 + ) + eps = torch.sqrt(self.likelihood.noise) * eps_base + + K_T = self.covar_module_T(self.T.unsqueeze(-1)) + + if self._cached_base_samples is not None and torch.equal( + base_samples, self._cached_base_samples + ): + L_train_train_X = self._cached_L_train_train_X + L_T = self._cached_L_T + H_inv_v = self._cached_H_inv_v + else: + # Evaluate prior mean at training data + m_train = self._get_mean(X_train, self.mask) + + # Calculate prior sample + K_train_train_X = self.covar_module_X(X_train) + L_train_train_X = K_train_train_X.cholesky(upper=False) + L_T = K_T.cholesky(upper=False) + + L_train_train = KroneckerProductLinearOperator(L_train_train_X, L_T) + + f_prior_train = L_train_train @ w_train.unsqueeze(-1) + f_prior_train = m_train + f_prior_train.squeeze(-1)[..., self.mask] + + K_train_train = KroneckerProductLinearOperator(K_train_train_X, K_T) + K_train_train = MaskedLinearOperator( + K_train_train, row_mask=self.mask, col_mask=self.mask + ) + noise_covar = ConstantDiagLinearOperator( + self.likelihood.noise + * torch.ones(*self.batch_shape, 1, dtype=X.dtype, device=X.device), + diag_shape=n_train, + ) + H = K_train_train + noise_covar + + v = self.train_targets - (f_prior_train + eps) + # Expand once here to avoid repeated expansion + # by MaskedLinearOperator later + H_inv_v = torch.zeros( + *sample_shape, + *self.batch_shape, + n_train_full, + dtype=X.dtype, + device=X.device, + ) + if settings._fast_solves.off(): + warn_msg = ( + "Iterative methods are disabled. Performing linear solve using " + "full joint covariance matrix, which might be slow and require " + "a lot of memory. Iterative methods can be enabled using " + "'with model.use_iterative_methods():'." + ) + warnings.warn( + warn_msg, + PerformanceWarning, + stacklevel=2, + ) + H_inv_v[..., self.mask] = H.solve(v.unsqueeze(-1)).squeeze(-1) + + self._cached_base_samples = base_samples + self._cached_L_train_train_X = L_train_train_X + self._cached_L_T = L_T + self._cached_H_inv_v = H_inv_v + + # Evaluate prior mean at test data + m_test = self._get_mean(X_test) + + K_train_test_X = self.covar_module_X(X_train, X_test).evaluate_kernel() + K_test_test_X = self.covar_module_X(X_test).evaluate_kernel() + + L_train_test_X = L_train_train_X.solve_triangular( + K_train_test_X.tensor, upper=False + ) + L_test_test_X = ( + K_test_test_X - L_train_test_X.transpose(-2, -1) @ L_train_test_X + ).cholesky(upper=False) + + L_test_train = KroneckerProductLinearOperator( + L_train_test_X.transpose(-2, -1), L_T + ) + + L_test_test = KroneckerProductLinearOperator(L_test_test_X, L_T) + + # match dimensions for broadcasting + broadcast_shape = L_test_train.shape[:-2] + extra_batch_dims = len(broadcast_shape) - len(self.batch_shape) + for _ in range(extra_batch_dims): + w_train = w_train.unsqueeze(len(sample_shape)) + w_test = w_test.unsqueeze(len(sample_shape)) + H_inv_v = H_inv_v.unsqueeze(len(sample_shape)) + + f_prior_test = L_test_train @ w_train.unsqueeze(-1) + f_prior_test = f_prior_test + L_test_test @ w_test.unsqueeze(-1) + f_prior_test = m_test + f_prior_test.squeeze(-1) + + K_train_test = KroneckerProductLinearOperator(K_train_test_X, K_T) + # no MaskedLinearOperator here because H_inv_v is already expanded + samples = K_train_test.transpose(-2, -1) @ H_inv_v.unsqueeze(-1) + samples = samples + f_prior_test.unsqueeze(-1) + # reshape samples to separate X and T dimensions + # samples.shape = (*sample_shape, *broadcast_shape, n_test_x * n_t, 1) + samples = samples.reshape( + *samples.shape[:-2], X_test.shape[-2], self._num_outputs + ) + # samples.shape = (*sample_shape, *broadcast_shape, n_test_x, n_t) + if hasattr(self, "outcome_transform") and self.outcome_transform is not None: + samples, _ = self.outcome_transform.untransform(samples) + return samples + + def condition_on_observations( + self, X: Tensor, Y: Tensor, noise: Tensor | None = None, **kwargs: Any + ) -> Model: + raise NotImplementedError( + f"Conditioning currently not supported for {self.__class__.__name__}" + ) diff --git a/botorch/posteriors/latent_kronecker.py b/botorch/posteriors/latent_kronecker.py new file mode 100644 index 0000000000..57b94137ee --- /dev/null +++ b/botorch/posteriors/latent_kronecker.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from botorch.models.gpytorch import GPyTorchModel +from botorch.posteriors.gpytorch import GPyTorchPosterior +from gpytorch.distributions import MultivariateNormal +from linear_operator.operators import IdentityLinearOperator, ZeroLinearOperator +from torch import Tensor + + +r""" +References + +.. [wilson2020sampling] + J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Efficiently + sampling functions from Gaussian process posteriors. International Conference on + Machine Learning (2020). + +.. [wilson2021pathwise] + J. Wilson, V. Borovitskiy, A. Terenin, P. Mostowsky, and M. Deisenroth. Pathwise + Conditioning of Gaussian Processes. Journal of Machine Learning Research (2021). +""" + + +class LatentKroneckerGPPosterior(GPyTorchPosterior): + r""" + Dummy posterior class for a LatentKroneckerGP model. + Internally calls model._rsample_from_base_samples to draw posterior samples via + pathwise conditioning aka Matheron's rule [wilson2020sampling, wilson2021pathwise]. + + This is necessary because BoTorch instantiates the posterior object before creating + base samples, whereas pathwise conditioning requires the base samples first to + calculate the posterior samples. To cache expensive computations, which only have + to be performed once for the same base samples, the results are stored in the model + instead of the posterior object, because a new posterior object is created in each + acquisition function call. + """ + + def __init__( + self, + model: GPyTorchModel, + X: Tensor, + ) -> None: + r"""A dummy posterior for LatentKroneckerGP models. + + Args: + model: The LatentKroneckerGP model to which this posterior belongs to. + X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension + of the feature space and `q` is the number of points considered + jointly, on which the posterior shall be evaluated. + """ + self._dtype = X.dtype + self._device = X.device + self.batch_shape = model.batch_shape + self.output_batch_shape = torch.broadcast_shapes( + model.batch_shape, X.shape[:-2] + ) + self.q = X.shape[-2] + output_dim = self.q * model.T.shape[-1] + mean = ZeroLinearOperator( + *self.output_batch_shape, output_dim, dtype=X.dtype, device=X.device + ) + covar = IdentityLinearOperator( + output_dim, + batch_shape=self.output_batch_shape, + dtype=X.dtype, + device=X.device, + ) + dummy_mvn = MultivariateNormal(mean=mean, covariance_matrix=covar) + super().__init__(distribution=dummy_mvn) + self.model = model + self.X = X + self._is_mt = True + + @property + def base_sample_shape(self): + r"""The shape of a base sample used for constructing posterior samples. + + Overwrites the standard `base_sample_shape` call to inform samplers that + `n_train_full + n_train + n_test` samples are needed rather than n samples. + """ + n_train_full = self.model.train_inputs[0].shape[-2] * self.model.T.shape[-1] + n_train = self.model.train_targets.shape[-1] + n_test = self.q * self.model.T.shape[-1] + return self.batch_shape + torch.Size([n_train_full + n_train + n_test]) + + @property + def batch_range(self) -> tuple[int, int]: + r"""The t-batch range. + + This is used in samplers to identify the t-batch component of the + `base_sample_shape`. The base samples are expanded over the t-batches to + provide consistency in the acquisition values, i.e., to ensure that a + candidate produces same value regardless of its position on the t-batch. + """ + return (0, -1) + + def _extended_shape( + self, + sample_shape: torch.Size = torch.Size(), # noqa: B008 + ) -> torch.Size: + r"""Returns the shape of the samples produced by the distribution with + the given `sample_shape`. + """ + time_shape = torch.Size([self.model.T.shape[-1]]) + q_shape = torch.Size([self.q]) + return sample_shape + self.output_batch_shape + q_shape + time_shape + + def rsample_from_base_samples( + self, + sample_shape: torch.Size, + base_samples: Tensor, + ) -> Tensor: + r"""Sample from the posterior (with gradients) using base samples. + + This is intended to be used with a sampler that produces the corresponding base + samples, and enables acquisition optimization via Sample Average Approximation. + + Since this posterior is a dummy object, call the model to perform sampling. + + Args: + sample_shape: A `torch.Size` object specifying the sample shape. To + draw `n` samples, set to `torch.Size([n])`. To draw `b` batches + of `n` samples each, set to `torch.Size([b, n])`. + base_samples: A Tensor of `N(0, I)` base samples of shape + `sample_shape x base_sample_shape`, typically obtained from + a `Sampler`. This is used for deterministic optimization. + + Returns: + Samples from the posterior, a tensor of shape + `self._extended_shape(sample_shape=sample_shape)`. + """ + if base_samples.shape[: len(sample_shape)] != sample_shape: + raise RuntimeError( + "`sample_shape` disagrees with shape of `base_samples`. " + f"Got {sample_shape=} and {base_samples.shape=}." + ) + + return self.model._rsample_from_base_samples(self.X, base_samples) + + def rsample( + self, + sample_shape: torch.Size | None = None, + ) -> Tensor: + r"""Sample from the posterior (with gradients). + + Args: + sample_shape: A `torch.Size` object specifying the sample shape. To + draw `n` samples, set to `torch.Size([n])`. To draw `b` batches + of `n` samples each, set to `torch.Size([b, n])`. + + Returns: + Samples from the posterior, a tensor of shape + `self._extended_shape(sample_shape=sample_shape)`. + """ + if sample_shape is None: + sample_shape = torch.Size([1]) + base_samples = torch.randn( + sample_shape + self.base_sample_shape, + dtype=self.X.dtype, + device=self.X.device, + ) + return self.rsample_from_base_samples(sample_shape, base_samples) diff --git a/sphinx/source/models.rst b/sphinx/source/models.rst index 4880d22931..dbeb43998d 100644 --- a/sphinx/source/models.rst +++ b/sphinx/source/models.rst @@ -69,6 +69,11 @@ Higher Order GP Models .. automodule:: botorch.models.higher_order_gp :members: +Latent Kronecker GP Models +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.latent_kronecker_gp + :members: + Pairwise GP Models ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.models.pairwise_gp diff --git a/sphinx/source/posteriors.rst b/sphinx/source/posteriors.rst index 665a3d9d44..0bf9ec3f0c 100644 --- a/sphinx/source/posteriors.rst +++ b/sphinx/source/posteriors.rst @@ -44,6 +44,11 @@ Higher Order GP Posterior .. automodule:: botorch.posteriors.higher_order :members: +Latent Kronecker GP Posterior +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.posteriors.latent_kronecker + :members: + Multitask GP Posterior ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.posteriors.multitask diff --git a/test/models/test_latent_kronecker_gp.py b/test/models/test_latent_kronecker_gp.py new file mode 100644 index 0000000000..ea42ce689c --- /dev/null +++ b/test/models/test_latent_kronecker_gp.py @@ -0,0 +1,605 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import warnings + +import torch +from botorch.acquisition.objective import ScalarizedPosteriorTransform +from botorch.exceptions.errors import BotorchTensorDimensionError +from botorch.exceptions.warnings import OptimizationWarning +from botorch.fit import fit_gpytorch_mll +from botorch.models.latent_kronecker_gp import LatentKroneckerGP, MinMaxStandardize +from botorch.models.transforms import Normalize +from botorch.utils.testing import _get_random_data, BotorchTestCase +from botorch.utils.types import DEFAULT +from gpytorch.kernels import MaternKernel, RBFKernel, ScaleKernel +from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood +from gpytorch.means import ConstantMean, ZeroMean +from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood +from linear_operator import settings +from linear_operator.utils.warnings import NumericalWarning, PerformanceWarning + + +def _get_data_with_missing_entries( + n_train: int, d: int, m: int, batch_shape: torch.Size, tkwargs: dict +): + train_X, train_Y = _get_random_data( + batch_shape=batch_shape, m=m, d=d, n=n_train, **tkwargs + ) + + # randomly mask half of the training data + train_Y_valid = torch.ones(n_train * m, dtype=torch.bool, device=tkwargs["device"]) + train_Y_valid[torch.randperm(n_train * m)[: n_train * m // 2]] = False + train_Y_valid = train_Y_valid.reshape(n_train, m) + train_Y[..., ~train_Y_valid] = torch.nan + + return train_X, train_Y, train_Y_valid + + +class TestLatentKroneckerGP(BotorchTestCase): + def test_default_init(self): + for ( + batch_shape, + n_train, + d, + m, + dtype, + use_transforms, + ) in itertools.product( + ( # batch_shape + torch.Size([]), + torch.Size([1]), + torch.Size([2, 3]), + ), + (10,), # n_train + (1, 2), # d + (1, 2), # m + (torch.float, torch.double), # dtype + (False, True), # use_transforms + ): + tkwargs = {"device": self.device, "dtype": dtype} + + if use_transforms: + intf = Normalize(d=d, batch_shape=batch_shape) + octf = DEFAULT + else: + intf = None + octf = None + + train_X, train_Y, train_Y_valid = _get_data_with_missing_entries( + n_train=n_train, d=d, m=m, batch_shape=batch_shape, tkwargs=tkwargs + ) + + model = LatentKroneckerGP( + train_X=train_X, + train_Y=train_Y, + train_Y_valid=train_Y_valid, + input_transform=intf, + outcome_transform=octf, + ) + model.to(**tkwargs) + + # test init + train_Y_flat = train_Y.reshape(*batch_shape, -1)[ + ..., train_Y_valid.reshape(-1) + ] + if use_transforms: + self.assertIsInstance(model.input_transform, Normalize) + self.assertIsInstance(model.outcome_transform, MinMaxStandardize) + else: + self.assertFalse(hasattr(model, "input_transform")) + self.assertFalse(hasattr(model, "outcome_transform")) + train_Y_flat = ( + model.outcome_transform(train_Y_flat.unsqueeze(-1))[0].squeeze(-1) + if use_transforms + else train_Y_flat + ) + self.assertAllClose(model.train_inputs[0], train_X, atol=0.0) + self.assertAllClose(model.train_targets, train_Y_flat, atol=0.0) + self.assertIsInstance(model.likelihood, GaussianLikelihood) + self.assertIsInstance(model.mean_module_X, ZeroMean) + self.assertIsInstance(model.mean_module_T, ZeroMean) + self.assertIsInstance(model.covar_module_X, MaternKernel) + self.assertIsInstance(model.covar_module_T, ScaleKernel) + self.assertIsInstance(model.covar_module_T.base_kernel, MaternKernel) + + def test_custom_init(self): + # test whether custom likelihoods and mean/covar modules are set correctly. + for batch_shape, n_train, d, m, dtype in itertools.product( + ( # batch_shape + torch.Size([]), + torch.Size([1]), + torch.Size([2]), + torch.Size([2, 3]), + ), + (10,), # n_train + (1, 2), # d + (1, 2), # m + (torch.float, torch.double), # dtype + ): + tkwargs = {"device": self.device, "dtype": dtype} + + train_X, train_Y, train_Y_valid = _get_data_with_missing_entries( + n_train=n_train, d=d, m=m, batch_shape=batch_shape, tkwargs=tkwargs + ) + + train_Y_valid_batched = train_Y_valid.expand(*batch_shape, n_train, m) + if len(batch_shape) > 0: + err_msg = ( + "Explicit batch_shape not allowed for train_Y_valid, " + "because the mask must be shared across batch dimensions. " + f"Expected train_Y_valid with shape: {train_Y.shape[-2:]} " + f"(got {train_Y_valid_batched.shape})." + ) + with self.assertRaises(BotorchTensorDimensionError) as e: + LatentKroneckerGP( + train_X=train_X, + train_Y=train_Y, + train_Y_valid=train_Y_valid_batched, + ) + self.assertEqual(err_msg, str(e.exception)) + + T = torch.linspace(0, 1, m, **tkwargs) + if len(batch_shape) > 0: + expected_shape = torch.Size([*batch_shape, train_Y.shape[-1]]) + err_msg = f"Expected T with shape: {expected_shape} (got {T.shape})." + with self.assertRaises(BotorchTensorDimensionError) as e: + LatentKroneckerGP( + train_X=train_X, + train_Y=train_Y, + T=T, + ) + self.assertEqual(err_msg, str(e.exception)) + + T = T.expand(*batch_shape, m) + + likelihood = GaussianLikelihood(batch_shape=batch_shape) + mean_module_X = ConstantMean(batch_shape=batch_shape) + mean_module_T = ConstantMean(batch_shape=batch_shape) + covar_module_X = RBFKernel(ard_num_dims=d, batch_shape=batch_shape) + covar_module_T = RBFKernel(ard_num_dims=1, batch_shape=batch_shape) + + model = LatentKroneckerGP( + train_X=train_X, + train_Y=train_Y, + train_Y_valid=train_Y_valid, + T=T, + likelihood=likelihood, + mean_module_X=mean_module_X, + mean_module_T=mean_module_T, + covar_module_X=covar_module_X, + covar_module_T=covar_module_T, + ) + model.to(**tkwargs) + + self.assertAllClose(model.T, T, atol=0.0) + self.assertEqual(model.likelihood, likelihood) + self.assertEqual(model.mean_module_X, mean_module_X) + self.assertEqual(model.mean_module_T, mean_module_T) + self.assertEqual(model.covar_module_X, covar_module_X) + self.assertEqual(model.covar_module_T, covar_module_T) + + # check devices + def _get_index(device): + return device.index if device.index is not None else 0 + + device_type = self.device.type + device_idx = _get_index(self.device) + + self.assertEqual(model.train_inputs[0].device.type, device_type) + self.assertEqual(_get_index(model.train_inputs[0].device), device_idx) + self.assertEqual(model.train_targets.device.type, device_type) + self.assertEqual(_get_index(model.train_targets.device), device_idx) + self.assertEqual(model.mask.device.type, device_type) + self.assertEqual(_get_index(model.mask.device), device_idx) + self.assertEqual(model.T.device.type, device_type) + self.assertEqual(_get_index(model.T.device), device_idx) + for p in model.parameters(): + self.assertEqual(p.device.type, device_type) + self.assertEqual(_get_index(p.device), device_idx) + + def test_custom_octf(self): + for ( + batch_shape, + n_train, + d, + m, + dtype, + ) in itertools.product( + ( # batch_shape + torch.Size([]), + torch.Size([1]), + torch.Size([2, 3]), + ), + (10,), # n_train + (1, 2), # d + (1, 2), # m + (torch.float, torch.double), # dtype + ): + tkwargs = {"device": self.device, "dtype": dtype} + + octf = DEFAULT + + train_X, train_Y, train_Y_valid = _get_data_with_missing_entries( + n_train=n_train, d=d, m=m, batch_shape=batch_shape, tkwargs=tkwargs + ) + + model = LatentKroneckerGP( + train_X=train_X, + train_Y=train_Y, + train_Y_valid=train_Y_valid, + outcome_transform=octf, + ) + model.to(**tkwargs) + + # test init + train_Y_flat = train_Y.reshape(*batch_shape, -1)[ + ..., train_Y_valid.reshape(-1) + ] + + self.assertIsInstance(model.outcome_transform, MinMaxStandardize) + octf = model.outcome_transform + + # test MinMaxStandardize + octf._is_trained = torch.tensor(False) + # wrong batch shape + with self.assertRaises(RuntimeError): + octf(train_Y_flat.unsqueeze(-1).unsqueeze(0)) + octf._is_trained = torch.tensor(False) + # wrong output dimension + with self.assertRaises(RuntimeError): + octf(train_Y_flat.unsqueeze(-1).repeat(1, 2)) + octf._is_trained = torch.tensor(False) + # missing output dimension + with self.assertRaises(ValueError): + octf(torch.zeros(*batch_shape, 0, 1, **tkwargs)) + octf._is_trained = torch.tensor(False) + # stdvs calculation with single observation + octf(torch.zeros(*batch_shape, 1, 1, **tkwargs)) + self.assertAllClose(octf.stdvs, torch.ones(*batch_shape, 1, 1, **tkwargs)) + octf._is_trained = torch.tensor(False) + # standardize specific output dimensions + octf._outputs = [] + octf(train_Y_flat.unsqueeze(-1)) + self.assertAllClose(octf.means, torch.zeros_like(octf.means)) + self.assertAllClose(octf.stdvs, torch.ones_like(octf.stdvs)) + octf._outputs = None + octf._is_trained = torch.tensor(False) + + def test_gp_train(self): + for ( + batch_shape, + n_train, + d, + m, + dtype, + use_transforms, + ) in itertools.product( + ( # batch_shape + torch.Size([]), + torch.Size([1]), + torch.Size([2, 3]), + ), + (10,), # n_train + (1, 2), # d + (1, 2), # m + (torch.float, torch.double), # dtype + (False, True), # use_transforms + ): + tkwargs = {"device": self.device, "dtype": dtype} + + if use_transforms: + intf = Normalize(d=d, batch_shape=batch_shape) + octf = DEFAULT + else: + intf = None + octf = None + + train_X, train_Y, train_Y_valid = _get_data_with_missing_entries( + n_train=n_train, d=d, m=m, batch_shape=batch_shape, tkwargs=tkwargs + ) + + model = LatentKroneckerGP( + train_X=train_X, + train_Y=train_Y, + train_Y_valid=train_Y_valid, + input_transform=intf, + outcome_transform=octf, + ) + model.to(**tkwargs) + + # test optim + model.train() + mll = ExactMarginalLogLikelihood(model.likelihood, model) + mll.to(**tkwargs) + with warnings.catch_warnings(), model.use_iterative_methods(): + warnings.filterwarnings("ignore", category=OptimizationWarning) + fit_gpytorch_mll( + mll, optimizer_kwargs={"options": {"maxiter": 1}}, max_attempts=1 + ) + + def test_gp_eval_shapes(self): + for ( + batch_shape, + n_train, + n_test, + d, + m, + dtype, + use_transforms, + ) in itertools.product( + ( # batch_shape + torch.Size([]), + torch.Size([1]), + torch.Size([2, 3]), + ), + (10,), # n_train + (7,), # n_test + (1,), # d + (1,), # m + (torch.float, torch.double), # dtype + (False, True), # use_transforms + ): + tkwargs = {"device": self.device, "dtype": dtype} + + if use_transforms: + intf = Normalize(d=d, batch_shape=batch_shape) + octf = DEFAULT + else: + intf = None + octf = None + + train_X, train_Y, train_Y_valid = _get_data_with_missing_entries( + n_train=n_train, d=d, m=m, batch_shape=batch_shape, tkwargs=tkwargs + ) + + model = LatentKroneckerGP( + train_X=train_X, + train_Y=train_Y, + train_Y_valid=train_Y_valid, + input_transform=intf, + outcome_transform=octf, + ) + model.to(**tkwargs) + model.eval() + + torch.manual_seed(12345) + for test_shape in ( + torch.Size([]), + torch.Size([3]), + torch.Size([*batch_shape]), + torch.Size([2, *batch_shape]), + ): + test_X = torch.rand(*test_shape, n_test, d, **tkwargs) + + # we expect an error if test_shape and batch_shape cannot be broadcasted + try: + broadcast_shape = torch.broadcast_shapes(test_shape, batch_shape) + except RuntimeError as e: + with self.assertRaisesRegex(RuntimeError, str(e)): + model.posterior(test_X) + continue + pred_shape = torch.Size([*broadcast_shape, n_test, m]) + + # custom posterior samples + posterior = model.posterior(test_X) + self.assertEqual(posterior.batch_range, (0, -1)) + for sample_shape in ((), (1,), (2, 3)): + # test posterior.rsample + with warnings.catch_warnings(), model.use_iterative_methods(): + warnings.filterwarnings("ignore", category=NumericalWarning) + pred_samples = posterior.rsample(sample_shape=sample_shape) + self.assertEqual( + pred_samples.shape, torch.Size([*sample_shape, *pred_shape]) + ) + self.assertEqual( + pred_samples.shape, + posterior._extended_shape(torch.Size(sample_shape)), + ) + # test posterior.rsample_from_base_samples + base_samples = torch.randn( + torch.Size(sample_shape) + posterior.base_sample_shape, + **tkwargs, + ) + with warnings.catch_warnings(), model.use_iterative_methods(): + warnings.filterwarnings("ignore", category=NumericalWarning) + pred_samples = posterior.rsample_from_base_samples( + sample_shape, base_samples + ) + self.assertEqual( + pred_samples.shape, torch.Size([*sample_shape, *pred_shape]) + ) + # run again to test caching when using the same base samples + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=NumericalWarning) + pred_samples = posterior.rsample_from_base_samples( + sample_shape, base_samples + ) + self.assertEqual( + pred_samples.shape, torch.Size([*sample_shape, *pred_shape]) + ) + if len(sample_shape) > 0: + # test incorrect base sample shape + incorrect_base_samples = torch.randn( + torch.Size([5]) + posterior.base_sample_shape, + **tkwargs, + ) + with self.assertRaises(RuntimeError): + posterior.rsample_from_base_samples( + sample_shape, incorrect_base_samples + ) + + def test_gp_eval_values(self): + for ( + batch_shape, + n_train, + n_test, + d, + m, + dtype, + use_transforms, + ) in itertools.product( + ( # batch_shape + torch.Size([]), + torch.Size([1]), + torch.Size([2, 3]), + ), + (10,), # n_train + (7,), # n_test + (1,), # d + (1,), # m + (torch.float, torch.double), # dtype + (False, True), # use_transforms + ): + tkwargs = {"device": self.device, "dtype": dtype} + + if use_transforms: + intf = Normalize(d=d, batch_shape=batch_shape) + octf = DEFAULT + else: + intf = None + octf = None + + train_X, train_Y, train_Y_valid = _get_data_with_missing_entries( + n_train=n_train, d=d, m=m, batch_shape=batch_shape, tkwargs=tkwargs + ) + + model = LatentKroneckerGP( + train_X=train_X, + train_Y=train_Y, + train_Y_valid=train_Y_valid, + input_transform=intf, + outcome_transform=octf, + ) + model.to(**tkwargs) + model.eval() + + torch.manual_seed(12345) + for test_shape in ( + torch.Size([]), + torch.Size([3]), + torch.Size([*batch_shape]), + torch.Size([2, *batch_shape]), + ): + test_X = torch.rand(*test_shape, n_test, d, **tkwargs) + + # we expect an error if test_shape and batch_shape cannot be broadcasted + try: + broadcast_shape = torch.broadcast_shapes(test_shape, batch_shape) + except RuntimeError as e: + with self.assertRaisesRegex(RuntimeError, str(e)): + model.posterior(test_X) + continue + pred_shape = torch.Size([*broadcast_shape, n_test, m]) + + posterior = model.posterior(test_X) + with warnings.catch_warnings(), model.use_iterative_methods(): + warnings.filterwarnings("ignore", category=NumericalWarning) + pred_samples = posterior.rsample(sample_shape=(2048,)) + self.assertEqual(pred_samples.shape, torch.Size([2048, *pred_shape])) + + # GPyTorch predictions + with model.use_iterative_methods(): + pred = model(intf(test_X)) if intf is not None else model(test_X) + pred_mean, pred_var = pred.mean, pred.variance + pred_mean = pred_mean.reshape(*pred_mean.shape[:-1], n_test, m) + pred_var = pred_var.reshape(*pred_var.shape[:-1], n_test, m) + pred_mean, pred_var = ( + model.outcome_transform.untransform(pred_mean, pred_var) + if octf is not None + else (pred_mean, pred_var) + ) + self.assertEqual(pred_mean.shape, pred_shape) + self.assertEqual(pred_var.shape, pred_shape) + + # check custom predictions and GPyTorch are roughly the same + self.assertLess( + (pred_mean - pred_samples.mean(dim=0)).norm() / pred_mean.norm(), + 0.1, + ) + self.assertLess( + (pred_var - pred_samples.var(dim=0)).norm() / pred_var.norm(), 0.1 + ) + + def test_iterative_methods(self): + for batch_shape, n_train, d, m, dtype in itertools.product( + (torch.Size([]),), # batch_shape + (10,), # n_train + (1,), # d + (1,), # m + (torch.float, torch.double), # dtype + ): + tkwargs = {"device": self.device, "dtype": dtype} + train_X, train_Y = _get_random_data( + batch_shape=batch_shape, m=m, d=d, n=n_train, **tkwargs + ) + + model = LatentKroneckerGP( + train_X=train_X, + train_Y=train_Y, + ) + model.to(**tkwargs) + posterior = model.posterior(train_X) + + warn_msg = ( + "Iterative methods are disabled. Performing linear solve using " + "full joint covariance matrix, which might be slow and require " + "a lot of memory. Iterative methods can be enabled using " + "'with model.use_iterative_methods():'." + ) + + with self.assertWarns(PerformanceWarning) as w: + posterior.rsample() + # Using this because self.assertWarnsRegex does not work for some reason + self.assertEqual(warn_msg, str(w.warning)) + + with model.use_iterative_methods(): + self.assertTrue(settings._fast_covar_root_decomposition.off()) + self.assertTrue(settings._fast_log_prob.on()) + self.assertTrue(settings._fast_solves.on()) + self.assertEqual(settings.cg_tolerance.value(), 0.01) + self.assertEqual(settings.max_cg_iterations.value(), 10000) + + def test_not_implemented(self): + batch_shape = torch.Size([]) + tkwargs = {"device": self.device, "dtype": torch.double} + train_X, train_Y = _get_random_data(batch_shape=batch_shape, m=1, **tkwargs) + + model = LatentKroneckerGP( + train_X=train_X, + train_Y=train_Y, + ) + model.to(**tkwargs) + + cls_name = model.__class__.__name__ + + transform = ScalarizedPosteriorTransform(torch.tensor([1.0], **tkwargs)) + err_msg = f"Posterior transforms currently not supported for {cls_name}" + with self.assertRaisesRegex(NotImplementedError, err_msg): + model.posterior(train_X, posterior_transform=transform) + + err_msg = f"Observation noise currently not supported for {cls_name}" + with self.assertRaisesRegex(NotImplementedError, err_msg): + model.posterior(train_X, observation_noise=True) + + err_msg = f"Conditioning currently not supported for {cls_name}" + with self.assertRaisesRegex(NotImplementedError, err_msg): + model.condition_on_observations(train_X, train_Y) + + likelihood = FixedNoiseGaussianLikelihood( + torch.tensor([1.0]), batch_shape=batch_shape, **tkwargs + ) + model = LatentKroneckerGP( + train_X=train_X, + train_Y=train_Y, + likelihood=likelihood, + ) + model.to(**tkwargs) + + err_msg = f"Only GaussianLikelihood currently supported for {cls_name}" + with self.assertRaisesRegex(NotImplementedError, err_msg): + model.posterior(train_X)