Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Variational GP with derivatives and monotonic gp #2272

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,621 changes: 1,621 additions & 0 deletions examples/08_Advanced_Usage/ApproxGP_Derivative_Information_MonotonicGP.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion gpytorch/variational/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from .orthogonally_decoupled_variational_strategy import OrthogonallyDecoupledVariationalStrategy
from .tril_natural_variational_distribution import TrilNaturalVariationalDistribution
from .unwhitened_variational_strategy import UnwhitenedVariationalStrategy
from .variational_strategy import VariationalStrategy
from .variational_strategy import VariationalStrategy, VariationalStrategyIndexed

__all__ = [
"_VariationalStrategy",
"_VariationalStrategyIndexed",
"AdditiveGridInterpolationVariationalStrategy",
"BatchDecoupledVariationalStrategy",
"CiqVariationalStrategy",
Expand Down
146 changes: 145 additions & 1 deletion gpytorch/variational/variational_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python3

import warnings

import torch
Expand Down Expand Up @@ -244,3 +243,148 @@ def __call__(self, x, prior=False, **kwargs):
self.updated_strategy.fill_(True)

return super().__call__(x, prior=prior, **kwargs)


class VariationalStrategyIndexed(_VariationalStrategy):
r"""
This is an Indexed version of the VariationalStrategy that can be
used with function and derivative values at different locations

:param ~gpytorch.models.ApproximateGP model: Model this strategy is applied to.
Typically passed in when the VariationalStrategy is created in the
__init__ method of the user defined model.
:param torch.Tensor inducing_points: Tensor containing a set of inducing
points to use for variational inference.
:param ~gpytorch.variational.VariationalDistribution variational_distribution: A
VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)`
:param torch.Tensor inducing_index: Boolean Tensor containing True/False flags at inducing
points for functiona/derivatives to be used for slicing the full covariance matrix.
:param learn_inducing_locations: (Default True): Whether or not
the inducing point locations :math:`\mathbf Z` should be learned (i.e. are they
parameters of the model).
:type learn_inducing_locations: `bool`, optional

.. _Hensman et al. (2015):
http://proceedings.mlr.press/v38/hensman15.pdf
.. _Matthews (2017):
https://www.repository.cam.ac.uk/handle/1810/278022
"""

def __init__(
self,
model,
inducing_points,
variational_distribution,
inducing_index,
learn_inducing_locations=True,
jitter_val=None,
):
super().__init__(
model, inducing_points, variational_distribution, learn_inducing_locations, jitter_val=jitter_val
)
self.inducing_index = inducing_index
self.register_buffer("updated_strategy", torch.tensor(True))
self._register_load_state_dict_pre_hook(_ensure_updated_strategy_flag_set)
self.has_fantasy_strategy = True

@cached(name="cholesky_factor", ignore_args=True)
def _cholesky_factor(self, induc_induc_covar):
L = psd_safe_cholesky(to_dense(induc_induc_covar).type(_linalg_dtype_cholesky.value()))
return TriangularLinearOperator(L)

@property
@cached(name="prior_distribution_memo")
def prior_distribution(self):
zeros = torch.zeros(
self._variational_distribution.shape(),
dtype=self._variational_distribution.dtype,
device=self._variational_distribution.device,
)
ones = torch.ones_like(zeros)
res = MultivariateNormal(zeros, DiagLinearOperator(ones))
return res

def forward(self, x, inducing_points, inducing_values, variational_inducing_covar=None, x_index=None, **kwargs):
# Compute full prior distribution
full_inputs = torch.cat([inducing_points, x], dim=-2)
full_indices = torch.cat([self.inducing_index, x_index], dim=-2)
full_output = self.model.forward(full_inputs, full_indices, **kwargs)
full_covar = full_output.lazy_covariance_matrix

# Covariance terms
num_induc = torch.sum(self.inducing_index).item()
test_mean = full_output.mean[..., num_induc:]
induc_induc_covar = full_covar[..., :num_induc, :num_induc].add_jitter(self.jitter_val)
induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense()
data_data_covar = full_covar[..., num_induc:, num_induc:]

# Compute interpolation terms
# K_ZZ^{-1/2} K_ZX
# K_ZZ^{-1/2} \mu_Z
L = self._cholesky_factor(induc_induc_covar)
if L.shape != induc_induc_covar.shape:
# Aggressive caching can cause nasty shape incompatibilies when evaluating with different batch shapes
# TODO: Use a hook fo this
try:
pop_from_cache_ignore_args(self, "cholesky_factor")
except CachingError:
pass
L = self._cholesky_factor(induc_induc_covar)
interp_term = L.solve(induc_data_covar.type(_linalg_dtype_cholesky.value())).to(full_inputs.dtype)

# Compute the mean of q(f)
# k_XZ K_ZZ^{-1/2} (m - K_ZZ^{-1/2} \mu_Z) + \mu_X
predictive_mean = (interp_term.transpose(-1, -2) @ inducing_values.unsqueeze(-1)).squeeze(-1) + test_mean

# Compute the covariance of q(f)
# K_XX + k_XZ K_ZZ^{-1/2} (S - I) K_ZZ^{-1/2} k_ZX
middle_term = self.prior_distribution.lazy_covariance_matrix.mul(-1)
if variational_inducing_covar is not None:
middle_term = SumLinearOperator(variational_inducing_covar, middle_term)

if trace_mode.on():
predictive_covar = (
data_data_covar.add_jitter(self.jitter_val).to_dense()
+ interp_term.transpose(-1, -2) @ middle_term.to_dense() @ interp_term
)
else:
predictive_covar = SumLinearOperator(
data_data_covar.add_jitter(self.jitter_val),
MatmulLinearOperator(interp_term.transpose(-1, -2), middle_term @ interp_term),
)

# Return the distribution
return MultivariateNormal(predictive_mean, predictive_covar)

def __call__(self, x, prior=False, **kwargs):
if not self.updated_strategy.item() and not prior:
with torch.no_grad():
# Get unwhitened p(u)
prior_function_dist = self(self.inducing_points, prior=True)
prior_mean = prior_function_dist.loc
L = self._cholesky_factor(prior_function_dist.lazy_covariance_matrix.add_jitter(self.jitter_val))

# Temporarily turn off noise that's added to the mean
orig_mean_init_std = self._variational_distribution.mean_init_std
self._variational_distribution.mean_init_std = 0.0

# Change the variational parameters to be whitened
variational_dist = self.variational_distribution
mean_diff = (variational_dist.loc - prior_mean).unsqueeze(-1).type(_linalg_dtype_cholesky.value())
whitened_mean = L.solve(mean_diff).squeeze(-1).to(variational_dist.loc.dtype)
covar_root = variational_dist.lazy_covariance_matrix.root_decomposition().root.to_dense()
covar_root = covar_root.type(_linalg_dtype_cholesky.value())
whitened_covar = RootLinearOperator(L.solve(covar_root).to(variational_dist.loc.dtype))
whitened_variational_distribution = variational_dist.__class__(whitened_mean, whitened_covar)
self._variational_distribution.initialize_variational_distribution(whitened_variational_distribution)

# Reset the random noise parameter of the model
self._variational_distribution.mean_init_std = orig_mean_init_std

# Reset the cache
clear_cache_hook(self)

# Mark that we have updated the variational strategy
self.updated_strategy.fill_(True)

return super().__call__(x, prior=prior, **kwargs)