Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure PSD-safe factorization in constructor of MultivariateNormal #2297

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion gpytorch/distributions/multitask_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from linear_operator import LinearOperator, to_linear_operator
from linear_operator.operators import BlockDiagLinearOperator, BlockInterleavedLinearOperator, CatLinearOperator
from torch import Tensor

from .multivariate_normal import MultivariateNormal

Expand All @@ -24,7 +25,9 @@ class MultitaskMultivariateNormal(MultivariateNormal):
w.r.t. inter-observation covariance for each task.
"""

def __init__(self, mean, covariance_matrix, validate_args=False, interleaved=True):
def __init__(
self, mean: Tensor, covariance_matrix: LinearOperator, validate_args: bool = False, interleaved: bool = True
):
if not torch.is_tensor(mean) and not isinstance(mean, LinearOperator):
raise RuntimeError("The mean of a MultitaskMultivariateNormal must be a Tensor or LinearOperator")

Expand Down
78 changes: 39 additions & 39 deletions gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,36 @@ class MultivariateNormal(TMultivariateNormal, Distribution):
:ivar torch.Tensor variance: The variance.
"""

def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], validate_args: bool = False):
self._islazy = isinstance(mean, LinearOperator) or isinstance(covariance_matrix, LinearOperator)
if self._islazy:
if validate_args:
ms = mean.size(-1)
cs1 = covariance_matrix.size(-1)
cs2 = covariance_matrix.size(-2)
if not (ms == cs1 and ms == cs2):
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
self.loc = mean
self._covar = covariance_matrix
self.__unbroadcasted_scale_tril = None
self._validate_args = validate_args
batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2])

event_shape = self.loc.shape[-1:]

# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)
else:
super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args)
def __init__(
self,
mean: Union[Tensor, LinearOperator],
covariance_matrix: Union[Tensor, LinearOperator],
validate_args: bool = False,
):
self._islazy = True
# casting Tensor to DenseLinearOperator because the super constructor calls cholesky, which
# will fail if the covariance matrix is semi-definite, whereas DenseLinearOperator ends up
# calling _psd_safe_cholesky, which factorizes semi-definite matrices by adding to the diagonal.
if isinstance(covariance_matrix, Tensor):
self._islazy = False # to allow _unbroadcasted_scale_tril setter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems odd to have _islazy set to True if the covariance matrix is indeed a LinearOperator. I guess the "lazy" nomenclature is a bit outdated anyway with the move to LinearOperator.

covariance_matrix = to_linear_operator(covariance_matrix)

if validate_args:
ms = mean.size(-1)
cs1 = covariance_matrix.size(-1)
cs2 = covariance_matrix.size(-2)
if not (ms == cs1 and ms == cs2):
raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}")
self.loc = mean
self._covar = covariance_matrix
self.__unbroadcasted_scale_tril = None
self._validate_args = validate_args
batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2])

event_shape = self.loc.shape[-1:]

# TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean changing the torch code to validate LinearOperator inputs? That might be somewhat challenging to do if we want to use LinearOperators there explicitly. What would work is to make changes in pure torch that would make it easier to use LinearOperator objects by means of the __torch_function__ interface we define in LinearOperator.

super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False)

def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size:
"""
Expand All @@ -81,16 +90,16 @@ def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size
def _repr_sizes(mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator]) -> str:
return f"MultivariateNormal(loc: {mean.size()}, scale: {covariance_matrix.size()})"

@property
@property # not using lazy_property here, because it does not allow for setter below
def _unbroadcasted_scale_tril(self) -> Tensor:
if self.islazy and self.__unbroadcasted_scale_tril is None:
if self.__unbroadcasted_scale_tril is None:
# cache root decoposition
ust = to_dense(self.lazy_covariance_matrix.cholesky())
self.__unbroadcasted_scale_tril = ust
return self.__unbroadcasted_scale_tril

@_unbroadcasted_scale_tril.setter
def _unbroadcasted_scale_tril(self, ust: Tensor):
def _unbroadcasted_scale_tril(self, ust: Tensor) -> None:
if self.islazy:
raise NotImplementedError("Cannot set _unbroadcasted_scale_tril for lazy MVN distributions")
else:
Expand All @@ -114,10 +123,7 @@ def base_sample_shape(self) -> torch.Size:

@lazy_property
def covariance_matrix(self) -> Tensor:
if self.islazy:
return self._covar.to_dense()
else:
return super().covariance_matrix
return self._covar.to_dense()

def confidence_region(self) -> Tuple[Tensor, Tensor]:
"""
Expand Down Expand Up @@ -157,10 +163,7 @@ def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:

@lazy_property
def lazy_covariance_matrix(self) -> LinearOperator:
if self.islazy:
return self._covar
else:
return to_linear_operator(super().covariance_matrix)
return self._covar

def log_prob(self, value: Tensor) -> Tensor:
r"""
Expand Down Expand Up @@ -304,13 +307,10 @@ def to_data_independent_dist(self) -> torch.distributions.Normal:

@property
def variance(self) -> Tensor:
if self.islazy:
# overwrite this since torch MVN uses unbroadcasted_scale_tril for this
diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
diag = diag.view(diag.shape[:-1] + self._event_shape)
variance = diag.expand(self._batch_shape + self._event_shape)
else:
variance = super().variance
# overwrite this since torch MVN uses unbroadcasted_scale_tril for this
diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2)
diag = diag.view(diag.shape[:-1] + self._event_shape)
variance = diag.expand(self._batch_shape + self._event_shape)

# Check to make sure that variance isn't lower than minimum allowed value (default 1e-6).
# This ensures that all variances are positive
Expand Down
14 changes: 14 additions & 0 deletions test/distributions/test_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ def test_multivariate_normal_non_lazy(self, cuda=False):
self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 3]))
self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3]))

# testing with semi-definite input
A = torch.randn(len(mean), 1)
covmat = A @ A.T
handles_psd = False
try:
# the regular call fails:
# mvn = TMultivariateNormal(loc=mean, covariance_matrix=covmat, validate_args=True)
mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True)
mvn.sample()
handles_psd = True
except ValueError:
handles_psd = False
self.assertTrue(handles_psd)

def test_multivariate_normal_non_lazy_cuda(self):
if torch.cuda.is_available():
with least_used_cuda_device():
Expand Down