From 6bf4ba191fef5bd772391fb56a044792c2925335 Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Wed, 1 Mar 2023 18:38:06 +0100 Subject: [PATCH 01/23] Fix prediction with NaN values in training labels --- gpytorch/models/exact_prediction_strategies.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index c2107a26e..596bdfc4a 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -236,7 +236,15 @@ def mean_cache(self): train_mean, train_train_covar = mvn.loc, mvn.lazy_covariance_matrix train_labels_offset = (self.train_labels - train_mean).unsqueeze(-1) - mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1) + + nan_labels = torch.isnan(self.train_labels) + if not torch.any(nan_labels): + mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1) + else: + non_nan_labels = torch.where(~nan_labels)[0] + mean_cache = torch.full_like(self.train_labels, torch.nan) + non_nan_kernel = train_train_covar[..., non_nan_labels, :][..., :, non_nan_labels].evaluate_kernel() + mean_cache[~nan_labels] = non_nan_kernel.solve(train_labels_offset[non_nan_labels]).squeeze(-1) if settings.detach_test_caches.on(): mean_cache = mean_cache.detach() @@ -285,7 +293,12 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # NOTE TO FUTURE SELF: # You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact # GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no! - res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1) + nan_means = torch.isnan(self.mean_cache) + if not torch.any(nan_means): + res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1) + else: + non_nan_idx = torch.where(~nan_means)[0] + res = (test_train_covar[..., non_nan_idx] @ self.mean_cache[non_nan_idx].unsqueeze(-1)).squeeze(-1) res = res + test_mean return res From 1c4e19ce448110d9a66aae0b3f056b5db1b7c9ad Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Thu, 2 Mar 2023 14:35:12 +0100 Subject: [PATCH 02/23] Missing observation support for multitask and allow MultivariateMultitaskNormal indexing --- .../multitask_multivariate_normal.py | 111 +++++++++++++++++- gpytorch/distributions/multivariate_normal.py | 3 +- gpytorch/mlls/__init__.py | 3 +- .../mlls/exact_marginal_log_likelihood.py | 25 ++++ .../models/exact_prediction_strategies.py | 18 +-- 5 files changed, 147 insertions(+), 13 deletions(-) diff --git a/gpytorch/distributions/multitask_multivariate_normal.py b/gpytorch/distributions/multitask_multivariate_normal.py index b692217f4..bf6d79191 100644 --- a/gpytorch/distributions/multitask_multivariate_normal.py +++ b/gpytorch/distributions/multitask_multivariate_normal.py @@ -2,7 +2,12 @@ import torch from linear_operator import LinearOperator, to_linear_operator -from linear_operator.operators import BlockDiagLinearOperator, BlockInterleavedLinearOperator, CatLinearOperator +from linear_operator.operators import ( + BlockDiagLinearOperator, + BlockInterleavedLinearOperator, + CatLinearOperator, + DiagLinearOperator, +) from .multivariate_normal import MultivariateNormal @@ -18,7 +23,7 @@ class MultitaskMultivariateNormal(MultivariateNormal): :param torch.Tensor mean: An `n x t` or batch `b x n x t` matrix of means for the MVN distribution. :param ~linear_operator.operators.LinearOperator covar: An `... x NT x NT` (batch) matrix. covariance matrix of MVN distribution. - :param bool validate_args: (default=False) If True, validate `mean` anad `covariance_matrix` arguments. + :param bool validate_args: (default=False) If True, validate `mean` and `covariance_matrix` arguments. :param bool interleaved: (default=True) If True, covariance matrix is interpreted as block-diagonal w.r.t. inter-task covariances for each observation. If False, it is interpreted as block-diagonal w.r.t. inter-observation covariance for each task. @@ -275,3 +280,105 @@ def variance(self): new_shape = self._output_shape[:-2] + self._output_shape[:-3:-1] return var.view(new_shape).transpose(-1, -2).contiguous() return var.view(self._output_shape) + + def __getitem__(self, idx) -> MultivariateNormal: + r""" + Constructs a new MultivariateNormal that represents a random variable + modified by an indexing operation. + + The mean and covariance matrix arguments are indexed accordingly. + + :param idx: Index to apply to the mean. The covariance matrix is indexed accordingly. + :return: If indices specify a slice for samples and tasks, returns a + MultitaskMultivariateNormal, else returns a MultivariateNormal. + """ + + # Normalize index to a tuple + if not isinstance(idx, tuple): + idx = (idx,) + + if ... in idx: + # Replace ellipsis '...' with explicit indices + ellipsis_location = idx.index(...) + if ... in idx[ellipsis_location + 1 :]: + raise IndexError("Only one ellipsis '...' is supported!") + prefix = idx[:ellipsis_location] + suffix = idx[ellipsis_location + 1 :] + infix_length = self.mean.dim() - prefix - suffix + if infix_length < 0: + raise IndexError(f"Index {idx} has too many dimensions") + idx = prefix + (slice(None),) * infix_length + suffix + elif len(idx) == self.mean.dim() - 1: + # Normalize indices ignoring the task-index to include it + idx = idx + (slice(None),) + + new_mean = self.mean[idx] + + # We now create a covariance matrix appropriate for new_mean + if len(idx) <= self.mean.dim() - 2: + # We are only indexing the batch dimensions in this case + return MultitaskMultivariateNormal( + mean=new_mean, + covariance_matrix=self.lazy_covariance_matrix[idx], + interleaved=self._interleaved, + ) + elif len(idx) > self.mean.dim(): + raise IndexError(f"Index {idx} has too many dimensions") + else: + # We have an index that extends over all dimensions + batch_idx = idx[:-2] + if self._interleaved: + row_idx = idx[-2] + col_idx = idx[-1] + num_rows = self._output_shape[-2] + num_cols = self._output_shape[-1] + else: + row_idx = idx[-1] + col_idx = idx[-2] + num_rows = self._output_shape[-1] + num_cols = self._output_shape[-2] + + if isinstance(row_idx, int) and isinstance(col_idx, int): + # Single sample with single task + new_cov = DiagLinearOperator( + self.lazy_covariance_matrix.diagonal()[batch_idx + (row_idx * num_cols + col_idx,)] + ) + return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov) + elif isinstance(row_idx, int) and isinstance(col_idx, slice): + # A block of the covariance matrix + new_slice = slice( + (col_idx.start if col_idx.start else 0) + row_idx * num_cols, + (col_idx.stop if col_idx.stop else num_cols) + row_idx * num_cols, + col_idx.step, + ) + new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)] + return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov) + elif isinstance(row_idx, slice) and isinstance(col_idx, int): + # A block of the reversely interleaved covariance matrix + new_slice = slice( + row_idx.start if row_idx.start else 0, + (row_idx.stop if row_idx.stop else num_rows) * num_cols, + (row_idx.step if row_idx.step else 1) * num_cols, + ) + new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)] + return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov) + elif isinstance(row_idx, slice) or isinstance(col_idx, slice): + # slice x slice or indices x slice or slice x indices + if isinstance(row_idx, slice): + row_idx = torch.arange(num_rows)[row_idx] + if isinstance(col_idx, slice): + col_idx = torch.arange(num_cols)[col_idx] + row_grid, col_grid = torch.meshgrid(row_idx, col_idx, indexing="ij") + indices = (row_grid * num_cols + col_grid).reshape(-1) + new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices] + return MultitaskMultivariateNormal( + mean=new_mean, covariance_matrix=new_cov, interleaved=self._interleaved, validate_args=False + ) + else: + # row_idx and col_idx have pairs of indices + indices = row_idx * num_cols + col_idx + new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices] + return MultivariateNormal( + mean=new_mean, + covariance_matrix=new_cov, + ) diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index d97110aba..1d7f75241 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -343,7 +343,7 @@ def __getitem__(self, idx) -> MultivariateNormal: The mean and covariance matrix arguments are indexed accordingly. - :param idx: Index to apply. + :param idx: Index to apply to the mean. The covariance matrix is indexed accordingly. """ if not isinstance(idx, tuple): @@ -353,6 +353,7 @@ def __getitem__(self, idx) -> MultivariateNormal: new_mean = self.mean[idx] if len(idx) <= self.mean.dim() - 1 and (Ellipsis not in rest_idx): + # We are only indexing the batch dimensions in this case new_cov = self.lazy_covariance_matrix[idx] elif len(idx) > self.mean.dim(): raise IndexError(f"Index {idx} has too many dimensions") diff --git a/gpytorch/mlls/__init__.py b/gpytorch/mlls/__init__.py index d5d358f9a..a5d4d5211 100644 --- a/gpytorch/mlls/__init__.py +++ b/gpytorch/mlls/__init__.py @@ -5,7 +5,7 @@ from .added_loss_term import AddedLossTerm from .deep_approximate_mll import DeepApproximateMLL from .deep_predictive_log_likelihood import DeepPredictiveLogLikelihood -from .exact_marginal_log_likelihood import ExactMarginalLogLikelihood +from .exact_marginal_log_likelihood import ExactMarginalLogLikelihood, ExactMarginalLogLikelihoodWithMissingObs from .gamma_robust_variational_elbo import GammaRobustVariationalELBO from .inducing_point_kernel_added_loss_term import InducingPointKernelAddedLossTerm from .kl_gaussian_added_loss_term import KLGaussianAddedLossTerm @@ -39,6 +39,7 @@ def __init__(self, *args, **kwargs): "DeepApproximateMLL", "DeepPredictiveLogLikelihood", "ExactMarginalLogLikelihood", + "ExactMarginalLogLikelihoodWithMissingObs", "InducingPointKernelAddedLossTerm", "LeaveOneOutPseudoLikelihood", "KLGaussianAddedLossTerm", diff --git a/gpytorch/mlls/exact_marginal_log_likelihood.py b/gpytorch/mlls/exact_marginal_log_likelihood.py index 9f8aa2ee5..76be3d968 100644 --- a/gpytorch/mlls/exact_marginal_log_likelihood.py +++ b/gpytorch/mlls/exact_marginal_log_likelihood.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import torch from ..distributions import MultivariateNormal from ..likelihoods import _GaussianLikelihoodBase @@ -67,3 +68,27 @@ def forward(self, function_dist, target, *params): # Scale by the amount of data we have num_data = function_dist.event_shape.numel() return res.div_(num_data) + + +class ExactMarginalLogLikelihoodWithMissingObs(ExactMarginalLogLikelihood): + """ + Like :obj:`~gpytorch.models.ExactGP` but with support for NaN values in the target. + These are just ignored for computation of the marginal. + """ + + def forward(self, function_dist, target, *params): + + if not isinstance(function_dist, MultivariateNormal): + raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables") + + # Only operate on observed variables + observed = torch.nonzero(~torch.isnan(target), as_tuple=True) + + # Get the log prob of the marginal distribution + output = self.likelihood(function_dist, *params)[observed] + res = output.log_prob(target[observed]) + res = self._add_other_terms(res, params) + + # Scale by the amount of data we have + num_data = output.event_shape.numel() + return res.div_(num_data) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 596bdfc4a..668332687 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -237,14 +237,14 @@ def mean_cache(self): train_labels_offset = (self.train_labels - train_mean).unsqueeze(-1) - nan_labels = torch.isnan(self.train_labels) - if not torch.any(nan_labels): + not_observed = torch.isnan(self.train_labels) + if not torch.any(not_observed): mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1) else: - non_nan_labels = torch.where(~nan_labels)[0] + observed = torch.where(~not_observed)[0] mean_cache = torch.full_like(self.train_labels, torch.nan) - non_nan_kernel = train_train_covar[..., non_nan_labels, :][..., :, non_nan_labels].evaluate_kernel() - mean_cache[~nan_labels] = non_nan_kernel.solve(train_labels_offset[non_nan_labels]).squeeze(-1) + non_nan_kernel = train_train_covar[..., observed, :][..., :, observed].evaluate_kernel() + mean_cache[~not_observed] = non_nan_kernel.solve(train_labels_offset[observed]).squeeze(-1) if settings.detach_test_caches.on(): mean_cache = mean_cache.detach() @@ -293,12 +293,12 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # NOTE TO FUTURE SELF: # You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact # GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no! - nan_means = torch.isnan(self.mean_cache) - if not torch.any(nan_means): + not_observed = torch.isnan(self.mean_cache) + if not torch.any(not_observed): res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1) else: - non_nan_idx = torch.where(~nan_means)[0] - res = (test_train_covar[..., non_nan_idx] @ self.mean_cache[non_nan_idx].unsqueeze(-1)).squeeze(-1) + observed = torch.where(~not_observed)[0] + res = (test_train_covar[..., observed] @ self.mean_cache[observed].unsqueeze(-1)).squeeze(-1) res = res + test_mean return res From d0cf6515048eb564cc42c7471f56341706dfdeac Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Thu, 2 Mar 2023 15:36:41 +0100 Subject: [PATCH 03/23] Fix error in MultitaskMultivariateNormal indexing on '...' --- gpytorch/distributions/multitask_multivariate_normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gpytorch/distributions/multitask_multivariate_normal.py b/gpytorch/distributions/multitask_multivariate_normal.py index bf6d79191..5a161081f 100644 --- a/gpytorch/distributions/multitask_multivariate_normal.py +++ b/gpytorch/distributions/multitask_multivariate_normal.py @@ -304,7 +304,7 @@ def __getitem__(self, idx) -> MultivariateNormal: raise IndexError("Only one ellipsis '...' is supported!") prefix = idx[:ellipsis_location] suffix = idx[ellipsis_location + 1 :] - infix_length = self.mean.dim() - prefix - suffix + infix_length = self.mean.dim() - len(prefix) - len(suffix) if infix_length < 0: raise IndexError(f"Index {idx} has too many dimensions") idx = prefix + (slice(None),) * infix_length + suffix From 2c4ac2f5cab1f17ac9ab24c1f041aa73efb0c51e Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Thu, 2 Mar 2023 16:17:22 +0100 Subject: [PATCH 04/23] Fix indexing with negative values --- .../multitask_multivariate_normal.py | 40 +++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/gpytorch/distributions/multitask_multivariate_normal.py b/gpytorch/distributions/multitask_multivariate_normal.py index 5a161081f..bce1abc86 100644 --- a/gpytorch/distributions/multitask_multivariate_normal.py +++ b/gpytorch/distributions/multitask_multivariate_normal.py @@ -340,26 +340,28 @@ def __getitem__(self, idx) -> MultivariateNormal: if isinstance(row_idx, int) and isinstance(col_idx, int): # Single sample with single task + row_idx = _normalize_index(row_idx, num_rows) + col_idx = _normalize_index(col_idx, num_cols) new_cov = DiagLinearOperator( self.lazy_covariance_matrix.diagonal()[batch_idx + (row_idx * num_cols + col_idx,)] ) return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov) elif isinstance(row_idx, int) and isinstance(col_idx, slice): # A block of the covariance matrix + row_idx = _normalize_index(row_idx, num_rows) + col_idx = _normalize_slice(col_idx, num_cols) new_slice = slice( - (col_idx.start if col_idx.start else 0) + row_idx * num_cols, - (col_idx.stop if col_idx.stop else num_cols) + row_idx * num_cols, + col_idx.start + row_idx * num_cols, + col_idx.stop + row_idx * num_cols, col_idx.step, ) new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)] return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov) elif isinstance(row_idx, slice) and isinstance(col_idx, int): # A block of the reversely interleaved covariance matrix - new_slice = slice( - row_idx.start if row_idx.start else 0, - (row_idx.stop if row_idx.stop else num_rows) * num_cols, - (row_idx.step if row_idx.step else 1) * num_cols, - ) + row_idx = _normalize_slice(row_idx, num_rows) + col_idx = _normalize_index(col_idx, num_cols) + new_slice = slice(row_idx.start + col_idx, row_idx.stop * num_cols + col_idx, row_idx.step * num_cols) new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)] return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov) elif isinstance(row_idx, slice) or isinstance(col_idx, slice): @@ -382,3 +384,27 @@ def __getitem__(self, idx) -> MultivariateNormal: mean=new_mean, covariance_matrix=new_cov, ) + + +def _normalize_index(i: int, dim_size: int) -> int: + if i < 0: + return dim_size + i + else: + return i + + +def _normalize_slice(s: slice, dim_size: int) -> slice: + start = s.start + if start is None: + start = 0 + elif start < 0: + start = dim_size + start + stop = s.stop + if stop is None: + stop = dim_size + elif stop < 0: + stop = dim_size + stop + step = s.step + if step is None: + step = 1 + return slice(start, stop, step) From 03f2dd2224df787feb0711d0511d2f74193b8042 Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Fri, 3 Mar 2023 18:39:10 +0100 Subject: [PATCH 05/23] Add tests - Indexing MultitaskMultivariateNormal - Missing data in single-task and multitask models --- .../test_multitask_multivariate_normal.py | 150 ++++++++++++++++++ test/examples/test_missing_data.py | 95 +++++++++++ 2 files changed, 245 insertions(+) create mode 100644 test/examples/test_missing_data.py diff --git a/test/distributions/test_multitask_multivariate_normal.py b/test/distributions/test_multitask_multivariate_normal.py index 773a1b004..4df3a7be8 100644 --- a/test/distributions/test_multitask_multivariate_normal.py +++ b/test/distributions/test_multitask_multivariate_normal.py @@ -315,6 +315,156 @@ def test_multitask_multivariate_normal_broadcasting(self): covar = _covar @ _covar.transpose(-1, -2) MultitaskMultivariateNormal(mean, covar) + def test_getitem_interleaved(self): + mean_shape = (2, 4, 3, 2) + covar_shape = (2, 4, 6, 6) + mean = torch.randn(mean_shape) + _covar = torch.randn(covar_shape) + covar = _covar @ _covar.transpose(-1, -2) + distribution = MultitaskMultivariateNormal(mean, covar, validate_args=True) + + def flat(observation: int, task: int) -> int: + return observation * 2 + task + + part = distribution[1, -1] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size(())) + self.assertEqual(part.event_shape, torch.Size((3, 2))) + self.assertAllClose(part.mean, mean[1, -1]) + self.assertAllClose(part.covariance_matrix, covar[1, -1]) + + part = distribution[..., 2, 1] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2,))) + self.assertEqual(part.event_shape, (4,)) + self.assertAllClose(part.mean, mean[..., 2, 1]) + self.assertAllClose(part.covariance_matrix, torch.diag_embed(covar[:, :, flat(2, 1), flat(2, 1)])) + + part = distribution[1, ..., -2] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((4,))) + self.assertEqual(part.event_shape, torch.Size((3,))) + self.assertAllClose(part.mean, mean[1, :, :, 0]) + self.assertAllClose(part.covariance_matrix, covar[1, :, ::2, ::2]) + + part = distribution[..., 2, :] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((2,))) + self.assertAllClose(part.mean, mean[:, :, 2, :]) + self.assertAllClose(part.covariance_matrix, covar[:, :, 2 * 2 : 3 * 2, 2 * 2 : 3 * 2]) + + part = distribution[0, :, :, torch.tensor([1, 0])] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((4,))) + self.assertEqual(part.event_shape, torch.Size((3, 2))) + self.assertAllClose(part.mean, mean[0, ..., torch.tensor([1, 0])]) + indices = torch.tensor([1, 0, 3, 2, 5, 4]) + self.assertAllClose(part.covariance_matrix, covar[0, :, indices][..., indices]) + + part = distribution[:, 1, torch.tensor([2, 0])] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2,))) + self.assertEqual(part.event_shape, torch.Size((2, 2))) + self.assertAllClose(part.mean, mean[:, 1, torch.tensor([2, 0])]) + indices = torch.tensor([4, 5, 0, 1]) + self.assertAllClose(part.covariance_matrix, covar[:, 1, indices][..., indices]) + + part = distribution[..., 1:, :-1] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((2, 1))) + self.assertAllClose(part.mean, mean[..., 1:, :-1]) + indices = torch.tensor([flat(1, 0), flat(2, 0)]) + self.assertAllClose(part.covariance_matrix, covar[..., indices, :][..., indices]) + + part = distribution[..., torch.tensor([2, 0, 2]), torch.tensor([1, 0, 0])] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((3,))) + self.assertAllClose(part.mean, mean[..., torch.tensor([2, 0, 2]), torch.tensor([1, 0, 0])]) + indices = torch.tensor([flat(2, 1), flat(0, 0), flat(2, 0)]) + self.assertAllClose(part.covariance_matrix, covar[..., indices, :][..., indices]) + + def test_getitem_non_interleaved(self): + mean_shape = (2, 4, 3, 2) + covar_shape = (2, 4, 6, 6) + mean = torch.randn(mean_shape) + _covar = torch.randn(covar_shape) + covar = _covar @ _covar.transpose(-1, -2) + distribution = MultitaskMultivariateNormal(mean, covar, validate_args=True, interleaved=False) + + def flat(observation: int, task: int) -> int: + return task * 3 + observation + + part = distribution[1, -1] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size(())) + self.assertEqual(part.event_shape, torch.Size((3, 2))) + self.assertAllClose(part.mean, mean[1, -1]) + self.assertAllClose(part.covariance_matrix, covar[1, -1]) + + part = distribution[..., 2, 1] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2,))) + self.assertEqual(part.event_shape, (4,)) + self.assertAllClose(part.mean, mean[..., 2, 1]) + self.assertAllClose(part.covariance_matrix, torch.diag_embed(covar[:, :, flat(2, 1), flat(2, 1)])) + + part = distribution[1, ..., -2] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((4,))) + self.assertEqual(part.event_shape, torch.Size((3,))) + self.assertAllClose(part.mean, mean[1, :, :, 0]) + self.assertAllClose(part.covariance_matrix, covar[1, :, :3, :3]) + + part = distribution[..., 2, :] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((2,))) + self.assertAllClose(part.mean, mean[:, :, 2, :]) + self.assertAllClose(part.covariance_matrix, covar[:, :, 2::3, 2::3]) + + part = distribution[0, :, :, torch.tensor([1, 0])] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((4,))) + self.assertEqual(part.event_shape, torch.Size((3, 2))) + self.assertAllClose(part.mean, mean[0, ..., torch.tensor([1, 0])]) + indices = torch.tensor([3, 4, 5, 0, 1, 2]) + self.assertAllClose(part.covariance_matrix, covar[0, :, indices][..., indices]) + + part = distribution[:, 1, torch.tensor([2, 0])] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2,))) + self.assertEqual(part.event_shape, torch.Size((2, 2))) + self.assertAllClose(part.mean, mean[:, 1, torch.tensor([2, 0])]) + indices = torch.tensor([2, 0, 5, 3]) + self.assertAllClose(part.covariance_matrix, covar[:, 1, indices][..., indices]) + + part = distribution[..., 1:, :-1] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((2, 1))) + self.assertAllClose(part.mean, mean[..., 1:, :-1]) + indices = torch.tensor([flat(1, 0), flat(2, 0)]) + self.assertAllClose(part.covariance_matrix, covar[..., indices, :][..., indices]) + + part = distribution[..., torch.tensor([2, 0, 2]), torch.tensor([1, 0, 0])] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((3,))) + self.assertAllClose(part.mean, mean[..., torch.tensor([2, 0, 2]), torch.tensor([1, 0, 0])]) + indices = torch.tensor([flat(2, 1), flat(0, 0), flat(2, 0)]) + self.assertAllClose(part.covariance_matrix, covar[..., indices, :][..., indices]) + if __name__ == "__main__": unittest.main() diff --git a/test/examples/test_missing_data.py b/test/examples/test_missing_data.py new file mode 100644 index 000000000..e970c4f60 --- /dev/null +++ b/test/examples/test_missing_data.py @@ -0,0 +1,95 @@ +import unittest + +import torch + +from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal +from gpytorch.kernels import ScaleKernel, RBFKernel, MultitaskKernel +from gpytorch.likelihoods import GaussianLikelihood, Likelihood, MultitaskGaussianLikelihood +from gpytorch.means import ConstantMean, MultitaskMean +from gpytorch.mlls import ExactMarginalLogLikelihoodWithMissingObs +from gpytorch.models import ExactGP +from gpytorch.test.base_test_case import BaseTestCase + + +class SingleGPModel(ExactGP): + def __init__(self, train_inputs, train_targets, likelihood): + super(SingleGPModel, self).__init__(train_inputs, train_targets, likelihood) + self.mean_module = ConstantMean() + self.covar_module = ScaleKernel(RBFKernel()) + + def forward(self, x): + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) + return MultivariateNormal(mean_x, covar_x) + + +class MultitaskGPModel(ExactGP): + def __init__(self, train_inputs, train_targets, likelihood, num_tasks): + super(MultitaskGPModel, self).__init__(train_inputs, train_targets, likelihood) + self.mean_module = MultitaskMean(ConstantMean(), num_tasks=num_tasks) + self.covar_module = MultitaskKernel(ScaleKernel(RBFKernel()), num_tasks=num_tasks) + + def forward(self, x): + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) + return MultitaskMultivariateNormal(mean_x, covar_x) + + +class TestMissingData(BaseTestCase, unittest.TestCase): + seed = 1 + + def _train(self, model: ExactGP, likelihood: Likelihood): + model.train() + likelihood.train() + + mll = ExactMarginalLogLikelihoodWithMissingObs(likelihood, model) + optimizer = torch.optim.Adam(model.parameters(), lr=0.15) + + for _ in range(20): + optimizer.zero_grad() + output = model(*model.train_inputs) + loss = mll(output, model.train_targets) + self.assertFalse(torch.any(torch.isnan(output.mean)).item()) + self.assertFalse(torch.any(torch.isnan(output.covariance_matrix)).item()) + self.assertFalse(torch.isnan(loss).item()) + loss.backward() + optimizer.step() + + model.eval() + likelihood.eval() + + def test_single(self): + train_x = torch.linspace(0, 1, 21) + test_x = torch.linspace(0, 1, 51) + train_y = torch.sin(2 * torch.pi * train_x) + train_y += torch.normal(0, 0.01, train_y.shape) + train_y[::2] = torch.nan + + likelihood = GaussianLikelihood() + model = SingleGPModel(train_x, train_y, likelihood) + self._train(model, likelihood) + + with torch.no_grad(): + prediction = model(test_x) + + self.assertFalse(torch.any(torch.isnan(prediction.mean)).item()) + self.assertFalse(torch.any(torch.isnan(prediction.covariance_matrix)).item()) + + def test_multitask(self): + num_tasks = 10 + train_x = torch.linspace(0, 1, 21) + test_x = torch.linspace(0, 1, 51) + train_y = torch.sin(2 * torch.pi * train_x)[:, None] * torch.rand(1, num_tasks) + train_y += torch.normal(0, 0.01, train_y.shape) + train_y[::3, : num_tasks // 2] = torch.nan + train_y[::4, num_tasks // 2 :] = torch.nan + + likelihood = MultitaskGaussianLikelihood(num_tasks) + model = MultitaskGPModel(train_x, train_y, likelihood, num_tasks) + self._train(model, likelihood) + + with torch.no_grad(): + prediction = model(test_x) + + self.assertFalse(torch.any(torch.isnan(prediction.mean)).item()) + self.assertFalse(torch.any(torch.isnan(prediction.covariance_matrix)).item()) From 270216b86fddece4a9ab3ccf69ca65b379c93c32 Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Sun, 5 Mar 2023 10:20:35 +0100 Subject: [PATCH 06/23] Render docs for MultitaskMultivariateNormal indexing and missing observations --- docs/source/distributions.rst | 1 + docs/source/marginal_log_likelihoods.rst | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index e7f3788ca..a98bd862b 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -44,6 +44,7 @@ MultitaskMultivariateNormal .. autoclass:: MultitaskMultivariateNormal :members: + :special-members: __getitem__ Delta diff --git a/docs/source/marginal_log_likelihoods.rst b/docs/source/marginal_log_likelihoods.rst index 20bb840bc..62ebe6ac3 100644 --- a/docs/source/marginal_log_likelihoods.rst +++ b/docs/source/marginal_log_likelihoods.rst @@ -37,6 +37,12 @@ These are MLLs for use with :obj:`~gpytorch.models.ExactGP` modules. They comput .. autoclass:: ExactMarginalLogLikelihood :members: +:hidden:`ExactMarginalLogLikelihoodWithMissingObs` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ExactMarginalLogLikelihoodWithMissingObs + :members: + :hidden:`LeaveOneOutPseudoLikelihood` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 954bed9aa6f4725542cea52b2fea6bfd286f6c37 Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Sun, 5 Mar 2023 10:29:01 +0100 Subject: [PATCH 07/23] Fix docs warning --- docs/source/marginal_log_likelihoods.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/marginal_log_likelihoods.rst b/docs/source/marginal_log_likelihoods.rst index 62ebe6ac3..b39cae2d2 100644 --- a/docs/source/marginal_log_likelihoods.rst +++ b/docs/source/marginal_log_likelihoods.rst @@ -32,19 +32,19 @@ Exact GP Inference These are MLLs for use with :obj:`~gpytorch.models.ExactGP` modules. They compute the MLL exactly. :hidden:`ExactMarginalLogLikelihood` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: ExactMarginalLogLikelihood :members: :hidden:`ExactMarginalLogLikelihoodWithMissingObs` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: ExactMarginalLogLikelihoodWithMissingObs :members: :hidden:`LeaveOneOutPseudoLikelihood` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LeaveOneOutPseudoLikelihood :members: From 5fee76f33c822eb218fc78a096f1e0091f34d38e Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Mon, 6 Mar 2023 09:41:37 +0100 Subject: [PATCH 08/23] Fix docstring --- gpytorch/distributions/multitask_multivariate_normal.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gpytorch/distributions/multitask_multivariate_normal.py b/gpytorch/distributions/multitask_multivariate_normal.py index bce1abc86..a8661e5ba 100644 --- a/gpytorch/distributions/multitask_multivariate_normal.py +++ b/gpytorch/distributions/multitask_multivariate_normal.py @@ -288,8 +288,9 @@ def __getitem__(self, idx) -> MultivariateNormal: The mean and covariance matrix arguments are indexed accordingly. - :param idx: Index to apply to the mean. The covariance matrix is indexed accordingly. - :return: If indices specify a slice for samples and tasks, returns a + :param Any idx: Index to apply to the mean. The covariance matrix is indexed accordingly. + + :returns: If indices specify a slice for samples and tasks, returns a MultitaskMultivariateNormal, else returns a MultivariateNormal. """ From cfa5435a169540b50fa0bb626bca5e9f9441e2ae Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Wed, 8 Mar 2023 14:25:09 +0100 Subject: [PATCH 09/23] Finally fix docstring --- gpytorch/distributions/multitask_multivariate_normal.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/gpytorch/distributions/multitask_multivariate_normal.py b/gpytorch/distributions/multitask_multivariate_normal.py index a8661e5ba..20ad77ad4 100644 --- a/gpytorch/distributions/multitask_multivariate_normal.py +++ b/gpytorch/distributions/multitask_multivariate_normal.py @@ -282,16 +282,15 @@ def variance(self): return var.view(self._output_shape) def __getitem__(self, idx) -> MultivariateNormal: - r""" + """ Constructs a new MultivariateNormal that represents a random variable modified by an indexing operation. The mean and covariance matrix arguments are indexed accordingly. :param Any idx: Index to apply to the mean. The covariance matrix is indexed accordingly. - :returns: If indices specify a slice for samples and tasks, returns a - MultitaskMultivariateNormal, else returns a MultivariateNormal. + MultitaskMultivariateNormal, else returns a MultivariateNormal. """ # Normalize index to a tuple From 25315880c392898a176976abee4d520032988cae Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Wed, 8 Mar 2023 15:14:36 +0100 Subject: [PATCH 10/23] Change missing data handling to option flag --- docs/source/marginal_log_likelihoods.rst | 6 ---- gpytorch/mlls/__init__.py | 3 +- .../mlls/exact_marginal_log_likelihood.py | 35 ++++++------------- test/examples/test_missing_data.py | 4 +-- 4 files changed, 14 insertions(+), 34 deletions(-) diff --git a/docs/source/marginal_log_likelihoods.rst b/docs/source/marginal_log_likelihoods.rst index b39cae2d2..4aab913cb 100644 --- a/docs/source/marginal_log_likelihoods.rst +++ b/docs/source/marginal_log_likelihoods.rst @@ -37,12 +37,6 @@ These are MLLs for use with :obj:`~gpytorch.models.ExactGP` modules. They comput .. autoclass:: ExactMarginalLogLikelihood :members: -:hidden:`ExactMarginalLogLikelihoodWithMissingObs` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: ExactMarginalLogLikelihoodWithMissingObs - :members: - :hidden:`LeaveOneOutPseudoLikelihood` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/gpytorch/mlls/__init__.py b/gpytorch/mlls/__init__.py index a5d4d5211..d5d358f9a 100644 --- a/gpytorch/mlls/__init__.py +++ b/gpytorch/mlls/__init__.py @@ -5,7 +5,7 @@ from .added_loss_term import AddedLossTerm from .deep_approximate_mll import DeepApproximateMLL from .deep_predictive_log_likelihood import DeepPredictiveLogLikelihood -from .exact_marginal_log_likelihood import ExactMarginalLogLikelihood, ExactMarginalLogLikelihoodWithMissingObs +from .exact_marginal_log_likelihood import ExactMarginalLogLikelihood from .gamma_robust_variational_elbo import GammaRobustVariationalELBO from .inducing_point_kernel_added_loss_term import InducingPointKernelAddedLossTerm from .kl_gaussian_added_loss_term import KLGaussianAddedLossTerm @@ -39,7 +39,6 @@ def __init__(self, *args, **kwargs): "DeepApproximateMLL", "DeepPredictiveLogLikelihood", "ExactMarginalLogLikelihood", - "ExactMarginalLogLikelihoodWithMissingObs", "InducingPointKernelAddedLossTerm", "LeaveOneOutPseudoLikelihood", "KLGaussianAddedLossTerm", diff --git a/gpytorch/mlls/exact_marginal_log_likelihood.py b/gpytorch/mlls/exact_marginal_log_likelihood.py index 76be3d968..32e880709 100644 --- a/gpytorch/mlls/exact_marginal_log_likelihood.py +++ b/gpytorch/mlls/exact_marginal_log_likelihood.py @@ -18,6 +18,8 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood): :param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model :param ~gpytorch.models.ExactGP model: The exact GP model + :param bool nan_means_missing_data: If set to True, this module checks for NaN values in the output + and ignores them for calculations. Example: >>> # model is a gpytorch.models.ExactGP @@ -29,10 +31,11 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood): >>> loss.backward() """ - def __init__(self, likelihood, model): + def __init__(self, likelihood, model, nan_means_missing_data=False): if not isinstance(likelihood, _GaussianLikelihoodBase): raise RuntimeError("Likelihood must be Gaussian for exact inference") super(ExactMarginalLogLikelihood, self).__init__(likelihood, model) + self.nan_means_missing_data = nan_means_missing_data def _add_other_terms(self, res, params): # Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models) @@ -60,33 +63,17 @@ def forward(self, function_dist, target, *params): if not isinstance(function_dist, MultivariateNormal): raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables") - # Get the log prob of the marginal distribution + # Determine output likelihood output = self.likelihood(function_dist, *params) - res = output.log_prob(target) - res = self._add_other_terms(res, params) - - # Scale by the amount of data we have - num_data = function_dist.event_shape.numel() - return res.div_(num_data) - - -class ExactMarginalLogLikelihoodWithMissingObs(ExactMarginalLogLikelihood): - """ - Like :obj:`~gpytorch.models.ExactGP` but with support for NaN values in the target. - These are just ignored for computation of the marginal. - """ - - def forward(self, function_dist, target, *params): - if not isinstance(function_dist, MultivariateNormal): - raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables") - - # Only operate on observed variables - observed = torch.nonzero(~torch.isnan(target), as_tuple=True) + # Remove NaN values if enabled + if self.nan_means_missing_data: + observed = torch.nonzero(~torch.isnan(target), as_tuple=True) + output = output[observed] + target = target[observed] # Get the log prob of the marginal distribution - output = self.likelihood(function_dist, *params)[observed] - res = output.log_prob(target[observed]) + res = output.log_prob(target) res = self._add_other_terms(res, params) # Scale by the amount of data we have diff --git a/test/examples/test_missing_data.py b/test/examples/test_missing_data.py index e970c4f60..4564aec0a 100644 --- a/test/examples/test_missing_data.py +++ b/test/examples/test_missing_data.py @@ -2,11 +2,11 @@ import torch +from gpytorch import ExactMarginalLogLikelihood from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal from gpytorch.kernels import ScaleKernel, RBFKernel, MultitaskKernel from gpytorch.likelihoods import GaussianLikelihood, Likelihood, MultitaskGaussianLikelihood from gpytorch.means import ConstantMean, MultitaskMean -from gpytorch.mlls import ExactMarginalLogLikelihoodWithMissingObs from gpytorch.models import ExactGP from gpytorch.test.base_test_case import BaseTestCase @@ -42,7 +42,7 @@ def _train(self, model: ExactGP, likelihood: Likelihood): model.train() likelihood.train() - mll = ExactMarginalLogLikelihoodWithMissingObs(likelihood, model) + mll = ExactMarginalLogLikelihood(likelihood, model, nan_means_missing_data=True) optimizer = torch.optim.Adam(model.parameters(), lr=0.15) for _ in range(20): From e7fca203dff8e332b76e496671436fbef5686a08 Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Fri, 10 Mar 2023 17:22:29 +0100 Subject: [PATCH 11/23] Revamp missing value implementation - Enable via gpytorch.settings - Two modes: 'mask' and 'fill' - Makes GaussianLikelihoodWithMissingObs obsolete - Supports approximate GPs --- docs/source/likelihoods.rst | 6 - docs/source/marginal_log_likelihoods.rst | 4 +- gpytorch/distributions/multivariate_normal.py | 5 + gpytorch/likelihoods/__init__.py | 2 - gpytorch/likelihoods/gaussian_likelihood.py | 88 ++++---- .../mlls/exact_marginal_log_likelihood.py | 19 +- .../models/exact_prediction_strategies.py | 52 +++-- gpytorch/settings.py | 51 ++++- .../distributions/test_multivariate_normal.py | 4 + test/examples/test_missing_data.py | 194 +++++++++++++++--- test/likelihoods/test_gaussian_likelihood.py | 68 +++--- 11 files changed, 350 insertions(+), 143 deletions(-) diff --git a/docs/source/likelihoods.rst b/docs/source/likelihoods.rst index 2607d5ff2..0c10a12b4 100644 --- a/docs/source/likelihoods.rst +++ b/docs/source/likelihoods.rst @@ -32,12 +32,6 @@ reduce the variance when computing approximate GP objective functions. .. autoclass:: GaussianLikelihood :members: -:hidden:`GaussianLikelihoodWithMissingObs` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: GaussianLikelihoodWithMissingObs - :members: - :hidden:`FixedNoiseGaussianLikelihood` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/marginal_log_likelihoods.rst b/docs/source/marginal_log_likelihoods.rst index 4aab913cb..20bb840bc 100644 --- a/docs/source/marginal_log_likelihoods.rst +++ b/docs/source/marginal_log_likelihoods.rst @@ -32,13 +32,13 @@ Exact GP Inference These are MLLs for use with :obj:`~gpytorch.models.ExactGP` modules. They compute the MLL exactly. :hidden:`ExactMarginalLogLikelihood` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: ExactMarginalLogLikelihood :members: :hidden:`LeaveOneOutPseudoLikelihood` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LeaveOneOutPseudoLikelihood :members: diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index 1d7f75241..ce9f5f550 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -348,6 +348,11 @@ def __getitem__(self, idx) -> MultivariateNormal: if not isinstance(idx, tuple): idx = (idx,) + if len(idx) > self.mean.dim() and Ellipsis in idx: + idx = tuple(i for i in idx if i != Ellipsis) + if len(idx) < self.mean.dim(): + raise IndexError("Multiple ambiguous ellipsis in index!") + rest_idx = idx[:-1] last_idx = idx[-1] new_mean = self.mean[idx] diff --git a/gpytorch/likelihoods/__init__.py b/gpytorch/likelihoods/__init__.py index 31a370079..45fb205f4 100644 --- a/gpytorch/likelihoods/__init__.py +++ b/gpytorch/likelihoods/__init__.py @@ -7,7 +7,6 @@ DirichletClassificationLikelihood, FixedNoiseGaussianLikelihood, GaussianLikelihood, - GaussianLikelihoodWithMissingObs, ) from .laplace_likelihood import LaplaceLikelihood from .likelihood import _OneDimensionalLikelihood, Likelihood @@ -26,7 +25,6 @@ "DirichletClassificationLikelihood", "FixedNoiseGaussianLikelihood", "GaussianLikelihood", - "GaussianLikelihoodWithMissingObs", "HeteroskedasticNoise", "LaplaceLikelihood", "Likelihood", diff --git a/gpytorch/likelihoods/gaussian_likelihood.py b/gpytorch/likelihoods/gaussian_likelihood.py index 3ce286946..69bc6868c 100644 --- a/gpytorch/likelihoods/gaussian_likelihood.py +++ b/gpytorch/likelihoods/gaussian_likelihood.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 - import math import warnings from copy import deepcopy @@ -9,6 +8,7 @@ from linear_operator.operators import ZeroLinearOperator from torch import Tensor +from .. import settings from ..distributions import base_distributions, MultivariateNormal from ..utils.warnings import GPInputWarning from .likelihood import Likelihood @@ -35,17 +35,34 @@ def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: An return self.noise_covar(*params, shape=base_shape, **kwargs) def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any) -> Tensor: - mean, variance = input.mean, input.variance - num_event_dim = len(input.event_shape) - noise = self._shaped_noise_covar(mean.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2) + noise = self._shaped_noise_covar(input.mean.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2) # Potentially reshape the noise to deal with the multitask case noise = noise.view(*noise.shape[:-1], *input.event_shape) + # Handle NaN values if enabled + nan_policy = settings.observation_nan_policy.value() + if nan_policy == "mask": + observed = settings.observation_nan_policy._get_observed(target, input.event_shape) + input = input[(...,) + observed] + noise = noise[(...,) + observed] + target = target[(...,) + observed] + elif nan_policy == "fill": + missing = torch.isnan(target) + target = settings.observation_nan_policy._fill_tensor(target) + + mean, variance = input.mean, input.variance res = ((target - mean) ** 2 + variance) / noise + noise.log() + math.log(2 * math.pi) res = res.mul(-0.5) - if num_event_dim > 1: # Do appropriate summation for multitask Gaussian likelihoods + + if nan_policy == "fill": + res = res * ~missing + + # Do appropriate summation for multitask Gaussian likelihoods + num_event_dim = len(input.event_shape) + if num_event_dim > 1: res = res.sum(list(range(-1, -num_event_dim, -1))) + return res def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> base_distributions.Normal: @@ -56,12 +73,26 @@ def log_marginal( self, observations: Tensor, function_dist: MultivariateNormal, *params: Any, **kwargs: Any ) -> Tensor: marginal = self.marginal(function_dist, *params, **kwargs) + + # Handle NaN values if enabled + nan_policy = settings.observation_nan_policy.value() + if nan_policy == "mask": + observed = settings.observation_nan_policy._get_observed(observations, marginal.event_shape) + marginal = marginal[(...,) + observed] + observations = observations[(...,) + observed] + elif nan_policy == "fill": + missing = torch.isnan(observations) + observations = settings.observation_nan_policy._fill_tensor(observations) + # We're making everything conditionally independent indep_dist = base_distributions.Normal(marginal.mean, marginal.variance.clamp_min(1e-8).sqrt()) res = indep_dist.log_prob(observations) + if nan_policy == "fill": + res = res * ~missing + # Do appropriate summation for multitask Gaussian likelihoods - num_event_dim = len(function_dist.event_shape) + num_event_dim = len(marginal.event_shape) if num_event_dim > 1: res = res.sum(list(range(-1, -num_event_dim, -1))) return res @@ -119,51 +150,6 @@ def raw_noise(self, value: Tensor) -> None: self.noise_covar.initialize(raw_noise=value) -class GaussianLikelihoodWithMissingObs(GaussianLikelihood): - r""" - The standard likelihood for regression with support for missing values. - Assumes a standard homoskedastic noise model: - - .. math:: - p(y \mid f) = f + \epsilon, \quad \epsilon \sim \mathcal N (0, \sigma^2) - - where :math:`\sigma^2` is a noise parameter. Values of y that are nan do - not impact the likelihood calculation. - - .. note:: - This likelihood can be used for exact or approximate inference. - - :param noise_prior: Prior for noise parameter :math:`\sigma^2`. - :type noise_prior: ~gpytorch.priors.Prior, optional - :param noise_constraint: Constraint for noise parameter :math:`\sigma^2`. - :type noise_constraint: ~gpytorch.constraints.Interval, optional - :param batch_shape: The batch shape of the learned noise parameter (default: []). - :type batch_shape: torch.Size, optional - - :var torch.Tensor noise: :math:`\sigma^2` parameter (noise) - """ - - MISSING_VALUE_FILL = -999.0 - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def _get_masked_obs(self, x): - missing_idx = x.isnan() - x_masked = x.masked_fill(missing_idx, self.MISSING_VALUE_FILL) - return missing_idx, x_masked - - def expected_log_prob(self, target, input, *params, **kwargs): - missing_idx, target = self._get_masked_obs(target) - res = super().expected_log_prob(target, input, *params, **kwargs) - return res * ~missing_idx - - def log_marginal(self, observations, function_dist, *params, **kwargs): - missing_idx, observations = self._get_masked_obs(observations) - res = super().log_marginal(observations, function_dist, *params, **kwargs) - return res * ~missing_idx - - class FixedNoiseGaussianLikelihood(_GaussianLikelihoodBase): r""" A Likelihood that assumes fixed heteroscedastic noise. This is useful when you have fixed, known observation diff --git a/gpytorch/mlls/exact_marginal_log_likelihood.py b/gpytorch/mlls/exact_marginal_log_likelihood.py index 32e880709..30604a3fc 100644 --- a/gpytorch/mlls/exact_marginal_log_likelihood.py +++ b/gpytorch/mlls/exact_marginal_log_likelihood.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -import torch +from .. import settings from ..distributions import MultivariateNormal from ..likelihoods import _GaussianLikelihoodBase from .marginal_log_likelihood import MarginalLogLikelihood @@ -18,8 +18,6 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood): :param ~gpytorch.likelihoods.GaussianLikelihood likelihood: The Gaussian likelihood for the model :param ~gpytorch.models.ExactGP model: The exact GP model - :param bool nan_means_missing_data: If set to True, this module checks for NaN values in the output - and ignores them for calculations. Example: >>> # model is a gpytorch.models.ExactGP @@ -31,11 +29,10 @@ class ExactMarginalLogLikelihood(MarginalLogLikelihood): >>> loss.backward() """ - def __init__(self, likelihood, model, nan_means_missing_data=False): + def __init__(self, likelihood, model): if not isinstance(likelihood, _GaussianLikelihoodBase): raise RuntimeError("Likelihood must be Gaussian for exact inference") super(ExactMarginalLogLikelihood, self).__init__(likelihood, model) - self.nan_means_missing_data = nan_means_missing_data def _add_other_terms(self, res, params): # Add additional terms (SGPR / learned inducing points, heteroskedastic likelihood models) @@ -67,15 +64,17 @@ def forward(self, function_dist, target, *params): output = self.likelihood(function_dist, *params) # Remove NaN values if enabled - if self.nan_means_missing_data: - observed = torch.nonzero(~torch.isnan(target), as_tuple=True) - output = output[observed] - target = target[observed] + if settings.observation_nan_policy.value() == "mask": + observed = settings.observation_nan_policy._get_observed(target, output.event_shape) + output = output[(...,) + observed] + target = target[(...,) + observed] + elif settings.observation_nan_policy.value() == "fill": + raise ValueError("NaN observation policy 'fill' is not supported by ExactMarginalLogLikelihood!") # Get the log prob of the marginal distribution res = output.log_prob(target) res = self._add_other_terms(res, params) # Scale by the amount of data we have - num_data = output.event_shape.numel() + num_data = function_dist.event_shape.numel() return res.div_(num_data) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 668332687..d2404762e 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -230,22 +230,39 @@ def covar_cache(self): return self._exact_predictive_covar_inv_quad_form_cache(train_train_covar_inv_root, self._last_test_train_covar) @property - @cached(name="mean_cache") def mean_cache(self): + return self._mean_cache(settings.observation_nan_policy.value()) + + @cached(name="mean_cache") + def _mean_cache(self, nan_policy: str) -> Tensor: mvn = self.likelihood(self.train_prior_dist, self.train_inputs) train_mean, train_train_covar = mvn.loc, mvn.lazy_covariance_matrix train_labels_offset = (self.train_labels - train_mean).unsqueeze(-1) - not_observed = torch.isnan(self.train_labels) - if not torch.any(not_observed): + if nan_policy == "ignore": mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1) - else: - observed = torch.where(~not_observed)[0] + elif nan_policy == "mask": + # Restrict solving to the outputs observed in every batch element. + observed = settings.observation_nan_policy._get_observed( + self.train_labels, torch.Size((self.train_labels.shape[-1],)) + )[0] mean_cache = torch.full_like(self.train_labels, torch.nan) - non_nan_kernel = train_train_covar[..., observed, :][..., :, observed].evaluate_kernel() - mean_cache[~not_observed] = non_nan_kernel.solve(train_labels_offset[observed]).squeeze(-1) - + kernel = train_train_covar[..., observed, :][..., :, observed].evaluate_kernel() + mean_cache[..., observed] = kernel.solve(train_labels_offset[..., observed, :]).squeeze(-1) + else: # 'fill' + # Fill all rows and columns in the kernel matrix corresponding to the missing observations with 0. + # Don't touch the corresponding diagonal elements to ensure a unique solution. + # This ensures that missing data is ignored during solving. + kernel = train_train_covar.evaluate_kernel() + missing = torch.isnan(self.train_labels) + kernel_mask = (~missing).to(torch.float) + kernel_mask = kernel_mask[..., None] * kernel_mask[..., None, :] + torch.diagonal(kernel_mask, dim1=-2, dim2=-1)[...] = 1 + kernel = kernel * kernel_mask # Unfortunately, this makes the kernel dense at the moment. + train_labels_offset = settings.observation_nan_policy._fill_tensor(train_labels_offset) + mean_cache = kernel.solve(train_labels_offset).squeeze(-1) + mean_cache[missing] = torch.nan # Ensure that nobody expects these values to be valid. if settings.detach_test_caches.on(): mean_cache = mean_cache.detach() @@ -293,12 +310,21 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # NOTE TO FUTURE SELF: # You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact # GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no! - not_observed = torch.isnan(self.mean_cache) - if not torch.any(not_observed): + nan_policy = settings.observation_nan_policy.value() + if nan_policy == "ignore": res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1) - else: - observed = torch.where(~not_observed)[0] - res = (test_train_covar[..., observed] @ self.mean_cache[observed].unsqueeze(-1)).squeeze(-1) + elif nan_policy == "mask": + # Restrict train dimension to observed values + observed = settings.observation_nan_policy._get_observed( + self.mean_cache, torch.Size((self.mean_cache.shape[-1],)) + )[0] + res = (test_train_covar[..., observed] @ self.mean_cache[..., observed].unsqueeze(-1)).squeeze(-1) + else: # 'fill' + # Set the columns corresponding to missing observations to 0 to ignore them during matmul. + mask = (~torch.isnan(self.mean_cache)).to(torch.float)[..., None, :] + test_train_covar = test_train_covar * mask + mean = settings.observation_nan_policy._fill_tensor(self.mean_cache) + res = (test_train_covar @ mean.unsqueeze(-1)).squeeze(-1) res = res + test_mean return res diff --git a/gpytorch/settings.py b/gpytorch/settings.py index d7207375a..b862a7aa5 100644 --- a/gpytorch/settings.py +++ b/gpytorch/settings.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 - import torch from linear_operator.settings import ( _linalg_dtype_cholesky, @@ -27,6 +26,7 @@ use_toeplitz, verbose_linalg, ) +from torch import Tensor class _dtype_value_context: @@ -401,6 +401,54 @@ def value(cls, dtype=None): return super().value(dtype=dtype) +class observation_nan_policy(_value_context): + """ + NaN handling policy for observations. + + * ``ignore``: Do not check for NaN values. + * ``mask``: Mask out NaN values during calculation. If an output is NaN in a single batch element, this output is + masked for the complete batch. + * ``fill``: Fill in NaN values with a dummy value, perform computations and filter them later. + Not supported by :class:`gpytorch.mlls.ExactMarginalLogLikelihood`. + Does not support lazy covariance matrices during prediction. + """ + + _fill_value = -999.0 + _global_value = "ignore" + + def __init__(self, value): + if value not in {"ignore", "mask", "fill"}: + raise ValueError(f"NaN handling policy {value} not supported!") + super().__init__(value) + + @staticmethod + def _get_observed(observations, event_shape) -> tuple[Tensor, ...]: + """ + Constructs an index that masks out all elements in the event shape of the tensor which contain a NaN value in + any batch element. Applying this index flattens the event_shape, as the task structure cannot be retained. + This function is used if observation_nan_policy is set to 'mask'. + + :param Tensor observations: The observations to search for NaN values in. + :param torch.Size event_shape: The shape of a single event, i.e. the shape of observations without batch + dimensions. + :return: The index to the event dimensions of the observations. + """ + missing = torch.any(torch.isnan(observations.reshape(-1, *event_shape)), dim=0) + index = torch.nonzero(~missing, as_tuple=True) + return index + + @classmethod + def _fill_tensor(cls, observations) -> Tensor: + """ + Fills a tensor's NaN values with a filling value. + This function is used if observation_nan_policy is set to 'fill'. + + :param Tensor observations: The tensor to fill with values. + :return: The filled in observations. + """ + return torch.nan_to_num(observations, nan=cls._fill_value) + + __all__ = [ "_linalg_dtype_symeig", "_linalg_dtype_cholesky", @@ -431,6 +479,7 @@ def value(cls, dtype=None): "num_gauss_hermite_locs", "num_likelihood_samples", "num_trace_samples", + "observation_nan_policy", "preconditioner_tolerance", "prior_mode", "sgpr_diagonal_correction", diff --git a/test/distributions/test_multivariate_normal.py b/test/distributions/test_multivariate_normal.py index b9cc2568a..e287884fb 100644 --- a/test/distributions/test_multivariate_normal.py +++ b/test/distributions/test_multivariate_normal.py @@ -299,6 +299,10 @@ def test_getitem(self): assert torch.equal(d.mean, dist.mean[1, 2, 2, :]) self.assertAllClose(d.covariance_matrix, dist_cov[1, 2, 2, :, :]) + d = dist[0, 1, ..., 2, 1] + assert torch.equal(d.mean, dist.mean[0, 1, 2, 1]) + self.assertAllClose(d.covariance_matrix, dist_cov[0, 1, 2, 1, 1]) + def test_base_sample_shape(self): a = torch.randn(5, 10) lazy_square_a = RootLinearOperator(to_linear_operator(a)) diff --git a/test/examples/test_missing_data.py b/test/examples/test_missing_data.py index 4564aec0a..846920e9d 100644 --- a/test/examples/test_missing_data.py +++ b/test/examples/test_missing_data.py @@ -2,20 +2,24 @@ import torch +from gpytorch import settings from gpytorch import ExactMarginalLogLikelihood from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal from gpytorch.kernels import ScaleKernel, RBFKernel, MultitaskKernel from gpytorch.likelihoods import GaussianLikelihood, Likelihood, MultitaskGaussianLikelihood from gpytorch.means import ConstantMean, MultitaskMean -from gpytorch.models import ExactGP +from gpytorch.mlls import PredictiveLogLikelihood, MarginalLogLikelihood, VariationalELBO +from gpytorch.models import ExactGP, VariationalGP, GP from gpytorch.test.base_test_case import BaseTestCase +from gpytorch.utils.memoize import clear_cache_hook +from gpytorch.variational import CholeskyVariationalDistribution, LMCVariationalStrategy, VariationalStrategy class SingleGPModel(ExactGP): - def __init__(self, train_inputs, train_targets, likelihood): + def __init__(self, train_inputs, train_targets, likelihood, batch_shape): super(SingleGPModel, self).__init__(train_inputs, train_targets, likelihood) - self.mean_module = ConstantMean() - self.covar_module = ScaleKernel(RBFKernel()) + self.mean_module = ConstantMean(batch_shape=batch_shape) + self.covar_module = ScaleKernel(RBFKernel(batch_shape=batch_shape)) def forward(self, x): mean_x = self.mean_module(x) @@ -35,20 +39,56 @@ def forward(self, x): return MultitaskMultivariateNormal(mean_x, covar_x) +class MultitaskVariationalGPModel(VariationalGP): + def __init__(self, num_latents, num_tasks): + inducing_points = torch.rand(num_latents, 21, 1) + variational_distribution = CholeskyVariationalDistribution( + inducing_points.size(-2), batch_shape=torch.Size([num_latents]) + ) + variational_strategy = LMCVariationalStrategy( + VariationalStrategy( + self, inducing_points, variational_distribution, learn_inducing_locations=True + ), + num_tasks=num_tasks, + num_latents=num_latents, + latent_dim=-1 + ) + super().__init__(variational_strategy) + self.mean_module = ConstantMean(batch_shape=torch.Size([num_latents])) + self.covar_module = ScaleKernel( + RBFKernel(batch_shape=torch.Size([num_latents])), + batch_shape=torch.Size([num_latents]) + ) + + def forward(self, x): + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) + return MultivariateNormal(mean_x, covar_x) + + class TestMissingData(BaseTestCase, unittest.TestCase): seed = 1 - def _train(self, model: ExactGP, likelihood: Likelihood): + def _check( + self, + model: GP, + likelihood: Likelihood, + train_x: torch.Tensor, + train_y: torch.Tensor, + test_x: torch.Tensor, + test_y: torch.Tensor, + optimizer: torch.optim.Optimizer, + mll: MarginalLogLikelihood, + epochs: int = 30, + atol: float = 0.2 + ) -> None: model.train() likelihood.train() - mll = ExactMarginalLogLikelihood(likelihood, model, nan_means_missing_data=True) - optimizer = torch.optim.Adam(model.parameters(), lr=0.15) - - for _ in range(20): + for _ in range(epochs): optimizer.zero_grad() - output = model(*model.train_inputs) - loss = mll(output, model.train_targets) + output = model(train_x) + loss = -mll(output, train_y).sum() self.assertFalse(torch.any(torch.isnan(output.mean)).item()) self.assertFalse(torch.any(torch.isnan(output.covariance_matrix)).item()) self.assertFalse(torch.isnan(loss).item()) @@ -58,38 +98,136 @@ def _train(self, model: ExactGP, likelihood: Likelihood): model.eval() likelihood.eval() - def test_single(self): - train_x = torch.linspace(0, 1, 21) - test_x = torch.linspace(0, 1, 51) - train_y = torch.sin(2 * torch.pi * train_x) - train_y += torch.normal(0, 0.01, train_y.shape) - train_y[::2] = torch.nan + with torch.no_grad(): + if isinstance(model, ExactGP): + self._check_predictions_exact_gp(model, test_x, test_y, atol) + else: + prediction = model(test_x) + self._check_prediction(prediction, test_y, atol) + + def _check_predictions_exact_gp(self, model: ExactGP, test_x: torch.Tensor, test_y: torch.Tensor, atol: float): + with settings.observation_nan_policy("mask"): + prediction = model(test_x) + self._check_prediction(prediction, test_y, atol) - likelihood = GaussianLikelihood() - model = SingleGPModel(train_x, train_y, likelihood) - self._train(model, likelihood) + clear_cache_hook(model.prediction_strategy) - with torch.no_grad(): + with settings.observation_nan_policy("fill"): + prediction = model(test_x) + self._check_prediction(prediction, test_y, atol) + + clear_cache_hook(model.prediction_strategy) + + with settings.observation_nan_policy("mask"): + model(test_x) + with settings.observation_nan_policy("fill"): prediction = model(test_x) + self._check_prediction(prediction, test_y, atol) + clear_cache_hook(model.prediction_strategy) + + with settings.observation_nan_policy("fill"): + model(test_x) + with settings.observation_nan_policy("mask"): + prediction = model(test_x) + self._check_prediction(prediction, test_y, atol) + + def _check_prediction(self, prediction: MultivariateNormal, test_y: torch.Tensor, atol: float): self.assertFalse(torch.any(torch.isnan(prediction.mean)).item()) self.assertFalse(torch.any(torch.isnan(prediction.covariance_matrix)).item()) + self.assertAllClose(prediction.mean, test_y, atol=atol) + + def test_single(self): + train_x = torch.linspace(0, 1, 41) + test_x = torch.linspace(0, 1, 51) + train_y = torch.sin(2 * torch.pi * train_x).squeeze() + train_y += torch.normal(0, 0.01, train_y.shape) + test_y = torch.sin(2 * torch.pi * test_x).squeeze() + train_y[::4] = torch.nan + + batch_shape = torch.Size(()) + likelihood = GaussianLikelihood(batch_shape=batch_shape) + model = SingleGPModel(train_x, train_y, likelihood, batch_shape=batch_shape) + + mll = ExactMarginalLogLikelihood(likelihood, model) + optimizer = torch.optim.Adam(model.parameters(), lr=0.15) + + with settings.observation_nan_policy("mask"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll) + + def test_single_batch(self): + train_x = torch.stack([torch.linspace(0, 1, 41), torch.linspace(1, 2, 41)]).reshape(2, 41, 1) + test_x = torch.stack([torch.linspace(0, 1, 51), torch.linspace(1, 2, 51)]).reshape(2, 51, 1) + train_y = torch.sin(2 * torch.pi * train_x).squeeze() + train_y += torch.normal(0, 0.01, train_y.shape) + test_y = torch.sin(2 * torch.pi * test_x).squeeze() + train_y[0, ::4] = torch.nan + + batch_shape = torch.Size((2,)) + likelihood = GaussianLikelihood(batch_shape=batch_shape) + model = SingleGPModel(train_x, train_y, likelihood, batch_shape=batch_shape) + + mll = ExactMarginalLogLikelihood(likelihood, model) + optimizer = torch.optim.Adam(model.parameters(), lr=0.15) + + with settings.observation_nan_policy("mask"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll) def test_multitask(self): num_tasks = 10 - train_x = torch.linspace(0, 1, 21) + train_x = torch.linspace(0, 1, 41) test_x = torch.linspace(0, 1, 51) - train_y = torch.sin(2 * torch.pi * train_x)[:, None] * torch.rand(1, num_tasks) + coefficients = torch.rand(1, num_tasks) + train_y = torch.sin(2 * torch.pi * train_x)[:, None] * coefficients train_y += torch.normal(0, 0.01, train_y.shape) + test_y = torch.sin(2 * torch.pi * test_x)[:, None] * coefficients train_y[::3, : num_tasks // 2] = torch.nan train_y[::4, num_tasks // 2 :] = torch.nan likelihood = MultitaskGaussianLikelihood(num_tasks) model = MultitaskGPModel(train_x, train_y, likelihood, num_tasks) - self._train(model, likelihood) - with torch.no_grad(): - prediction = model(test_x) + mll = ExactMarginalLogLikelihood(likelihood, model) + optimizer = torch.optim.Adam(model.parameters(), lr=0.15) - self.assertFalse(torch.any(torch.isnan(prediction.mean)).item()) - self.assertFalse(torch.any(torch.isnan(prediction.covariance_matrix)).item()) + with settings.observation_nan_policy("mask"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll) + + def test_variational_multitask(self): + num_latents = 1 + train_x = torch.linspace(0, 1, 41) + test_x = torch.linspace(0, 1, 51) + train_y = torch.stack([ + torch.sin(train_x * (2 * torch.pi)) + torch.randn(train_x.size()) * 0.2, + -torch.sin(train_x * (2 * torch.pi)) + torch.randn(train_x.size()) * 0.2, + ], -1) + test_y = torch.stack([ + torch.sin(test_x * (2 * torch.pi)), + -torch.sin(test_x * (2 * torch.pi)), + ], -1) + num_tasks = train_y.shape[-1] + + # nan out a few train_y + train_y[-3:, 1] = torch.nan + + likelihood = MultitaskGaussianLikelihood(num_tasks=num_tasks) + model = MultitaskVariationalGPModel(num_latents, num_tasks) + model.train() + likelihood.train() + + optimizer = torch.optim.Adam([ + {'params': model.parameters()}, + {'params': likelihood.parameters()}, + ], lr=0.05) + + mll = VariationalELBO(likelihood, model, num_data=train_y.size(0)) + with settings.observation_nan_policy("mask"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.3) + with settings.observation_nan_policy("fill"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.2) + + mll = PredictiveLogLikelihood(likelihood, model, num_data=train_y.size(0)) + with settings.observation_nan_policy("mask"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.3) + with settings.observation_nan_policy("fill"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.2) diff --git a/test/likelihoods/test_gaussian_likelihood.py b/test/likelihoods/test_gaussian_likelihood.py index 9874395a9..2c11a2a5f 100644 --- a/test/likelihoods/test_gaussian_likelihood.py +++ b/test/likelihoods/test_gaussian_likelihood.py @@ -8,12 +8,7 @@ from gpytorch import settings from gpytorch.distributions import MultivariateNormal -from gpytorch.likelihoods import ( - DirichletClassificationLikelihood, - FixedNoiseGaussianLikelihood, - GaussianLikelihood, - GaussianLikelihoodWithMissingObs, -) +from gpytorch.likelihoods import DirichletClassificationLikelihood, FixedNoiseGaussianLikelihood, GaussianLikelihood from gpytorch.likelihoods.noise_models import FixedGaussianNoise from gpytorch.priors import GammaPrior from gpytorch.test.base_likelihood_test_case import BaseLikelihoodTestCase @@ -164,13 +159,13 @@ def test_dirichlet_classification_likelihood(self, cuda=False): self.assertTrue(torch.allclose(out.variance, 1.0 + obs_targets)) -class TestGaussianLikelihoodwithMissingObs(BaseLikelihoodTestCase, unittest.TestCase): +class TestGaussianLikelihoodWithMissingObs(BaseLikelihoodTestCase, unittest.TestCase): seed = 42 def create_likelihood(self): - return GaussianLikelihoodWithMissingObs() + return GaussianLikelihood() - def test_missing_value_inference(self): + def test_missing_value_inference_fill(self): """ samples = mvn samples + noise samples In this test, we try to recover noise parameters when some elements in @@ -179,43 +174,56 @@ def test_missing_value_inference(self): torch.manual_seed(self.seed) - mu = torch.zeros(2, 3) - sigma = torch.tensor([[[1, 0.999, -0.999], [0.999, 1, -0.999], [-0.999, -0.999, 1]]] * 2).float() - mvn = MultivariateNormal(mu, sigma) - samples = mvn.sample(torch.Size([10000])) # mvn samples - - noise_sd = 0.5 - noise_dist = torch.distributions.Normal(0, noise_sd) - samples += noise_dist.sample(samples.shape) # noise + mvn, samples = self._make_data() - missing_prop = 0.33 - missing_idx = torch.distributions.Binomial(1, missing_prop).sample(samples.shape).bool() + missing_probability = 0.33 + missing_idx = torch.distributions.Binomial(1, missing_probability).sample(samples.shape).bool() samples[missing_idx] = float("nan") - likelihood = GaussianLikelihoodWithMissingObs() + # check that the correct noise sd is recovered - # check that the missing value fill doesn't impact the likelihood + with settings.observation_nan_policy("fill"): + self._check_recovery(mvn, samples) - likelihood.MISSING_VALUE_FILL = 999.0 - like_init_plus = likelihood.log_marginal(samples, mvn).sum().data + def test_missing_value_inference_mask(self): + """ + samples = mvn samples + noise samples + In this test, we try to recover noise parameters when some elements in + 'samples' are missing at random. + """ - likelihood.MISSING_VALUE_FILL = -999.0 - like_init_minus = likelihood.log_marginal(samples, mvn).sum().data + torch.manual_seed(self.seed) + + mvn, samples = self._make_data() - torch.testing.assert_close(like_init_plus, like_init_minus) + missing_prop = 0.33 + missing_idx = torch.distributions.Binomial(1, missing_prop).sample(samples.shape[1:]).bool() + samples[1, missing_idx] = float("nan") # check that the correct noise sd is recovered - opt = torch.optim.Adam(likelihood.parameters(), lr=0.05) + with settings.observation_nan_policy("fill"): + self._check_recovery(mvn, samples) + def _make_data(self): + mu = torch.zeros(2, 3) + sigma = torch.tensor([[[1, 0.999, -0.999], [0.999, 1, -0.999], [-0.999, -0.999, 1]]] * 2).float() + mvn = MultivariateNormal(mu, sigma) + samples = mvn.sample(torch.Size([10000])) # mvn samples + noise_sd = 0.5 + noise_dist = torch.distributions.Normal(0, noise_sd) + samples += noise_dist.sample(samples.shape) # noise + return mvn, samples + + def _check_recovery(self, mvn, samples): + likelihood = GaussianLikelihood() + opt = torch.optim.Adam(likelihood.parameters(), lr=0.05) for _ in range(100): opt.zero_grad() loss = -likelihood.log_marginal(samples, mvn).sum() loss.backward() opt.step() - - assert abs(float(likelihood.noise.sqrt()) - 0.5) < 0.02 - + self.assertTrue(abs(float(likelihood.noise.sqrt()) - 0.5) < 0.02) # Check log marginal works likelihood.log_marginal(samples[0], mvn) From 31ed8dde22b5833620669d4e5b8780a95f4e0363 Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Fri, 10 Mar 2023 17:38:32 +0100 Subject: [PATCH 12/23] Fix Python version incompatibility --- gpytorch/settings.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gpytorch/settings.py b/gpytorch/settings.py index b862a7aa5..2edd4c68f 100644 --- a/gpytorch/settings.py +++ b/gpytorch/settings.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from typing import Tuple + import torch from linear_operator.settings import ( _linalg_dtype_cholesky, @@ -422,7 +424,7 @@ def __init__(self, value): super().__init__(value) @staticmethod - def _get_observed(observations, event_shape) -> tuple[Tensor, ...]: + def _get_observed(observations, event_shape) -> Tuple[Tensor, ...]: """ Constructs an index that masks out all elements in the event shape of the tensor which contain a NaN value in any batch element. Applying this index flattens the event_shape, as the task structure cannot be retained. From 8546fd3c6e70e50eb62740ab8887c8e2773988e7 Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Fri, 10 Mar 2023 17:49:49 +0100 Subject: [PATCH 13/23] Increase atol on variational tests --- test/examples/test_missing_data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/examples/test_missing_data.py b/test/examples/test_missing_data.py index 846920e9d..1f7044a9c 100644 --- a/test/examples/test_missing_data.py +++ b/test/examples/test_missing_data.py @@ -67,7 +67,7 @@ def forward(self, x): class TestMissingData(BaseTestCase, unittest.TestCase): - seed = 1 + seed = 20 def _check( self, @@ -222,12 +222,12 @@ def test_variational_multitask(self): mll = VariationalELBO(likelihood, model, num_data=train_y.size(0)) with settings.observation_nan_policy("mask"): - self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.3) + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.7) with settings.observation_nan_policy("fill"): - self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.2) + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.3) mll = PredictiveLogLikelihood(likelihood, model, num_data=train_y.size(0)) with settings.observation_nan_policy("mask"): - self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.3) + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.7) with settings.observation_nan_policy("fill"): - self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.2) + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.3) From 6ab4e5514a5199094981986e172a8de9fb75ce74 Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Thu, 16 Mar 2023 16:59:48 +0100 Subject: [PATCH 14/23] Add ExactMarginalLogLikelihoodWithMissingObs back with deprecation warning --- docs/source/marginal_log_likelihoods.rst | 6 +++ gpytorch/likelihoods/__init__.py | 2 + gpytorch/likelihoods/gaussian_likelihood.py | 54 +++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/docs/source/marginal_log_likelihoods.rst b/docs/source/marginal_log_likelihoods.rst index 20bb840bc..3fb9e2f14 100644 --- a/docs/source/marginal_log_likelihoods.rst +++ b/docs/source/marginal_log_likelihoods.rst @@ -37,6 +37,12 @@ These are MLLs for use with :obj:`~gpytorch.models.ExactGP` modules. They comput .. autoclass:: ExactMarginalLogLikelihood :members: +:hidden:`GaussianLikelihoodWithMissingObs` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: GaussianLikelihoodWithMissingObs + :members: + :hidden:`LeaveOneOutPseudoLikelihood` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/gpytorch/likelihoods/__init__.py b/gpytorch/likelihoods/__init__.py index 45fb205f4..31a370079 100644 --- a/gpytorch/likelihoods/__init__.py +++ b/gpytorch/likelihoods/__init__.py @@ -7,6 +7,7 @@ DirichletClassificationLikelihood, FixedNoiseGaussianLikelihood, GaussianLikelihood, + GaussianLikelihoodWithMissingObs, ) from .laplace_likelihood import LaplaceLikelihood from .likelihood import _OneDimensionalLikelihood, Likelihood @@ -25,6 +26,7 @@ "DirichletClassificationLikelihood", "FixedNoiseGaussianLikelihood", "GaussianLikelihood", + "GaussianLikelihoodWithMissingObs", "HeteroskedasticNoise", "LaplaceLikelihood", "Likelihood", diff --git a/gpytorch/likelihoods/gaussian_likelihood.py b/gpytorch/likelihoods/gaussian_likelihood.py index 8201b9d39..2a681f317 100644 --- a/gpytorch/likelihoods/gaussian_likelihood.py +++ b/gpytorch/likelihoods/gaussian_likelihood.py @@ -167,6 +167,60 @@ def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) return super().marginal(function_dist, *args, **kwargs) +class GaussianLikelihoodWithMissingObs(GaussianLikelihood): + r""" + The standard likelihood for regression with support for missing values. + Assumes a standard homoskedastic noise model: + .. math:: + p(y \mid f) = f + \epsilon, \quad \epsilon \sim \mathcal N (0, \sigma^2) + where :math:`\sigma^2` is a noise parameter. Values of y that are nan do + not impact the likelihood calculation. + .. note:: + This likelihood can be used for exact or approximate inference. + :param noise_prior: Prior for noise parameter :math:`\sigma^2`. + :type noise_prior: ~gpytorch.priors.Prior, optional + :param noise_constraint: Constraint for noise parameter :math:`\sigma^2`. + :type noise_constraint: ~gpytorch.constraints.Interval, optional + :param batch_shape: The batch shape of the learned noise parameter (default: []). + :type batch_shape: torch.Size, optional + :var torch.Tensor noise: :math:`\sigma^2` parameter (noise) + .. note:: + GaussianLikelihoodWithMissingObs has an analytic marginal distribution. + """ + + MISSING_VALUE_FILL: float = -999.0 + + def __init__(self, **kwargs: Any) -> None: + warnings.warn( + "ExactMarginalLogLikelihood is replaced by gpytorch.settings.observation_nan_policy('fill').", + DeprecationWarning, + ) + super().__init__(**kwargs) + + def _get_masked_obs(self, x: Tensor) -> Tuple[Tensor, Tensor]: + missing_idx = x.isnan() + x_masked = x.masked_fill(missing_idx, self.MISSING_VALUE_FILL) + return missing_idx, x_masked + + def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any) -> Tensor: + missing_idx, target = self._get_masked_obs(target) + res = super().expected_log_prob(target, input, *params, **kwargs) + return res * ~missing_idx + + def log_marginal( + self, observations: Tensor, function_dist: MultivariateNormal, *params: Any, **kwargs: Any + ) -> Tensor: + missing_idx, observations = self._get_masked_obs(observations) + res = super().log_marginal(observations, function_dist, *params, **kwargs) + return res * ~missing_idx + + def marginal(self, function_dist: MultivariateNormal, *args: Any, **kwargs: Any) -> MultivariateNormal: + """ + :return: Analytic marginal :math:`p(\mathbf y)`. + """ + return super().marginal(function_dist, *args, **kwargs) + + class FixedNoiseGaussianLikelihood(_GaussianLikelihoodBase): r""" A Likelihood that assumes fixed heteroscedastic noise. This is useful when you have fixed, known observation From e2713ac0a8a490cc25e0dd5432c3622419016ff0 Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Fri, 21 Apr 2023 11:20:21 +0200 Subject: [PATCH 15/23] Add warning if kernel matrix is made dense --- gpytorch/models/exact_prediction_strategies.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index d2404762e..971739afe 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -2,6 +2,7 @@ import functools import string +import warnings import torch from linear_operator import to_dense, to_linear_operator @@ -254,6 +255,10 @@ def _mean_cache(self, nan_policy: str) -> Tensor: # Fill all rows and columns in the kernel matrix corresponding to the missing observations with 0. # Don't touch the corresponding diagonal elements to ensure a unique solution. # This ensures that missing data is ignored during solving. + warnings.warn( + "Observation NaN policy 'fill' makes the kernel matrix dense during exact prediction.", + RuntimeWarning, + ) kernel = train_train_covar.evaluate_kernel() missing = torch.isnan(self.train_labels) kernel_mask = (~missing).to(torch.float) From a67b40bee4165bdebb6c235e2a8023ed6ad498b3 Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Fri, 21 Apr 2023 11:53:43 +0200 Subject: [PATCH 16/23] Fix docs --- docs/source/likelihoods.rst | 6 ++++++ docs/source/marginal_log_likelihoods.rst | 6 ------ gpytorch/likelihoods/gaussian_likelihood.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/likelihoods.rst b/docs/source/likelihoods.rst index c69feec3f..3b71b42ae 100644 --- a/docs/source/likelihoods.rst +++ b/docs/source/likelihoods.rst @@ -33,6 +33,12 @@ reduce the variance when computing approximate GP objective functions. .. autoclass:: GaussianLikelihood :members: +:hidden:`GaussianLikelihoodWithMissingObs` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: GaussianLikelihoodWithMissingObs + :members: + :hidden:`FixedNoiseGaussianLikelihood` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/marginal_log_likelihoods.rst b/docs/source/marginal_log_likelihoods.rst index 3fb9e2f14..20bb840bc 100644 --- a/docs/source/marginal_log_likelihoods.rst +++ b/docs/source/marginal_log_likelihoods.rst @@ -37,12 +37,6 @@ These are MLLs for use with :obj:`~gpytorch.models.ExactGP` modules. They comput .. autoclass:: ExactMarginalLogLikelihood :members: -:hidden:`GaussianLikelihoodWithMissingObs` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: GaussianLikelihoodWithMissingObs - :members: - :hidden:`LeaveOneOutPseudoLikelihood` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/gpytorch/likelihoods/gaussian_likelihood.py b/gpytorch/likelihoods/gaussian_likelihood.py index 2a681f317..645e7aecf 100644 --- a/gpytorch/likelihoods/gaussian_likelihood.py +++ b/gpytorch/likelihoods/gaussian_likelihood.py @@ -192,7 +192,7 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood): def __init__(self, **kwargs: Any) -> None: warnings.warn( - "ExactMarginalLogLikelihood is replaced by gpytorch.settings.observation_nan_policy('fill').", + "GaussianLikelihoodWithMissingObs is replaced by gpytorch.settings.observation_nan_policy('fill').", DeprecationWarning, ) super().__init__(**kwargs) From 49fc2f4c1895c18940b01841b8d118b99640b272 Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Fri, 21 Apr 2023 12:05:32 +0200 Subject: [PATCH 17/23] Add quick path for noop slice indices --- .../distributions/multitask_multivariate_normal.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/gpytorch/distributions/multitask_multivariate_normal.py b/gpytorch/distributions/multitask_multivariate_normal.py index 20ad77ad4..6461500af 100644 --- a/gpytorch/distributions/multitask_multivariate_normal.py +++ b/gpytorch/distributions/multitask_multivariate_normal.py @@ -364,6 +364,18 @@ def __getitem__(self, idx) -> MultivariateNormal: new_slice = slice(row_idx.start + col_idx, row_idx.stop * num_cols + col_idx, row_idx.step * num_cols) new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)] return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov) + elif ( + isinstance(row_idx, slice) + and isinstance(col_idx, slice) + and row_idx == col_idx == slice(None, None, None) + ): + new_cov = self.lazy_covariance_matrix[batch_idx] + return MultitaskMultivariateNormal( + mean=new_mean, + covariance_matrix=new_cov, + interleaved=self._interleaved, + validate_args=False, + ) elif isinstance(row_idx, slice) or isinstance(col_idx, slice): # slice x slice or indices x slice or slice x indices if isinstance(row_idx, slice): From e8ecbef28cf8284324ca4418299b14e5f9f68a9d Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Fri, 21 Apr 2023 12:13:45 +0200 Subject: [PATCH 18/23] Add test for noop slice indexing --- test/distributions/test_multitask_multivariate_normal.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/distributions/test_multitask_multivariate_normal.py b/test/distributions/test_multitask_multivariate_normal.py index 4df3a7be8..245cef3e3 100644 --- a/test/distributions/test_multitask_multivariate_normal.py +++ b/test/distributions/test_multitask_multivariate_normal.py @@ -333,6 +333,13 @@ def flat(observation: int, task: int) -> int: self.assertAllClose(part.mean, mean[1, -1]) self.assertAllClose(part.covariance_matrix, covar[1, -1]) + part = distribution[1, 0, ...] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size(())) + self.assertEqual(part.event_shape, torch.Size((3, 2))) + self.assertAllClose(part.mean, mean[1, 0]) + self.assertAllClose(part.covariance_matrix, covar[1, 0]) + part = distribution[..., 2, 1] self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) self.assertIsInstance(part, MultivariateNormal) From 470647788d0403c4665c24c3ab970e7dc19628ad Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Fri, 21 Apr 2023 12:14:16 +0200 Subject: [PATCH 19/23] Fix docs --- gpytorch/likelihoods/gaussian_likelihood.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/gpytorch/likelihoods/gaussian_likelihood.py b/gpytorch/likelihoods/gaussian_likelihood.py index 645e7aecf..ceea71c6e 100644 --- a/gpytorch/likelihoods/gaussian_likelihood.py +++ b/gpytorch/likelihoods/gaussian_likelihood.py @@ -171,12 +171,19 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood): r""" The standard likelihood for regression with support for missing values. Assumes a standard homoskedastic noise model: + .. math:: p(y \mid f) = f + \epsilon, \quad \epsilon \sim \mathcal N (0, \sigma^2) + where :math:`\sigma^2` is a noise parameter. Values of y that are nan do not impact the likelihood calculation. + .. note:: This likelihood can be used for exact or approximate inference. + + .. warning:: + This likelihood is deprecated in favor of :class:`gpytorch.settings.observation_nan_policy`. + :param noise_prior: Prior for noise parameter :math:`\sigma^2`. :type noise_prior: ~gpytorch.priors.Prior, optional :param noise_constraint: Constraint for noise parameter :math:`\sigma^2`. @@ -184,6 +191,7 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood): :param batch_shape: The batch shape of the learned noise parameter (default: []). :type batch_shape: torch.Size, optional :var torch.Tensor noise: :math:`\sigma^2` parameter (noise) + .. note:: GaussianLikelihoodWithMissingObs has an analytic marginal distribution. """ From e8788c8b7817d158c3759289b7d2e92e98d6d80b Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Wed, 7 Jun 2023 18:33:05 +0200 Subject: [PATCH 20/23] Switch to MaskedLinearOperator --- gpytorch/likelihoods/gaussian_likelihood.py | 21 +++- .../mlls/exact_marginal_log_likelihood.py | 10 +- .../models/exact_prediction_strategies.py | 15 ++- gpytorch/settings.py | 17 ++-- gpytorch/utils/masked_linear_operator.py | 96 +++++++++++++++++++ test/examples/test_missing_data.py | 7 +- 6 files changed, 141 insertions(+), 25 deletions(-) create mode 100644 gpytorch/utils/masked_linear_operator.py diff --git a/gpytorch/likelihoods/gaussian_likelihood.py b/gpytorch/likelihoods/gaussian_likelihood.py index ceea71c6e..040ab17da 100644 --- a/gpytorch/likelihoods/gaussian_likelihood.py +++ b/gpytorch/likelihoods/gaussian_likelihood.py @@ -13,6 +13,7 @@ from ..constraints import Interval from ..distributions import base_distributions, MultivariateNormal from ..priors import Prior +from ..utils.masked_linear_operator import MaskedLinearOperator from ..utils.warnings import GPInputWarning from .likelihood import Likelihood from .noise_models import FixedGaussianNoise, HomoskedasticNoise, Noise @@ -48,9 +49,14 @@ def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: nan_policy = settings.observation_nan_policy.value() if nan_policy == "mask": observed = settings.observation_nan_policy._get_observed(target, input.event_shape) - input = input[(...,) + observed] - noise = noise[(...,) + observed] - target = target[(...,) + observed] + input = MultivariateNormal( + mean=input.mean[..., observed], + covariance_matrix=MaskedLinearOperator( + input.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1) + ), + ) + noise = noise[..., observed] + target = target[..., observed] elif nan_policy == "fill": missing = torch.isnan(target) target = settings.observation_nan_policy._fill_tensor(target) @@ -82,8 +88,13 @@ def log_marginal( nan_policy = settings.observation_nan_policy.value() if nan_policy == "mask": observed = settings.observation_nan_policy._get_observed(observations, marginal.event_shape) - marginal = marginal[(...,) + observed] - observations = observations[(...,) + observed] + marginal = MultivariateNormal( + mean=marginal.mean[..., observed], + covariance_matrix=MaskedLinearOperator( + marginal.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1) + ), + ) + observations = observations[..., observed] elif nan_policy == "fill": missing = torch.isnan(observations) observations = settings.observation_nan_policy._fill_tensor(observations) diff --git a/gpytorch/mlls/exact_marginal_log_likelihood.py b/gpytorch/mlls/exact_marginal_log_likelihood.py index 30604a3fc..9c944ae42 100644 --- a/gpytorch/mlls/exact_marginal_log_likelihood.py +++ b/gpytorch/mlls/exact_marginal_log_likelihood.py @@ -3,6 +3,7 @@ from .. import settings from ..distributions import MultivariateNormal from ..likelihoods import _GaussianLikelihoodBase +from ..utils.masked_linear_operator import MaskedLinearOperator from .marginal_log_likelihood import MarginalLogLikelihood @@ -66,8 +67,13 @@ def forward(self, function_dist, target, *params): # Remove NaN values if enabled if settings.observation_nan_policy.value() == "mask": observed = settings.observation_nan_policy._get_observed(target, output.event_shape) - output = output[(...,) + observed] - target = target[(...,) + observed] + output = MultivariateNormal( + mean=output.mean[..., observed], + covariance_matrix=MaskedLinearOperator( + output.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1) + ), + ) + target = target[..., observed] elif settings.observation_nan_policy.value() == "fill": raise ValueError("NaN observation policy 'fill' is not supported by ExactMarginalLogLikelihood!") diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 971739afe..4b3df050f 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -24,6 +24,7 @@ from .. import settings from ..lazy import LazyEvaluatedKernelTensor +from ..utils.masked_linear_operator import MaskedLinearOperator from ..utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache @@ -244,12 +245,14 @@ def _mean_cache(self, nan_policy: str) -> Tensor: if nan_policy == "ignore": mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1) elif nan_policy == "mask": - # Restrict solving to the outputs observed in every batch element. + # Mask all rows and columns in the kernel matrix corresponding to the missing observations. observed = settings.observation_nan_policy._get_observed( self.train_labels, torch.Size((self.train_labels.shape[-1],)) - )[0] + ) mean_cache = torch.full_like(self.train_labels, torch.nan) - kernel = train_train_covar[..., observed, :][..., :, observed].evaluate_kernel() + kernel = MaskedLinearOperator( + train_train_covar.evaluate_kernel(), observed.reshape(-1), observed.reshape(-1) + ) mean_cache[..., observed] = kernel.solve(train_labels_offset[..., observed, :]).squeeze(-1) else: # 'fill' # Fill all rows and columns in the kernel matrix corresponding to the missing observations with 0. @@ -322,8 +325,10 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # Restrict train dimension to observed values observed = settings.observation_nan_policy._get_observed( self.mean_cache, torch.Size((self.mean_cache.shape[-1],)) - )[0] - res = (test_train_covar[..., observed] @ self.mean_cache[..., observed].unsqueeze(-1)).squeeze(-1) + ) + full_mask = torch.ones(test_mean.shape[-1], dtype=torch.bool, device=test_mean.device) + test_train_covar = MaskedLinearOperator(test_train_covar, full_mask, observed.reshape(-1)) + res = (test_train_covar @ self.mean_cache[..., observed].unsqueeze(-1)).squeeze(-1) else: # 'fill' # Set the columns corresponding to missing observations to 0 to ignore them during matmul. mask = (~torch.isnan(self.mean_cache)).to(torch.float)[..., None, :] diff --git a/gpytorch/settings.py b/gpytorch/settings.py index 2edd4c68f..595ffe10d 100644 --- a/gpytorch/settings.py +++ b/gpytorch/settings.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -from typing import Tuple import torch from linear_operator.settings import ( @@ -407,9 +406,9 @@ class observation_nan_policy(_value_context): """ NaN handling policy for observations. - * ``ignore``: Do not check for NaN values. - * ``mask``: Mask out NaN values during calculation. If an output is NaN in a single batch element, this output is - masked for the complete batch. + * ``ignore``: Do not check for NaN values (the default). + * ``mask``: Mask out NaN values during calculation. If an output is NaN in a single batch element, this output + is masked for the complete batch. This strategy likely is a good choice if you have NaN values. * ``fill``: Fill in NaN values with a dummy value, perform computations and filter them later. Not supported by :class:`gpytorch.mlls.ExactMarginalLogLikelihood`. Does not support lazy covariance matrices during prediction. @@ -424,20 +423,18 @@ def __init__(self, value): super().__init__(value) @staticmethod - def _get_observed(observations, event_shape) -> Tuple[Tensor, ...]: + def _get_observed(observations, event_shape) -> Tensor: """ - Constructs an index that masks out all elements in the event shape of the tensor which contain a NaN value in + Constructs a tensor that masks out all elements in the event shape of the tensor which contain a NaN value in any batch element. Applying this index flattens the event_shape, as the task structure cannot be retained. This function is used if observation_nan_policy is set to 'mask'. :param Tensor observations: The observations to search for NaN values in. :param torch.Size event_shape: The shape of a single event, i.e. the shape of observations without batch dimensions. - :return: The index to the event dimensions of the observations. + :return: The mask to the event dimensions of the observations. """ - missing = torch.any(torch.isnan(observations.reshape(-1, *event_shape)), dim=0) - index = torch.nonzero(~missing, as_tuple=True) - return index + return ~torch.any(torch.isnan(observations.reshape(-1, *event_shape)), dim=0) @classmethod def _fill_tensor(cls, observations) -> Tensor: diff --git a/gpytorch/utils/masked_linear_operator.py b/gpytorch/utils/masked_linear_operator.py new file mode 100644 index 000000000..e5c703e80 --- /dev/null +++ b/gpytorch/utils/masked_linear_operator.py @@ -0,0 +1,96 @@ +from typing import Optional, Union + +import torch +from linear_operator import LinearOperator +from torch import Tensor + + +class MaskedLinearOperator(LinearOperator): + def __init__(self, base: LinearOperator, row_mask: Tensor, col_mask: Tensor): + super().__init__(base, row_mask, col_mask) + self.base = base + self.row_mask = row_mask + self.col_mask = col_mask + self.row_eq_col_mask = row_mask is not None and col_mask is not None and torch.equal(row_mask, col_mask) + + def _matmul(self, rhs: Tensor) -> Tensor: + if self.col_mask is not None: + rhs_expanded = torch.zeros( + *rhs.shape[:-2], + self.base.size(-1), + rhs.shape[-1], + device=rhs.device, + dtype=rhs.dtype, + ) + rhs_expanded[..., self.col_mask, :] = rhs + rhs = rhs_expanded + + res = self.base.matmul(rhs) + + if self.row_mask is not None: + res = res[..., self.row_mask, :] + + return res + + def _size(self) -> torch.Size: + base_size = list(self.base.size()) + if self.row_mask is not None: + base_size[-2] = torch.count_nonzero(self.row_mask) + if self.col_mask is not None: + base_size[-1] = torch.count_nonzero(self.col_mask) + return torch.Size(tuple(base_size)) + + def _transpose_nonbatch(self) -> LinearOperator: + return MaskedLinearOperator(self.base.mT, self.col_mask, self.row_mask) + + def _getitem( + self, + row_index: Union[slice, torch.LongTensor], + col_index: Union[slice, torch.LongTensor], + *batch_indices: tuple[Union[int, slice, torch.LongTensor], ...], + ) -> LinearOperator: + raise NotImplementedError("Indexing with %r, %r, %r not supported." % (batch_indices, row_index, col_index)) + + def _get_indices( + self, + row_index: torch.LongTensor, + col_index: torch.LongTensor, + *batch_indices: tuple[torch.LongTensor, ...], + ) -> torch.Tensor: + def map_indices(index: torch.LongTensor, mask: Optional[Tensor], base_size: int) -> torch.LongTensor: + if mask is None: + return index + map = torch.arange(base_size, device=self.base.device)[mask] + return map[index] + + if len(batch_indices) == 0: + row_index = map_indices(row_index, self.row_mask, self.base.size(-2)) + col_index = map_indices(col_index, self.col_mask, self.base.size(-1)) + return self.base._get_indices(row_index, col_index) + + raise NotImplementedError("Indexing with %r, %r, %r not supported." % (batch_indices, row_index, col_index)) + + def _diagonal(self) -> Tensor: + if not self.row_eq_col_mask: + raise NotImplementedError() + diag = self.base.diagonal() + return diag[self.row_mask] + + def to_dense(self) -> torch.Tensor: + full_dense = self.base.to_dense() + return full_dense[..., self.row_mask, :][..., :, self.col_mask] + + def _cholesky_solve(self, rhs, upper: bool = False) -> LinearOperator: + raise NotImplementedError() + + def _expand_batch(self, batch_shape: torch.Size) -> LinearOperator: + raise NotImplementedError() + + def _isclose(self, other, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Tensor: + raise NotImplementedError() + + def _prod_batch(self, dim: int) -> LinearOperator: + raise NotImplementedError() + + def _sum_batch(self, dim: int) -> LinearOperator: + raise NotImplementedError() diff --git a/test/examples/test_missing_data.py b/test/examples/test_missing_data.py index 1f7044a9c..4a3b651cb 100644 --- a/test/examples/test_missing_data.py +++ b/test/examples/test_missing_data.py @@ -68,6 +68,7 @@ def forward(self, x): class TestMissingData(BaseTestCase, unittest.TestCase): seed = 20 + warning = "Observation NaN policy 'fill' makes the kernel matrix dense during exact prediction." def _check( self, @@ -112,7 +113,7 @@ def _check_predictions_exact_gp(self, model: ExactGP, test_x: torch.Tensor, test clear_cache_hook(model.prediction_strategy) - with settings.observation_nan_policy("fill"): + with settings.observation_nan_policy("fill"), self.assertWarns(RuntimeWarning, msg=self.warning): prediction = model(test_x) self._check_prediction(prediction, test_y, atol) @@ -120,13 +121,13 @@ def _check_predictions_exact_gp(self, model: ExactGP, test_x: torch.Tensor, test with settings.observation_nan_policy("mask"): model(test_x) - with settings.observation_nan_policy("fill"): + with settings.observation_nan_policy("fill"), self.assertWarns(RuntimeWarning, msg=self.warning): prediction = model(test_x) self._check_prediction(prediction, test_y, atol) clear_cache_hook(model.prediction_strategy) - with settings.observation_nan_policy("fill"): + with settings.observation_nan_policy("fill"), self.assertWarns(RuntimeWarning, msg=self.warning): model(test_x) with settings.observation_nan_policy("mask"): prediction = model(test_x) From 522cbcf1f682904dae0d086887ceb9aca9307f5b Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Tue, 8 Aug 2023 14:17:20 +0200 Subject: [PATCH 21/23] Switch to MaskedLinearOperator from linear-operator 0.5.1 --- .conda/meta.yaml | 2 +- docs/requirements.txt | 2 +- gpytorch/likelihoods/gaussian_likelihood.py | 3 +- .../mlls/exact_marginal_log_likelihood.py | 3 +- .../models/exact_prediction_strategies.py | 6 +- gpytorch/utils/masked_linear_operator.py | 96 ------------------- setup.py | 2 +- 7 files changed, 10 insertions(+), 104 deletions(-) delete mode 100644 gpytorch/utils/masked_linear_operator.py diff --git a/.conda/meta.yaml b/.conda/meta.yaml index f15f1bc9f..77abb83e9 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -18,7 +18,7 @@ requirements: run: - pytorch>=1.11 - scikit-learn - - linear_operator>=0.5.0 + - linear_operator>=0.5.1 test: imports: diff --git a/docs/requirements.txt b/docs/requirements.txt index 2a300985f..258800a64 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,7 +1,7 @@ setuptools_scm<=7.1.0 ipython<=8.6.0 ipykernel<=6.17.1 -linear_operator>=0.5.0 +linear_operator>=0.5.1 m2r2<=0.3.3.post2 nbclient<=0.7.3 nbformat<=5.8.0 diff --git a/gpytorch/likelihoods/gaussian_likelihood.py b/gpytorch/likelihoods/gaussian_likelihood.py index 3ecea027e..e753f92c3 100644 --- a/gpytorch/likelihoods/gaussian_likelihood.py +++ b/gpytorch/likelihoods/gaussian_likelihood.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Tuple, Union import torch -from linear_operator.operators import LinearOperator, ZeroLinearOperator +from linear_operator.operators import LinearOperator, MaskedLinearOperator, ZeroLinearOperator from torch import Tensor from torch.distributions import Distribution, Normal @@ -13,7 +13,6 @@ from ..constraints import Interval from ..distributions import base_distributions, MultivariateNormal from ..priors import Prior -from ..utils.masked_linear_operator import MaskedLinearOperator from ..utils.warnings import GPInputWarning from .likelihood import Likelihood from .noise_models import FixedGaussianNoise, HomoskedasticNoise, Noise diff --git a/gpytorch/mlls/exact_marginal_log_likelihood.py b/gpytorch/mlls/exact_marginal_log_likelihood.py index 399b000ca..7b2987f50 100644 --- a/gpytorch/mlls/exact_marginal_log_likelihood.py +++ b/gpytorch/mlls/exact_marginal_log_likelihood.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 +from linear_operator.operators import MaskedLinearOperator + from .. import settings from ..distributions import MultivariateNormal from ..likelihoods import _GaussianLikelihoodBase -from ..utils.masked_linear_operator import MaskedLinearOperator from .marginal_log_likelihood import MarginalLogLikelihood diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 4a27f671d..2b716d73f 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -14,6 +14,7 @@ InterpolatedLinearOperator, LinearOperator, LowRankRootAddedDiagLinearOperator, + MaskedLinearOperator, MatmulLinearOperator, RootLinearOperator, ZeroLinearOperator, @@ -26,7 +27,6 @@ from ..distributions import MultitaskMultivariateNormal from ..lazy import LazyEvaluatedKernelTensor -from ..utils.masked_linear_operator import MaskedLinearOperator from ..utils.memoize import add_to_cache, cached, clear_cache_hook, pop_from_cache @@ -349,7 +349,9 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # Restrict train dimension to observed values observed = settings.observation_nan_policy._get_observed(mean_cache, torch.Size((mean_cache.shape[-1],))) full_mask = torch.ones(test_mean.shape[-1], dtype=torch.bool, device=test_mean.device) - test_train_covar = MaskedLinearOperator(test_train_covar, full_mask, observed.reshape(-1)) + test_train_covar = MaskedLinearOperator( + to_linear_operator(test_train_covar), full_mask, observed.reshape(-1) + ) res = (test_train_covar @ mean_cache[..., observed].unsqueeze(-1)).squeeze(-1) else: # 'fill' # Set the columns corresponding to missing observations to 0 to ignore them during matmul. diff --git a/gpytorch/utils/masked_linear_operator.py b/gpytorch/utils/masked_linear_operator.py deleted file mode 100644 index e5c703e80..000000000 --- a/gpytorch/utils/masked_linear_operator.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Optional, Union - -import torch -from linear_operator import LinearOperator -from torch import Tensor - - -class MaskedLinearOperator(LinearOperator): - def __init__(self, base: LinearOperator, row_mask: Tensor, col_mask: Tensor): - super().__init__(base, row_mask, col_mask) - self.base = base - self.row_mask = row_mask - self.col_mask = col_mask - self.row_eq_col_mask = row_mask is not None and col_mask is not None and torch.equal(row_mask, col_mask) - - def _matmul(self, rhs: Tensor) -> Tensor: - if self.col_mask is not None: - rhs_expanded = torch.zeros( - *rhs.shape[:-2], - self.base.size(-1), - rhs.shape[-1], - device=rhs.device, - dtype=rhs.dtype, - ) - rhs_expanded[..., self.col_mask, :] = rhs - rhs = rhs_expanded - - res = self.base.matmul(rhs) - - if self.row_mask is not None: - res = res[..., self.row_mask, :] - - return res - - def _size(self) -> torch.Size: - base_size = list(self.base.size()) - if self.row_mask is not None: - base_size[-2] = torch.count_nonzero(self.row_mask) - if self.col_mask is not None: - base_size[-1] = torch.count_nonzero(self.col_mask) - return torch.Size(tuple(base_size)) - - def _transpose_nonbatch(self) -> LinearOperator: - return MaskedLinearOperator(self.base.mT, self.col_mask, self.row_mask) - - def _getitem( - self, - row_index: Union[slice, torch.LongTensor], - col_index: Union[slice, torch.LongTensor], - *batch_indices: tuple[Union[int, slice, torch.LongTensor], ...], - ) -> LinearOperator: - raise NotImplementedError("Indexing with %r, %r, %r not supported." % (batch_indices, row_index, col_index)) - - def _get_indices( - self, - row_index: torch.LongTensor, - col_index: torch.LongTensor, - *batch_indices: tuple[torch.LongTensor, ...], - ) -> torch.Tensor: - def map_indices(index: torch.LongTensor, mask: Optional[Tensor], base_size: int) -> torch.LongTensor: - if mask is None: - return index - map = torch.arange(base_size, device=self.base.device)[mask] - return map[index] - - if len(batch_indices) == 0: - row_index = map_indices(row_index, self.row_mask, self.base.size(-2)) - col_index = map_indices(col_index, self.col_mask, self.base.size(-1)) - return self.base._get_indices(row_index, col_index) - - raise NotImplementedError("Indexing with %r, %r, %r not supported." % (batch_indices, row_index, col_index)) - - def _diagonal(self) -> Tensor: - if not self.row_eq_col_mask: - raise NotImplementedError() - diag = self.base.diagonal() - return diag[self.row_mask] - - def to_dense(self) -> torch.Tensor: - full_dense = self.base.to_dense() - return full_dense[..., self.row_mask, :][..., :, self.col_mask] - - def _cholesky_solve(self, rhs, upper: bool = False) -> LinearOperator: - raise NotImplementedError() - - def _expand_batch(self, batch_shape: torch.Size) -> LinearOperator: - raise NotImplementedError() - - def _isclose(self, other, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Tensor: - raise NotImplementedError() - - def _prod_batch(self, dim: int) -> LinearOperator: - raise NotImplementedError() - - def _sum_batch(self, dim: int) -> LinearOperator: - raise NotImplementedError() diff --git a/setup.py b/setup.py index 649e2fead..e0a6b8a43 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ def find_version(*file_paths): torch_min = "1.11" install_requires = [ "scikit-learn", - "linear_operator>=0.5.0", + "linear_operator>=0.5.1", ] # if recent dev version of PyTorch is installed, no need to install stable try: From 15ad5422fea5f17a47fff0cd57c83422aa04be0d Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Tue, 8 Aug 2023 14:43:35 +0200 Subject: [PATCH 22/23] Disable test_t_matmul_matrix() for LazyEvaluatedKernelTensor The test fails because LazyEvaluatedKernelTensor only supports _matmul() with checkpointing, but checkpointing is deprecated. --- test/lazy/test_lazy_evaluated_kernel_tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/lazy/test_lazy_evaluated_kernel_tensor.py b/test/lazy/test_lazy_evaluated_kernel_tensor.py index 7bb9e00af..e14433c0b 100644 --- a/test/lazy/test_lazy_evaluated_kernel_tensor.py +++ b/test/lazy/test_lazy_evaluated_kernel_tensor.py @@ -158,6 +158,10 @@ def test_grad_state(self): lazy_tensor = k(X) self.assertFalse(lazy_tensor.to_dense().requires_grad) + def test_t_matmul_matrix(self): + # Not supported without checkpointing, and checkpointing is deprecated. (#2361) + pass + class TestLazyEvaluatedKernelTensorMultitaskBatch(TestLazyEvaluatedKernelTensorBatch): seed = 0 From 10349558acf0820985ed40946f82d2facd8cba8d Mon Sep 17 00:00:00 2001 From: Tilman Hoffbauer Date: Wed, 16 Aug 2023 09:35:50 +0200 Subject: [PATCH 23/23] Fix merge conflict --- test/lazy/test_lazy_evaluated_kernel_tensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/lazy/test_lazy_evaluated_kernel_tensor.py b/test/lazy/test_lazy_evaluated_kernel_tensor.py index 74170a1ea..5a3528704 100644 --- a/test/lazy/test_lazy_evaluated_kernel_tensor.py +++ b/test/lazy/test_lazy_evaluated_kernel_tensor.py @@ -161,10 +161,6 @@ def test_grad_state(self): lazy_tensor = k(X) self.assertFalse(lazy_tensor.to_dense().requires_grad) - def test_t_matmul_matrix(self): - # Not supported without checkpointing, and checkpointing is deprecated. (#2361) - pass - class TestLazyEvaluatedKernelTensorMultitaskBatch(TestLazyEvaluatedKernelTensorBatch): seed = 0