From 570d43fd5a7e4e365401dfc0ede89f63b8c49743 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Fri, 10 Mar 2023 12:03:56 -0500 Subject: [PATCH] Casting Tensor to LinearOperator in constructor of MultivariateNormal to ensure PSD-safe factorization --- .../multitask_multivariate_normal.py | 5 +- gpytorch/distributions/multivariate_normal.py | 78 +++++++++---------- .../distributions/test_multivariate_normal.py | 14 ++++ 3 files changed, 57 insertions(+), 40 deletions(-) diff --git a/gpytorch/distributions/multitask_multivariate_normal.py b/gpytorch/distributions/multitask_multivariate_normal.py index b692217f4..9a35ca131 100644 --- a/gpytorch/distributions/multitask_multivariate_normal.py +++ b/gpytorch/distributions/multitask_multivariate_normal.py @@ -3,6 +3,7 @@ import torch from linear_operator import LinearOperator, to_linear_operator from linear_operator.operators import BlockDiagLinearOperator, BlockInterleavedLinearOperator, CatLinearOperator +from torch import Tensor from .multivariate_normal import MultivariateNormal @@ -24,7 +25,9 @@ class MultitaskMultivariateNormal(MultivariateNormal): w.r.t. inter-observation covariance for each task. """ - def __init__(self, mean, covariance_matrix, validate_args=False, interleaved=True): + def __init__( + self, mean: Tensor, covariance_matrix: LinearOperator, validate_args: bool = False, interleaved: bool = True + ): if not torch.is_tensor(mean) and not isinstance(mean, LinearOperator): raise RuntimeError("The mean of a MultitaskMultivariateNormal must be a Tensor or LinearOperator") diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index d97110aba..3a26aac96 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -42,27 +42,36 @@ class MultivariateNormal(TMultivariateNormal, Distribution): :ivar torch.Tensor variance: The variance. """ - def __init__(self, mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator], validate_args: bool = False): - self._islazy = isinstance(mean, LinearOperator) or isinstance(covariance_matrix, LinearOperator) - if self._islazy: - if validate_args: - ms = mean.size(-1) - cs1 = covariance_matrix.size(-1) - cs2 = covariance_matrix.size(-2) - if not (ms == cs1 and ms == cs2): - raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}") - self.loc = mean - self._covar = covariance_matrix - self.__unbroadcasted_scale_tril = None - self._validate_args = validate_args - batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2]) - - event_shape = self.loc.shape[-1:] - - # TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic - super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False) - else: - super().__init__(loc=mean, covariance_matrix=covariance_matrix, validate_args=validate_args) + def __init__( + self, + mean: Union[Tensor, LinearOperator], + covariance_matrix: Union[Tensor, LinearOperator], + validate_args: bool = False, + ): + self._islazy = True + # casting Tensor to DenseLinearOperator because the super constructor calls cholesky, which + # will fail if the covariance matrix is semi-definite, whereas DenseLinearOperator ends up + # calling _psd_safe_cholesky, which factorizes semi-definite matrices by adding to the diagonal. + if isinstance(covariance_matrix, Tensor): + self._islazy = False # to allow _unbroadcasted_scale_tril setter + covariance_matrix = to_linear_operator(covariance_matrix) + + if validate_args: + ms = mean.size(-1) + cs1 = covariance_matrix.size(-1) + cs2 = covariance_matrix.size(-2) + if not (ms == cs1 and ms == cs2): + raise ValueError(f"Wrong shapes in {self._repr_sizes(mean, covariance_matrix)}") + self.loc = mean + self._covar = covariance_matrix + self.__unbroadcasted_scale_tril = None + self._validate_args = validate_args + batch_shape = torch.broadcast_shapes(self.loc.shape[:-1], covariance_matrix.shape[:-2]) + + event_shape = self.loc.shape[-1:] + + # TODO: Integrate argument validation for LinearOperators into torch.distribution validation logic + super(TMultivariateNormal, self).__init__(batch_shape, event_shape, validate_args=False) def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size: """ @@ -81,16 +90,16 @@ def _extended_shape(self, sample_shape: torch.Size = torch.Size()) -> torch.Size def _repr_sizes(mean: Tensor, covariance_matrix: Union[Tensor, LinearOperator]) -> str: return f"MultivariateNormal(loc: {mean.size()}, scale: {covariance_matrix.size()})" - @property + @property # not using lazy_property here, because it does not allow for setter below def _unbroadcasted_scale_tril(self) -> Tensor: - if self.islazy and self.__unbroadcasted_scale_tril is None: + if self.__unbroadcasted_scale_tril is None: # cache root decoposition ust = to_dense(self.lazy_covariance_matrix.cholesky()) self.__unbroadcasted_scale_tril = ust return self.__unbroadcasted_scale_tril @_unbroadcasted_scale_tril.setter - def _unbroadcasted_scale_tril(self, ust: Tensor): + def _unbroadcasted_scale_tril(self, ust: Tensor) -> None: if self.islazy: raise NotImplementedError("Cannot set _unbroadcasted_scale_tril for lazy MVN distributions") else: @@ -114,10 +123,7 @@ def base_sample_shape(self) -> torch.Size: @lazy_property def covariance_matrix(self) -> Tensor: - if self.islazy: - return self._covar.to_dense() - else: - return super().covariance_matrix + return self._covar.to_dense() def confidence_region(self) -> Tuple[Tensor, Tensor]: """ @@ -157,10 +163,7 @@ def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor: @lazy_property def lazy_covariance_matrix(self) -> LinearOperator: - if self.islazy: - return self._covar - else: - return to_linear_operator(super().covariance_matrix) + return self._covar def log_prob(self, value: Tensor) -> Tensor: r""" @@ -304,13 +307,10 @@ def to_data_independent_dist(self) -> torch.distributions.Normal: @property def variance(self) -> Tensor: - if self.islazy: - # overwrite this since torch MVN uses unbroadcasted_scale_tril for this - diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2) - diag = diag.view(diag.shape[:-1] + self._event_shape) - variance = diag.expand(self._batch_shape + self._event_shape) - else: - variance = super().variance + # overwrite this since torch MVN uses unbroadcasted_scale_tril for this + diag = self.lazy_covariance_matrix.diagonal(dim1=-1, dim2=-2) + diag = diag.view(diag.shape[:-1] + self._event_shape) + variance = diag.expand(self._batch_shape + self._event_shape) # Check to make sure that variance isn't lower than minimum allowed value (default 1e-6). # This ensures that all variances are positive diff --git a/test/distributions/test_multivariate_normal.py b/test/distributions/test_multivariate_normal.py index b9cc2568a..e306e4729 100644 --- a/test/distributions/test_multivariate_normal.py +++ b/test/distributions/test_multivariate_normal.py @@ -47,6 +47,20 @@ def test_multivariate_normal_non_lazy(self, cuda=False): self.assertTrue(mvn.sample(torch.Size([2])).shape == torch.Size([2, 3])) self.assertTrue(mvn.sample(torch.Size([2, 4])).shape == torch.Size([2, 4, 3])) + # testing with semi-definite input + A = torch.randn(len(mean), 1) + covmat = A @ A.T + handles_psd = False + try: + # the regular call fails: + # mvn = TMultivariateNormal(loc=mean, covariance_matrix=covmat, validate_args=True) + mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True) + mvn.sample() + handles_psd = True + except ValueError: + handles_psd = False + self.assertTrue(handles_psd) + def test_multivariate_normal_non_lazy_cuda(self): if torch.cuda.is_available(): with least_used_cuda_device():