diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index e7f3788ca..a98bd862b 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -44,6 +44,7 @@ MultitaskMultivariateNormal .. autoclass:: MultitaskMultivariateNormal :members: + :special-members: __getitem__ Delta diff --git a/docs/source/likelihoods.rst b/docs/source/likelihoods.rst index a4c215a6e..3b71b42ae 100644 --- a/docs/source/likelihoods.rst +++ b/docs/source/likelihoods.rst @@ -34,7 +34,7 @@ reduce the variance when computing approximate GP objective functions. :members: :hidden:`GaussianLikelihoodWithMissingObs` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: GaussianLikelihoodWithMissingObs :members: diff --git a/gpytorch/distributions/multitask_multivariate_normal.py b/gpytorch/distributions/multitask_multivariate_normal.py index 987826540..b27ffdc1a 100644 --- a/gpytorch/distributions/multitask_multivariate_normal.py +++ b/gpytorch/distributions/multitask_multivariate_normal.py @@ -2,7 +2,12 @@ import torch from linear_operator import LinearOperator, to_linear_operator -from linear_operator.operators import BlockDiagLinearOperator, BlockInterleavedLinearOperator, CatLinearOperator +from linear_operator.operators import ( + BlockDiagLinearOperator, + BlockInterleavedLinearOperator, + CatLinearOperator, + DiagLinearOperator, +) from .multivariate_normal import MultivariateNormal @@ -18,7 +23,7 @@ class MultitaskMultivariateNormal(MultivariateNormal): :param torch.Tensor mean: An `n x t` or batch `b x n x t` matrix of means for the MVN distribution. :param ~linear_operator.operators.LinearOperator covar: An `... x NT x NT` (batch) matrix. covariance matrix of MVN distribution. - :param bool validate_args: (default=False) If True, validate `mean` anad `covariance_matrix` arguments. + :param bool validate_args: (default=False) If True, validate `mean` and `covariance_matrix` arguments. :param bool interleaved: (default=True) If True, covariance matrix is interpreted as block-diagonal w.r.t. inter-task covariances for each observation. If False, it is interpreted as block-diagonal w.r.t. inter-observation covariance for each task. @@ -276,5 +281,145 @@ def variance(self): return var.view(new_shape).transpose(-1, -2).contiguous() return var.view(self._output_shape) + def __getitem__(self, idx) -> MultivariateNormal: + """ + Constructs a new MultivariateNormal that represents a random variable + modified by an indexing operation. + + The mean and covariance matrix arguments are indexed accordingly. + + :param Any idx: Index to apply to the mean. The covariance matrix is indexed accordingly. + :returns: If indices specify a slice for samples and tasks, returns a + MultitaskMultivariateNormal, else returns a MultivariateNormal. + """ + + # Normalize index to a tuple + if not isinstance(idx, tuple): + idx = (idx,) + + if ... in idx: + # Replace ellipsis '...' with explicit indices + ellipsis_location = idx.index(...) + if ... in idx[ellipsis_location + 1 :]: + raise IndexError("Only one ellipsis '...' is supported!") + prefix = idx[:ellipsis_location] + suffix = idx[ellipsis_location + 1 :] + infix_length = self.mean.dim() - len(prefix) - len(suffix) + if infix_length < 0: + raise IndexError(f"Index {idx} has too many dimensions") + idx = prefix + (slice(None),) * infix_length + suffix + elif len(idx) == self.mean.dim() - 1: + # Normalize indices ignoring the task-index to include it + idx = idx + (slice(None),) + + new_mean = self.mean[idx] + + # We now create a covariance matrix appropriate for new_mean + if len(idx) <= self.mean.dim() - 2: + # We are only indexing the batch dimensions in this case + return MultitaskMultivariateNormal( + mean=new_mean, + covariance_matrix=self.lazy_covariance_matrix[idx], + interleaved=self._interleaved, + ) + elif len(idx) > self.mean.dim(): + raise IndexError(f"Index {idx} has too many dimensions") + else: + # We have an index that extends over all dimensions + batch_idx = idx[:-2] + if self._interleaved: + row_idx = idx[-2] + col_idx = idx[-1] + num_rows = self._output_shape[-2] + num_cols = self._output_shape[-1] + else: + row_idx = idx[-1] + col_idx = idx[-2] + num_rows = self._output_shape[-1] + num_cols = self._output_shape[-2] + + if isinstance(row_idx, int) and isinstance(col_idx, int): + # Single sample with single task + row_idx = _normalize_index(row_idx, num_rows) + col_idx = _normalize_index(col_idx, num_cols) + new_cov = DiagLinearOperator( + self.lazy_covariance_matrix.diagonal()[batch_idx + (row_idx * num_cols + col_idx,)] + ) + return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov) + elif isinstance(row_idx, int) and isinstance(col_idx, slice): + # A block of the covariance matrix + row_idx = _normalize_index(row_idx, num_rows) + col_idx = _normalize_slice(col_idx, num_cols) + new_slice = slice( + col_idx.start + row_idx * num_cols, + col_idx.stop + row_idx * num_cols, + col_idx.step, + ) + new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)] + return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov) + elif isinstance(row_idx, slice) and isinstance(col_idx, int): + # A block of the reversely interleaved covariance matrix + row_idx = _normalize_slice(row_idx, num_rows) + col_idx = _normalize_index(col_idx, num_cols) + new_slice = slice(row_idx.start + col_idx, row_idx.stop * num_cols + col_idx, row_idx.step * num_cols) + new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)] + return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov) + elif ( + isinstance(row_idx, slice) + and isinstance(col_idx, slice) + and row_idx == col_idx == slice(None, None, None) + ): + new_cov = self.lazy_covariance_matrix[batch_idx] + return MultitaskMultivariateNormal( + mean=new_mean, + covariance_matrix=new_cov, + interleaved=self._interleaved, + validate_args=False, + ) + elif isinstance(row_idx, slice) or isinstance(col_idx, slice): + # slice x slice or indices x slice or slice x indices + if isinstance(row_idx, slice): + row_idx = torch.arange(num_rows)[row_idx] + if isinstance(col_idx, slice): + col_idx = torch.arange(num_cols)[col_idx] + row_grid, col_grid = torch.meshgrid(row_idx, col_idx, indexing="ij") + indices = (row_grid * num_cols + col_grid).reshape(-1) + new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices] + return MultitaskMultivariateNormal( + mean=new_mean, covariance_matrix=new_cov, interleaved=self._interleaved, validate_args=False + ) + else: + # row_idx and col_idx have pairs of indices + indices = row_idx * num_cols + col_idx + new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices] + return MultivariateNormal( + mean=new_mean, + covariance_matrix=new_cov, + ) + def __repr__(self) -> str: return f"MultitaskMultivariateNormal(mean shape: {self._output_shape})" + + +def _normalize_index(i: int, dim_size: int) -> int: + if i < 0: + return dim_size + i + else: + return i + + +def _normalize_slice(s: slice, dim_size: int) -> slice: + start = s.start + if start is None: + start = 0 + elif start < 0: + start = dim_size + start + stop = s.stop + if stop is None: + stop = dim_size + elif stop < 0: + stop = dim_size + stop + step = s.step + if step is None: + step = 1 + return slice(start, stop, step) diff --git a/gpytorch/distributions/multivariate_normal.py b/gpytorch/distributions/multivariate_normal.py index 8f7911aec..c5f29307b 100644 --- a/gpytorch/distributions/multivariate_normal.py +++ b/gpytorch/distributions/multivariate_normal.py @@ -343,16 +343,22 @@ def __getitem__(self, idx) -> MultivariateNormal: The mean and covariance matrix arguments are indexed accordingly. - :param idx: Index to apply. + :param idx: Index to apply to the mean. The covariance matrix is indexed accordingly. """ if not isinstance(idx, tuple): idx = (idx,) + if len(idx) > self.mean.dim() and Ellipsis in idx: + idx = tuple(i for i in idx if i != Ellipsis) + if len(idx) < self.mean.dim(): + raise IndexError("Multiple ambiguous ellipsis in index!") + rest_idx = idx[:-1] last_idx = idx[-1] new_mean = self.mean[idx] if len(idx) <= self.mean.dim() - 1 and (Ellipsis not in rest_idx): + # We are only indexing the batch dimensions in this case new_cov = self.lazy_covariance_matrix[idx] elif len(idx) > self.mean.dim(): raise IndexError(f"Index {idx} has too many dimensions") diff --git a/gpytorch/likelihoods/gaussian_likelihood.py b/gpytorch/likelihoods/gaussian_likelihood.py index 42cd715e1..e753f92c3 100644 --- a/gpytorch/likelihoods/gaussian_likelihood.py +++ b/gpytorch/likelihoods/gaussian_likelihood.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 - import math import warnings from copy import deepcopy from typing import Any, Optional, Tuple, Union import torch -from linear_operator.operators import LinearOperator, ZeroLinearOperator +from linear_operator.operators import LinearOperator, MaskedLinearOperator, ZeroLinearOperator from torch import Tensor from torch.distributions import Distribution, Normal +from .. import settings from ..constraints import Interval from ..distributions import base_distributions, MultivariateNormal from ..priors import Prior @@ -39,17 +39,39 @@ def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: An return self.noise_covar(*params, shape=base_shape, **kwargs) def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any) -> Tensor: - mean, variance = input.mean, input.variance - num_event_dim = len(input.event_shape) - noise = self._shaped_noise_covar(mean.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2) + noise = self._shaped_noise_covar(input.mean.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2) # Potentially reshape the noise to deal with the multitask case noise = noise.view(*noise.shape[:-1], *input.event_shape) + # Handle NaN values if enabled + nan_policy = settings.observation_nan_policy.value() + if nan_policy == "mask": + observed = settings.observation_nan_policy._get_observed(target, input.event_shape) + input = MultivariateNormal( + mean=input.mean[..., observed], + covariance_matrix=MaskedLinearOperator( + input.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1) + ), + ) + noise = noise[..., observed] + target = target[..., observed] + elif nan_policy == "fill": + missing = torch.isnan(target) + target = settings.observation_nan_policy._fill_tensor(target) + + mean, variance = input.mean, input.variance res = ((target - mean).square() + variance) / noise + noise.log() + math.log(2 * math.pi) res = res.mul(-0.5) - if num_event_dim > 1: # Do appropriate summation for multitask Gaussian likelihoods + + if nan_policy == "fill": + res = res * ~missing + + # Do appropriate summation for multitask Gaussian likelihoods + num_event_dim = len(input.event_shape) + if num_event_dim > 1: res = res.sum(list(range(-1, -num_event_dim, -1))) + return res def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> Normal: @@ -60,12 +82,31 @@ def log_marginal( self, observations: Tensor, function_dist: MultivariateNormal, *params: Any, **kwargs: Any ) -> Tensor: marginal = self.marginal(function_dist, *params, **kwargs) + + # Handle NaN values if enabled + nan_policy = settings.observation_nan_policy.value() + if nan_policy == "mask": + observed = settings.observation_nan_policy._get_observed(observations, marginal.event_shape) + marginal = MultivariateNormal( + mean=marginal.mean[..., observed], + covariance_matrix=MaskedLinearOperator( + marginal.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1) + ), + ) + observations = observations[..., observed] + elif nan_policy == "fill": + missing = torch.isnan(observations) + observations = settings.observation_nan_policy._fill_tensor(observations) + # We're making everything conditionally independent indep_dist = base_distributions.Normal(marginal.mean, marginal.variance.clamp_min(1e-8).sqrt()) res = indep_dist.log_prob(observations) + if nan_policy == "fill": + res = res * ~missing + # Do appropriate summation for multitask Gaussian likelihoods - num_event_dim = len(function_dist.event_shape) + num_event_dim = len(marginal.event_shape) if num_event_dim > 1: res = res.sum(list(range(-1, -num_event_dim, -1))) return res @@ -150,13 +191,15 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood): .. note:: This likelihood can be used for exact or approximate inference. + .. warning:: + This likelihood is deprecated in favor of :class:`gpytorch.settings.observation_nan_policy`. + :param noise_prior: Prior for noise parameter :math:`\sigma^2`. :type noise_prior: ~gpytorch.priors.Prior, optional :param noise_constraint: Constraint for noise parameter :math:`\sigma^2`. :type noise_constraint: ~gpytorch.constraints.Interval, optional :param batch_shape: The batch shape of the learned noise parameter (default: []). :type batch_shape: torch.Size, optional - :var torch.Tensor noise: :math:`\sigma^2` parameter (noise) .. note:: @@ -166,6 +209,10 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood): MISSING_VALUE_FILL: float = -999.0 def __init__(self, **kwargs: Any) -> None: + warnings.warn( + "GaussianLikelihoodWithMissingObs is replaced by gpytorch.settings.observation_nan_policy('fill').", + DeprecationWarning, + ) super().__init__(**kwargs) def _get_masked_obs(self, x: Tensor) -> Tuple[Tensor, Tensor]: diff --git a/gpytorch/mlls/exact_marginal_log_likelihood.py b/gpytorch/mlls/exact_marginal_log_likelihood.py index 5338c9bb3..7b2987f50 100644 --- a/gpytorch/mlls/exact_marginal_log_likelihood.py +++ b/gpytorch/mlls/exact_marginal_log_likelihood.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +from linear_operator.operators import MaskedLinearOperator + +from .. import settings from ..distributions import MultivariateNormal from ..likelihoods import _GaussianLikelihoodBase from .marginal_log_likelihood import MarginalLogLikelihood @@ -59,8 +62,23 @@ def forward(self, function_dist, target, *params): if not isinstance(function_dist, MultivariateNormal): raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables") - # Get the log prob of the marginal distribution + # Determine output likelihood output = self.likelihood(function_dist, *params) + + # Remove NaN values if enabled + if settings.observation_nan_policy.value() == "mask": + observed = settings.observation_nan_policy._get_observed(target, output.event_shape) + output = MultivariateNormal( + mean=output.mean[..., observed], + covariance_matrix=MaskedLinearOperator( + output.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1) + ), + ) + target = target[..., observed] + elif settings.observation_nan_policy.value() == "fill": + raise ValueError("NaN observation policy 'fill' is not supported by ExactMarginalLogLikelihood!") + + # Get the log prob of the marginal distribution res = output.log_prob(target) res = self._add_other_terms(res, params) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index cd7047412..2b716d73f 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -2,6 +2,7 @@ import functools import string +import warnings import torch from linear_operator import to_dense, to_linear_operator @@ -13,6 +14,7 @@ InterpolatedLinearOperator, LinearOperator, LowRankRootAddedDiagLinearOperator, + MaskedLinearOperator, MatmulLinearOperator, RootLinearOperator, ZeroLinearOperator, @@ -247,14 +249,45 @@ def covar_cache(self): return self._exact_predictive_covar_inv_quad_form_cache(train_train_covar_inv_root, self._last_test_train_covar) @property - @cached(name="mean_cache") def mean_cache(self): + return self._mean_cache(settings.observation_nan_policy.value()) + + @cached(name="mean_cache") + def _mean_cache(self, nan_policy: str) -> Tensor: mvn = self.likelihood(self.train_prior_dist, self.train_inputs) train_mean, train_train_covar = mvn.loc, mvn.lazy_covariance_matrix train_labels_offset = (self.train_labels - train_mean).unsqueeze(-1) - mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1) + if nan_policy == "ignore": + mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1) + elif nan_policy == "mask": + # Mask all rows and columns in the kernel matrix corresponding to the missing observations. + observed = settings.observation_nan_policy._get_observed( + self.train_labels, torch.Size((self.train_labels.shape[-1],)) + ) + mean_cache = torch.full_like(self.train_labels, torch.nan) + kernel = MaskedLinearOperator( + train_train_covar.evaluate_kernel(), observed.reshape(-1), observed.reshape(-1) + ) + mean_cache[..., observed] = kernel.solve(train_labels_offset[..., observed, :]).squeeze(-1) + else: # 'fill' + # Fill all rows and columns in the kernel matrix corresponding to the missing observations with 0. + # Don't touch the corresponding diagonal elements to ensure a unique solution. + # This ensures that missing data is ignored during solving. + warnings.warn( + "Observation NaN policy 'fill' makes the kernel matrix dense during exact prediction.", + RuntimeWarning, + ) + kernel = train_train_covar.evaluate_kernel() + missing = torch.isnan(self.train_labels) + kernel_mask = (~missing).to(torch.float) + kernel_mask = kernel_mask[..., None] * kernel_mask[..., None, :] + torch.diagonal(kernel_mask, dim1=-2, dim2=-1)[...] = 1 + kernel = kernel * kernel_mask # Unfortunately, this makes the kernel dense at the moment. + train_labels_offset = settings.observation_nan_policy._fill_tensor(train_labels_offset) + mean_cache = kernel.solve(train_labels_offset).squeeze(-1) + mean_cache[missing] = torch.nan # Ensure that nobody expects these values to be valid. if settings.detach_test_caches.on(): mean_cache = mean_cache.detach() @@ -303,10 +336,29 @@ def exact_predictive_mean(self, test_mean: Tensor, test_train_covar: LinearOpera # You **cannot* use addmv here, because test_train_covar may not actually be a non lazy tensor even for an exact # GP, and using addmv requires you to to_dense test_train_covar, which is obviously a huge no-no! - if len(self.mean_cache.shape) == 4: - res = (test_train_covar @ self.mean_cache.squeeze(1).unsqueeze(-1)).squeeze(-1) - else: - res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1) + # see https://github.com/cornellius-gp/gpytorch/pull/2317#discussion_r1157994719 + mean_cache = self.mean_cache + if len(mean_cache.shape) == 4: + mean_cache = mean_cache.squeeze(1) + + # Handle NaNs + nan_policy = settings.observation_nan_policy.value() + if nan_policy == "ignore": + res = (test_train_covar @ mean_cache.unsqueeze(-1)).squeeze(-1) + elif nan_policy == "mask": + # Restrict train dimension to observed values + observed = settings.observation_nan_policy._get_observed(mean_cache, torch.Size((mean_cache.shape[-1],))) + full_mask = torch.ones(test_mean.shape[-1], dtype=torch.bool, device=test_mean.device) + test_train_covar = MaskedLinearOperator( + to_linear_operator(test_train_covar), full_mask, observed.reshape(-1) + ) + res = (test_train_covar @ mean_cache[..., observed].unsqueeze(-1)).squeeze(-1) + else: # 'fill' + # Set the columns corresponding to missing observations to 0 to ignore them during matmul. + mask = (~torch.isnan(mean_cache)).to(torch.float)[..., None, :] + test_train_covar = test_train_covar * mask + mean = settings.observation_nan_policy._fill_tensor(mean_cache) + res = (test_train_covar @ mean.unsqueeze(-1)).squeeze(-1) res = res + test_mean return res diff --git a/gpytorch/settings.py b/gpytorch/settings.py index d7207375a..595ffe10d 100644 --- a/gpytorch/settings.py +++ b/gpytorch/settings.py @@ -27,6 +27,7 @@ use_toeplitz, verbose_linalg, ) +from torch import Tensor class _dtype_value_context: @@ -401,6 +402,52 @@ def value(cls, dtype=None): return super().value(dtype=dtype) +class observation_nan_policy(_value_context): + """ + NaN handling policy for observations. + + * ``ignore``: Do not check for NaN values (the default). + * ``mask``: Mask out NaN values during calculation. If an output is NaN in a single batch element, this output + is masked for the complete batch. This strategy likely is a good choice if you have NaN values. + * ``fill``: Fill in NaN values with a dummy value, perform computations and filter them later. + Not supported by :class:`gpytorch.mlls.ExactMarginalLogLikelihood`. + Does not support lazy covariance matrices during prediction. + """ + + _fill_value = -999.0 + _global_value = "ignore" + + def __init__(self, value): + if value not in {"ignore", "mask", "fill"}: + raise ValueError(f"NaN handling policy {value} not supported!") + super().__init__(value) + + @staticmethod + def _get_observed(observations, event_shape) -> Tensor: + """ + Constructs a tensor that masks out all elements in the event shape of the tensor which contain a NaN value in + any batch element. Applying this index flattens the event_shape, as the task structure cannot be retained. + This function is used if observation_nan_policy is set to 'mask'. + + :param Tensor observations: The observations to search for NaN values in. + :param torch.Size event_shape: The shape of a single event, i.e. the shape of observations without batch + dimensions. + :return: The mask to the event dimensions of the observations. + """ + return ~torch.any(torch.isnan(observations.reshape(-1, *event_shape)), dim=0) + + @classmethod + def _fill_tensor(cls, observations) -> Tensor: + """ + Fills a tensor's NaN values with a filling value. + This function is used if observation_nan_policy is set to 'fill'. + + :param Tensor observations: The tensor to fill with values. + :return: The filled in observations. + """ + return torch.nan_to_num(observations, nan=cls._fill_value) + + __all__ = [ "_linalg_dtype_symeig", "_linalg_dtype_cholesky", @@ -431,6 +478,7 @@ def value(cls, dtype=None): "num_gauss_hermite_locs", "num_likelihood_samples", "num_trace_samples", + "observation_nan_policy", "preconditioner_tolerance", "prior_mode", "sgpr_diagonal_correction", diff --git a/test/distributions/test_multitask_multivariate_normal.py b/test/distributions/test_multitask_multivariate_normal.py index 9e5db7b8c..dc8f59bf2 100644 --- a/test/distributions/test_multitask_multivariate_normal.py +++ b/test/distributions/test_multitask_multivariate_normal.py @@ -315,6 +315,163 @@ def test_multitask_multivariate_normal_broadcasting(self): covar = _covar @ _covar.transpose(-1, -2) MultitaskMultivariateNormal(mean, covar) + def test_getitem_interleaved(self): + mean_shape = (2, 4, 3, 2) + covar_shape = (2, 4, 6, 6) + mean = torch.randn(mean_shape) + _covar = torch.randn(covar_shape) + covar = _covar @ _covar.transpose(-1, -2) + distribution = MultitaskMultivariateNormal(mean, covar, validate_args=True) + + def flat(observation: int, task: int) -> int: + return observation * 2 + task + + part = distribution[1, -1] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size(())) + self.assertEqual(part.event_shape, torch.Size((3, 2))) + self.assertAllClose(part.mean, mean[1, -1]) + self.assertAllClose(part.covariance_matrix, covar[1, -1]) + + part = distribution[1, 0, ...] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size(())) + self.assertEqual(part.event_shape, torch.Size((3, 2))) + self.assertAllClose(part.mean, mean[1, 0]) + self.assertAllClose(part.covariance_matrix, covar[1, 0]) + + part = distribution[..., 2, 1] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2,))) + self.assertEqual(part.event_shape, (4,)) + self.assertAllClose(part.mean, mean[..., 2, 1]) + self.assertAllClose(part.covariance_matrix, torch.diag_embed(covar[:, :, flat(2, 1), flat(2, 1)])) + + part = distribution[1, ..., -2] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((4,))) + self.assertEqual(part.event_shape, torch.Size((3,))) + self.assertAllClose(part.mean, mean[1, :, :, 0]) + self.assertAllClose(part.covariance_matrix, covar[1, :, ::2, ::2]) + + part = distribution[..., 2, :] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((2,))) + self.assertAllClose(part.mean, mean[:, :, 2, :]) + self.assertAllClose(part.covariance_matrix, covar[:, :, 2 * 2 : 3 * 2, 2 * 2 : 3 * 2]) + + part = distribution[0, :, :, torch.tensor([1, 0])] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((4,))) + self.assertEqual(part.event_shape, torch.Size((3, 2))) + self.assertAllClose(part.mean, mean[0, ..., torch.tensor([1, 0])]) + indices = torch.tensor([1, 0, 3, 2, 5, 4]) + self.assertAllClose(part.covariance_matrix, covar[0, :, indices][..., indices]) + + part = distribution[:, 1, torch.tensor([2, 0])] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2,))) + self.assertEqual(part.event_shape, torch.Size((2, 2))) + self.assertAllClose(part.mean, mean[:, 1, torch.tensor([2, 0])]) + indices = torch.tensor([4, 5, 0, 1]) + self.assertAllClose(part.covariance_matrix, covar[:, 1, indices][..., indices]) + + part = distribution[..., 1:, :-1] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((2, 1))) + self.assertAllClose(part.mean, mean[..., 1:, :-1]) + indices = torch.tensor([flat(1, 0), flat(2, 0)]) + self.assertAllClose(part.covariance_matrix, covar[..., indices, :][..., indices]) + + part = distribution[..., torch.tensor([2, 0, 2]), torch.tensor([1, 0, 0])] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((3,))) + self.assertAllClose(part.mean, mean[..., torch.tensor([2, 0, 2]), torch.tensor([1, 0, 0])]) + indices = torch.tensor([flat(2, 1), flat(0, 0), flat(2, 0)]) + self.assertAllClose(part.covariance_matrix, covar[..., indices, :][..., indices]) + + def test_getitem_non_interleaved(self): + mean_shape = (2, 4, 3, 2) + covar_shape = (2, 4, 6, 6) + mean = torch.randn(mean_shape) + _covar = torch.randn(covar_shape) + covar = _covar @ _covar.transpose(-1, -2) + distribution = MultitaskMultivariateNormal(mean, covar, validate_args=True, interleaved=False) + + def flat(observation: int, task: int) -> int: + return task * 3 + observation + + part = distribution[1, -1] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size(())) + self.assertEqual(part.event_shape, torch.Size((3, 2))) + self.assertAllClose(part.mean, mean[1, -1]) + self.assertAllClose(part.covariance_matrix, covar[1, -1]) + + part = distribution[..., 2, 1] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2,))) + self.assertEqual(part.event_shape, (4,)) + self.assertAllClose(part.mean, mean[..., 2, 1]) + self.assertAllClose(part.covariance_matrix, torch.diag_embed(covar[:, :, flat(2, 1), flat(2, 1)])) + + part = distribution[1, ..., -2] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((4,))) + self.assertEqual(part.event_shape, torch.Size((3,))) + self.assertAllClose(part.mean, mean[1, :, :, 0]) + self.assertAllClose(part.covariance_matrix, covar[1, :, :3, :3]) + + part = distribution[..., 2, :] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((2,))) + self.assertAllClose(part.mean, mean[:, :, 2, :]) + self.assertAllClose(part.covariance_matrix, covar[:, :, 2::3, 2::3]) + + part = distribution[0, :, :, torch.tensor([1, 0])] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((4,))) + self.assertEqual(part.event_shape, torch.Size((3, 2))) + self.assertAllClose(part.mean, mean[0, ..., torch.tensor([1, 0])]) + indices = torch.tensor([3, 4, 5, 0, 1, 2]) + self.assertAllClose(part.covariance_matrix, covar[0, :, indices][..., indices]) + + part = distribution[:, 1, torch.tensor([2, 0])] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2,))) + self.assertEqual(part.event_shape, torch.Size((2, 2))) + self.assertAllClose(part.mean, mean[:, 1, torch.tensor([2, 0])]) + indices = torch.tensor([2, 0, 5, 3]) + self.assertAllClose(part.covariance_matrix, covar[:, 1, indices][..., indices]) + + part = distribution[..., 1:, :-1] + self.assertIsInstance(part, MultitaskMultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((2, 1))) + self.assertAllClose(part.mean, mean[..., 1:, :-1]) + indices = torch.tensor([flat(1, 0), flat(2, 0)]) + self.assertAllClose(part.covariance_matrix, covar[..., indices, :][..., indices]) + + part = distribution[..., torch.tensor([2, 0, 2]), torch.tensor([1, 0, 0])] + self.assertFalse(isinstance(part, MultitaskMultivariateNormal)) + self.assertIsInstance(part, MultivariateNormal) + self.assertEqual(part.batch_shape, torch.Size((2, 4))) + self.assertEqual(part.event_shape, torch.Size((3,))) + self.assertAllClose(part.mean, mean[..., torch.tensor([2, 0, 2]), torch.tensor([1, 0, 0])]) + indices = torch.tensor([flat(2, 1), flat(0, 0), flat(2, 0)]) + self.assertAllClose(part.covariance_matrix, covar[..., indices, :][..., indices]) + def test_repr(self): mean = torch.randn(5, 1, 3) covar = torch.eye(6) diff --git a/test/distributions/test_multivariate_normal.py b/test/distributions/test_multivariate_normal.py index b9cc2568a..e287884fb 100644 --- a/test/distributions/test_multivariate_normal.py +++ b/test/distributions/test_multivariate_normal.py @@ -299,6 +299,10 @@ def test_getitem(self): assert torch.equal(d.mean, dist.mean[1, 2, 2, :]) self.assertAllClose(d.covariance_matrix, dist_cov[1, 2, 2, :, :]) + d = dist[0, 1, ..., 2, 1] + assert torch.equal(d.mean, dist.mean[0, 1, 2, 1]) + self.assertAllClose(d.covariance_matrix, dist_cov[0, 1, 2, 1, 1]) + def test_base_sample_shape(self): a = torch.randn(5, 10) lazy_square_a = RootLinearOperator(to_linear_operator(a)) diff --git a/test/examples/test_missing_data.py b/test/examples/test_missing_data.py new file mode 100644 index 000000000..4a3b651cb --- /dev/null +++ b/test/examples/test_missing_data.py @@ -0,0 +1,234 @@ +import unittest + +import torch + +from gpytorch import settings +from gpytorch import ExactMarginalLogLikelihood +from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal +from gpytorch.kernels import ScaleKernel, RBFKernel, MultitaskKernel +from gpytorch.likelihoods import GaussianLikelihood, Likelihood, MultitaskGaussianLikelihood +from gpytorch.means import ConstantMean, MultitaskMean +from gpytorch.mlls import PredictiveLogLikelihood, MarginalLogLikelihood, VariationalELBO +from gpytorch.models import ExactGP, VariationalGP, GP +from gpytorch.test.base_test_case import BaseTestCase +from gpytorch.utils.memoize import clear_cache_hook +from gpytorch.variational import CholeskyVariationalDistribution, LMCVariationalStrategy, VariationalStrategy + + +class SingleGPModel(ExactGP): + def __init__(self, train_inputs, train_targets, likelihood, batch_shape): + super(SingleGPModel, self).__init__(train_inputs, train_targets, likelihood) + self.mean_module = ConstantMean(batch_shape=batch_shape) + self.covar_module = ScaleKernel(RBFKernel(batch_shape=batch_shape)) + + def forward(self, x): + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) + return MultivariateNormal(mean_x, covar_x) + + +class MultitaskGPModel(ExactGP): + def __init__(self, train_inputs, train_targets, likelihood, num_tasks): + super(MultitaskGPModel, self).__init__(train_inputs, train_targets, likelihood) + self.mean_module = MultitaskMean(ConstantMean(), num_tasks=num_tasks) + self.covar_module = MultitaskKernel(ScaleKernel(RBFKernel()), num_tasks=num_tasks) + + def forward(self, x): + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) + return MultitaskMultivariateNormal(mean_x, covar_x) + + +class MultitaskVariationalGPModel(VariationalGP): + def __init__(self, num_latents, num_tasks): + inducing_points = torch.rand(num_latents, 21, 1) + variational_distribution = CholeskyVariationalDistribution( + inducing_points.size(-2), batch_shape=torch.Size([num_latents]) + ) + variational_strategy = LMCVariationalStrategy( + VariationalStrategy( + self, inducing_points, variational_distribution, learn_inducing_locations=True + ), + num_tasks=num_tasks, + num_latents=num_latents, + latent_dim=-1 + ) + super().__init__(variational_strategy) + self.mean_module = ConstantMean(batch_shape=torch.Size([num_latents])) + self.covar_module = ScaleKernel( + RBFKernel(batch_shape=torch.Size([num_latents])), + batch_shape=torch.Size([num_latents]) + ) + + def forward(self, x): + mean_x = self.mean_module(x) + covar_x = self.covar_module(x) + return MultivariateNormal(mean_x, covar_x) + + +class TestMissingData(BaseTestCase, unittest.TestCase): + seed = 20 + warning = "Observation NaN policy 'fill' makes the kernel matrix dense during exact prediction." + + def _check( + self, + model: GP, + likelihood: Likelihood, + train_x: torch.Tensor, + train_y: torch.Tensor, + test_x: torch.Tensor, + test_y: torch.Tensor, + optimizer: torch.optim.Optimizer, + mll: MarginalLogLikelihood, + epochs: int = 30, + atol: float = 0.2 + ) -> None: + model.train() + likelihood.train() + + for _ in range(epochs): + optimizer.zero_grad() + output = model(train_x) + loss = -mll(output, train_y).sum() + self.assertFalse(torch.any(torch.isnan(output.mean)).item()) + self.assertFalse(torch.any(torch.isnan(output.covariance_matrix)).item()) + self.assertFalse(torch.isnan(loss).item()) + loss.backward() + optimizer.step() + + model.eval() + likelihood.eval() + + with torch.no_grad(): + if isinstance(model, ExactGP): + self._check_predictions_exact_gp(model, test_x, test_y, atol) + else: + prediction = model(test_x) + self._check_prediction(prediction, test_y, atol) + + def _check_predictions_exact_gp(self, model: ExactGP, test_x: torch.Tensor, test_y: torch.Tensor, atol: float): + with settings.observation_nan_policy("mask"): + prediction = model(test_x) + self._check_prediction(prediction, test_y, atol) + + clear_cache_hook(model.prediction_strategy) + + with settings.observation_nan_policy("fill"), self.assertWarns(RuntimeWarning, msg=self.warning): + prediction = model(test_x) + self._check_prediction(prediction, test_y, atol) + + clear_cache_hook(model.prediction_strategy) + + with settings.observation_nan_policy("mask"): + model(test_x) + with settings.observation_nan_policy("fill"), self.assertWarns(RuntimeWarning, msg=self.warning): + prediction = model(test_x) + self._check_prediction(prediction, test_y, atol) + + clear_cache_hook(model.prediction_strategy) + + with settings.observation_nan_policy("fill"), self.assertWarns(RuntimeWarning, msg=self.warning): + model(test_x) + with settings.observation_nan_policy("mask"): + prediction = model(test_x) + self._check_prediction(prediction, test_y, atol) + + def _check_prediction(self, prediction: MultivariateNormal, test_y: torch.Tensor, atol: float): + self.assertFalse(torch.any(torch.isnan(prediction.mean)).item()) + self.assertFalse(torch.any(torch.isnan(prediction.covariance_matrix)).item()) + self.assertAllClose(prediction.mean, test_y, atol=atol) + + def test_single(self): + train_x = torch.linspace(0, 1, 41) + test_x = torch.linspace(0, 1, 51) + train_y = torch.sin(2 * torch.pi * train_x).squeeze() + train_y += torch.normal(0, 0.01, train_y.shape) + test_y = torch.sin(2 * torch.pi * test_x).squeeze() + train_y[::4] = torch.nan + + batch_shape = torch.Size(()) + likelihood = GaussianLikelihood(batch_shape=batch_shape) + model = SingleGPModel(train_x, train_y, likelihood, batch_shape=batch_shape) + + mll = ExactMarginalLogLikelihood(likelihood, model) + optimizer = torch.optim.Adam(model.parameters(), lr=0.15) + + with settings.observation_nan_policy("mask"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll) + + def test_single_batch(self): + train_x = torch.stack([torch.linspace(0, 1, 41), torch.linspace(1, 2, 41)]).reshape(2, 41, 1) + test_x = torch.stack([torch.linspace(0, 1, 51), torch.linspace(1, 2, 51)]).reshape(2, 51, 1) + train_y = torch.sin(2 * torch.pi * train_x).squeeze() + train_y += torch.normal(0, 0.01, train_y.shape) + test_y = torch.sin(2 * torch.pi * test_x).squeeze() + train_y[0, ::4] = torch.nan + + batch_shape = torch.Size((2,)) + likelihood = GaussianLikelihood(batch_shape=batch_shape) + model = SingleGPModel(train_x, train_y, likelihood, batch_shape=batch_shape) + + mll = ExactMarginalLogLikelihood(likelihood, model) + optimizer = torch.optim.Adam(model.parameters(), lr=0.15) + + with settings.observation_nan_policy("mask"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll) + + def test_multitask(self): + num_tasks = 10 + train_x = torch.linspace(0, 1, 41) + test_x = torch.linspace(0, 1, 51) + coefficients = torch.rand(1, num_tasks) + train_y = torch.sin(2 * torch.pi * train_x)[:, None] * coefficients + train_y += torch.normal(0, 0.01, train_y.shape) + test_y = torch.sin(2 * torch.pi * test_x)[:, None] * coefficients + train_y[::3, : num_tasks // 2] = torch.nan + train_y[::4, num_tasks // 2 :] = torch.nan + + likelihood = MultitaskGaussianLikelihood(num_tasks) + model = MultitaskGPModel(train_x, train_y, likelihood, num_tasks) + + mll = ExactMarginalLogLikelihood(likelihood, model) + optimizer = torch.optim.Adam(model.parameters(), lr=0.15) + + with settings.observation_nan_policy("mask"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll) + + def test_variational_multitask(self): + num_latents = 1 + train_x = torch.linspace(0, 1, 41) + test_x = torch.linspace(0, 1, 51) + train_y = torch.stack([ + torch.sin(train_x * (2 * torch.pi)) + torch.randn(train_x.size()) * 0.2, + -torch.sin(train_x * (2 * torch.pi)) + torch.randn(train_x.size()) * 0.2, + ], -1) + test_y = torch.stack([ + torch.sin(test_x * (2 * torch.pi)), + -torch.sin(test_x * (2 * torch.pi)), + ], -1) + num_tasks = train_y.shape[-1] + + # nan out a few train_y + train_y[-3:, 1] = torch.nan + + likelihood = MultitaskGaussianLikelihood(num_tasks=num_tasks) + model = MultitaskVariationalGPModel(num_latents, num_tasks) + model.train() + likelihood.train() + + optimizer = torch.optim.Adam([ + {'params': model.parameters()}, + {'params': likelihood.parameters()}, + ], lr=0.05) + + mll = VariationalELBO(likelihood, model, num_data=train_y.size(0)) + with settings.observation_nan_policy("mask"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.7) + with settings.observation_nan_policy("fill"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.3) + + mll = PredictiveLogLikelihood(likelihood, model, num_data=train_y.size(0)) + with settings.observation_nan_policy("mask"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.7) + with settings.observation_nan_policy("fill"): + self._check(model, likelihood, train_x, train_y, test_x, test_y, optimizer, mll, epochs=50, atol=0.3) diff --git a/test/likelihoods/test_gaussian_likelihood.py b/test/likelihoods/test_gaussian_likelihood.py index 9874395a9..2c11a2a5f 100644 --- a/test/likelihoods/test_gaussian_likelihood.py +++ b/test/likelihoods/test_gaussian_likelihood.py @@ -8,12 +8,7 @@ from gpytorch import settings from gpytorch.distributions import MultivariateNormal -from gpytorch.likelihoods import ( - DirichletClassificationLikelihood, - FixedNoiseGaussianLikelihood, - GaussianLikelihood, - GaussianLikelihoodWithMissingObs, -) +from gpytorch.likelihoods import DirichletClassificationLikelihood, FixedNoiseGaussianLikelihood, GaussianLikelihood from gpytorch.likelihoods.noise_models import FixedGaussianNoise from gpytorch.priors import GammaPrior from gpytorch.test.base_likelihood_test_case import BaseLikelihoodTestCase @@ -164,13 +159,13 @@ def test_dirichlet_classification_likelihood(self, cuda=False): self.assertTrue(torch.allclose(out.variance, 1.0 + obs_targets)) -class TestGaussianLikelihoodwithMissingObs(BaseLikelihoodTestCase, unittest.TestCase): +class TestGaussianLikelihoodWithMissingObs(BaseLikelihoodTestCase, unittest.TestCase): seed = 42 def create_likelihood(self): - return GaussianLikelihoodWithMissingObs() + return GaussianLikelihood() - def test_missing_value_inference(self): + def test_missing_value_inference_fill(self): """ samples = mvn samples + noise samples In this test, we try to recover noise parameters when some elements in @@ -179,43 +174,56 @@ def test_missing_value_inference(self): torch.manual_seed(self.seed) - mu = torch.zeros(2, 3) - sigma = torch.tensor([[[1, 0.999, -0.999], [0.999, 1, -0.999], [-0.999, -0.999, 1]]] * 2).float() - mvn = MultivariateNormal(mu, sigma) - samples = mvn.sample(torch.Size([10000])) # mvn samples - - noise_sd = 0.5 - noise_dist = torch.distributions.Normal(0, noise_sd) - samples += noise_dist.sample(samples.shape) # noise + mvn, samples = self._make_data() - missing_prop = 0.33 - missing_idx = torch.distributions.Binomial(1, missing_prop).sample(samples.shape).bool() + missing_probability = 0.33 + missing_idx = torch.distributions.Binomial(1, missing_probability).sample(samples.shape).bool() samples[missing_idx] = float("nan") - likelihood = GaussianLikelihoodWithMissingObs() + # check that the correct noise sd is recovered - # check that the missing value fill doesn't impact the likelihood + with settings.observation_nan_policy("fill"): + self._check_recovery(mvn, samples) - likelihood.MISSING_VALUE_FILL = 999.0 - like_init_plus = likelihood.log_marginal(samples, mvn).sum().data + def test_missing_value_inference_mask(self): + """ + samples = mvn samples + noise samples + In this test, we try to recover noise parameters when some elements in + 'samples' are missing at random. + """ - likelihood.MISSING_VALUE_FILL = -999.0 - like_init_minus = likelihood.log_marginal(samples, mvn).sum().data + torch.manual_seed(self.seed) + + mvn, samples = self._make_data() - torch.testing.assert_close(like_init_plus, like_init_minus) + missing_prop = 0.33 + missing_idx = torch.distributions.Binomial(1, missing_prop).sample(samples.shape[1:]).bool() + samples[1, missing_idx] = float("nan") # check that the correct noise sd is recovered - opt = torch.optim.Adam(likelihood.parameters(), lr=0.05) + with settings.observation_nan_policy("fill"): + self._check_recovery(mvn, samples) + def _make_data(self): + mu = torch.zeros(2, 3) + sigma = torch.tensor([[[1, 0.999, -0.999], [0.999, 1, -0.999], [-0.999, -0.999, 1]]] * 2).float() + mvn = MultivariateNormal(mu, sigma) + samples = mvn.sample(torch.Size([10000])) # mvn samples + noise_sd = 0.5 + noise_dist = torch.distributions.Normal(0, noise_sd) + samples += noise_dist.sample(samples.shape) # noise + return mvn, samples + + def _check_recovery(self, mvn, samples): + likelihood = GaussianLikelihood() + opt = torch.optim.Adam(likelihood.parameters(), lr=0.05) for _ in range(100): opt.zero_grad() loss = -likelihood.log_marginal(samples, mvn).sum() loss.backward() opt.step() - - assert abs(float(likelihood.noise.sqrt()) - 0.5) < 0.02 - + self.assertTrue(abs(float(likelihood.noise.sqrt()) - 0.5) < 0.02) # Check log marginal works likelihood.log_marginal(samples[0], mvn)