Skip to content

Commit

Permalink
Better support for missing labels (#2288)
Browse files Browse the repository at this point in the history
* Fix prediction with NaN values in training labels

* Missing observation support for multitask and allow MultivariateMultitaskNormal indexing

* Fix error in MultitaskMultivariateNormal indexing on '...'

* Fix indexing with negative values

* Add tests

- Indexing MultitaskMultivariateNormal
- Missing data in single-task and multitask models

* Render docs for MultitaskMultivariateNormal indexing and missing observations

* Fix docs warning

* Fix docstring

* Finally fix docstring

* Change missing data handling to option flag

* Revamp missing value implementation

- Enable via gpytorch.settings
- Two modes: 'mask' and 'fill'
- Makes GaussianLikelihoodWithMissingObs obsolete
- Supports approximate GPs

* Fix Python version incompatibility

* Increase atol on variational tests

* Add ExactMarginalLogLikelihoodWithMissingObs back with deprecation warning

* Add warning if kernel matrix is made dense

* Fix docs

* Add quick path for noop slice indices

* Add test for noop slice indexing

* Fix docs

* Switch to MaskedLinearOperator

* Switch to MaskedLinearOperator from linear-operator 0.5.1

* Disable test_t_matmul_matrix() for LazyEvaluatedKernelTensor

The test fails because LazyEvaluatedKernelTensor only supports _matmul() with checkpointing, but checkpointing is deprecated.

* Fix merge conflict
  • Loading branch information
Turakar authored Sep 7, 2023
1 parent 5e93d2c commit 981edd8
Show file tree
Hide file tree
Showing 12 changed files with 769 additions and 49 deletions.
1 change: 1 addition & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ MultitaskMultivariateNormal

.. autoclass:: MultitaskMultivariateNormal
:members:
:special-members: __getitem__


Delta
Expand Down
2 changes: 1 addition & 1 deletion docs/source/likelihoods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ reduce the variance when computing approximate GP objective functions.
:members:

:hidden:`GaussianLikelihoodWithMissingObs`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: GaussianLikelihoodWithMissingObs
:members:
Expand Down
149 changes: 147 additions & 2 deletions gpytorch/distributions/multitask_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import torch
from linear_operator import LinearOperator, to_linear_operator
from linear_operator.operators import BlockDiagLinearOperator, BlockInterleavedLinearOperator, CatLinearOperator
from linear_operator.operators import (
BlockDiagLinearOperator,
BlockInterleavedLinearOperator,
CatLinearOperator,
DiagLinearOperator,
)

from .multivariate_normal import MultivariateNormal

Expand All @@ -18,7 +23,7 @@ class MultitaskMultivariateNormal(MultivariateNormal):
:param torch.Tensor mean: An `n x t` or batch `b x n x t` matrix of means for the MVN distribution.
:param ~linear_operator.operators.LinearOperator covar: An `... x NT x NT` (batch) matrix.
covariance matrix of MVN distribution.
:param bool validate_args: (default=False) If True, validate `mean` anad `covariance_matrix` arguments.
:param bool validate_args: (default=False) If True, validate `mean` and `covariance_matrix` arguments.
:param bool interleaved: (default=True) If True, covariance matrix is interpreted as block-diagonal w.r.t.
inter-task covariances for each observation. If False, it is interpreted as block-diagonal
w.r.t. inter-observation covariance for each task.
Expand Down Expand Up @@ -276,5 +281,145 @@ def variance(self):
return var.view(new_shape).transpose(-1, -2).contiguous()
return var.view(self._output_shape)

def __getitem__(self, idx) -> MultivariateNormal:
"""
Constructs a new MultivariateNormal that represents a random variable
modified by an indexing operation.
The mean and covariance matrix arguments are indexed accordingly.
:param Any idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
:returns: If indices specify a slice for samples and tasks, returns a
MultitaskMultivariateNormal, else returns a MultivariateNormal.
"""

# Normalize index to a tuple
if not isinstance(idx, tuple):
idx = (idx,)

if ... in idx:
# Replace ellipsis '...' with explicit indices
ellipsis_location = idx.index(...)
if ... in idx[ellipsis_location + 1 :]:
raise IndexError("Only one ellipsis '...' is supported!")
prefix = idx[:ellipsis_location]
suffix = idx[ellipsis_location + 1 :]
infix_length = self.mean.dim() - len(prefix) - len(suffix)
if infix_length < 0:
raise IndexError(f"Index {idx} has too many dimensions")
idx = prefix + (slice(None),) * infix_length + suffix
elif len(idx) == self.mean.dim() - 1:
# Normalize indices ignoring the task-index to include it
idx = idx + (slice(None),)

new_mean = self.mean[idx]

# We now create a covariance matrix appropriate for new_mean
if len(idx) <= self.mean.dim() - 2:
# We are only indexing the batch dimensions in this case
return MultitaskMultivariateNormal(
mean=new_mean,
covariance_matrix=self.lazy_covariance_matrix[idx],
interleaved=self._interleaved,
)
elif len(idx) > self.mean.dim():
raise IndexError(f"Index {idx} has too many dimensions")
else:
# We have an index that extends over all dimensions
batch_idx = idx[:-2]
if self._interleaved:
row_idx = idx[-2]
col_idx = idx[-1]
num_rows = self._output_shape[-2]
num_cols = self._output_shape[-1]
else:
row_idx = idx[-1]
col_idx = idx[-2]
num_rows = self._output_shape[-1]
num_cols = self._output_shape[-2]

if isinstance(row_idx, int) and isinstance(col_idx, int):
# Single sample with single task
row_idx = _normalize_index(row_idx, num_rows)
col_idx = _normalize_index(col_idx, num_cols)
new_cov = DiagLinearOperator(
self.lazy_covariance_matrix.diagonal()[batch_idx + (row_idx * num_cols + col_idx,)]
)
return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov)
elif isinstance(row_idx, int) and isinstance(col_idx, slice):
# A block of the covariance matrix
row_idx = _normalize_index(row_idx, num_rows)
col_idx = _normalize_slice(col_idx, num_cols)
new_slice = slice(
col_idx.start + row_idx * num_cols,
col_idx.stop + row_idx * num_cols,
col_idx.step,
)
new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)]
return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov)
elif isinstance(row_idx, slice) and isinstance(col_idx, int):
# A block of the reversely interleaved covariance matrix
row_idx = _normalize_slice(row_idx, num_rows)
col_idx = _normalize_index(col_idx, num_cols)
new_slice = slice(row_idx.start + col_idx, row_idx.stop * num_cols + col_idx, row_idx.step * num_cols)
new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)]
return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov)
elif (
isinstance(row_idx, slice)
and isinstance(col_idx, slice)
and row_idx == col_idx == slice(None, None, None)
):
new_cov = self.lazy_covariance_matrix[batch_idx]
return MultitaskMultivariateNormal(
mean=new_mean,
covariance_matrix=new_cov,
interleaved=self._interleaved,
validate_args=False,
)
elif isinstance(row_idx, slice) or isinstance(col_idx, slice):
# slice x slice or indices x slice or slice x indices
if isinstance(row_idx, slice):
row_idx = torch.arange(num_rows)[row_idx]
if isinstance(col_idx, slice):
col_idx = torch.arange(num_cols)[col_idx]
row_grid, col_grid = torch.meshgrid(row_idx, col_idx, indexing="ij")
indices = (row_grid * num_cols + col_grid).reshape(-1)
new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices]
return MultitaskMultivariateNormal(
mean=new_mean, covariance_matrix=new_cov, interleaved=self._interleaved, validate_args=False
)
else:
# row_idx and col_idx have pairs of indices
indices = row_idx * num_cols + col_idx
new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices]
return MultivariateNormal(
mean=new_mean,
covariance_matrix=new_cov,
)

def __repr__(self) -> str:
return f"MultitaskMultivariateNormal(mean shape: {self._output_shape})"


def _normalize_index(i: int, dim_size: int) -> int:
if i < 0:
return dim_size + i
else:
return i


def _normalize_slice(s: slice, dim_size: int) -> slice:
start = s.start
if start is None:
start = 0
elif start < 0:
start = dim_size + start
stop = s.stop
if stop is None:
stop = dim_size
elif stop < 0:
stop = dim_size + stop
step = s.step
if step is None:
step = 1
return slice(start, stop, step)
8 changes: 7 additions & 1 deletion gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,16 +343,22 @@ def __getitem__(self, idx) -> MultivariateNormal:
The mean and covariance matrix arguments are indexed accordingly.
:param idx: Index to apply.
:param idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
"""

if not isinstance(idx, tuple):
idx = (idx,)
if len(idx) > self.mean.dim() and Ellipsis in idx:
idx = tuple(i for i in idx if i != Ellipsis)
if len(idx) < self.mean.dim():
raise IndexError("Multiple ambiguous ellipsis in index!")

rest_idx = idx[:-1]
last_idx = idx[-1]
new_mean = self.mean[idx]

if len(idx) <= self.mean.dim() - 1 and (Ellipsis not in rest_idx):
# We are only indexing the batch dimensions in this case
new_cov = self.lazy_covariance_matrix[idx]
elif len(idx) > self.mean.dim():
raise IndexError(f"Index {idx} has too many dimensions")
Expand Down
63 changes: 55 additions & 8 deletions gpytorch/likelihoods/gaussian_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#!/usr/bin/env python3

import math
import warnings
from copy import deepcopy
from typing import Any, Optional, Tuple, Union

import torch
from linear_operator.operators import LinearOperator, ZeroLinearOperator
from linear_operator.operators import LinearOperator, MaskedLinearOperator, ZeroLinearOperator
from torch import Tensor
from torch.distributions import Distribution, Normal

from .. import settings
from ..constraints import Interval
from ..distributions import base_distributions, MultivariateNormal
from ..priors import Prior
Expand Down Expand Up @@ -39,17 +39,39 @@ def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: An
return self.noise_covar(*params, shape=base_shape, **kwargs)

def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any) -> Tensor:
mean, variance = input.mean, input.variance
num_event_dim = len(input.event_shape)

noise = self._shaped_noise_covar(mean.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2)
noise = self._shaped_noise_covar(input.mean.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2)
# Potentially reshape the noise to deal with the multitask case
noise = noise.view(*noise.shape[:-1], *input.event_shape)

# Handle NaN values if enabled
nan_policy = settings.observation_nan_policy.value()
if nan_policy == "mask":
observed = settings.observation_nan_policy._get_observed(target, input.event_shape)
input = MultivariateNormal(
mean=input.mean[..., observed],
covariance_matrix=MaskedLinearOperator(
input.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
),
)
noise = noise[..., observed]
target = target[..., observed]
elif nan_policy == "fill":
missing = torch.isnan(target)
target = settings.observation_nan_policy._fill_tensor(target)

mean, variance = input.mean, input.variance
res = ((target - mean).square() + variance) / noise + noise.log() + math.log(2 * math.pi)
res = res.mul(-0.5)
if num_event_dim > 1: # Do appropriate summation for multitask Gaussian likelihoods

if nan_policy == "fill":
res = res * ~missing

# Do appropriate summation for multitask Gaussian likelihoods
num_event_dim = len(input.event_shape)
if num_event_dim > 1:
res = res.sum(list(range(-1, -num_event_dim, -1)))

return res

def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> Normal:
Expand All @@ -60,12 +82,31 @@ def log_marginal(
self, observations: Tensor, function_dist: MultivariateNormal, *params: Any, **kwargs: Any
) -> Tensor:
marginal = self.marginal(function_dist, *params, **kwargs)

# Handle NaN values if enabled
nan_policy = settings.observation_nan_policy.value()
if nan_policy == "mask":
observed = settings.observation_nan_policy._get_observed(observations, marginal.event_shape)
marginal = MultivariateNormal(
mean=marginal.mean[..., observed],
covariance_matrix=MaskedLinearOperator(
marginal.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
),
)
observations = observations[..., observed]
elif nan_policy == "fill":
missing = torch.isnan(observations)
observations = settings.observation_nan_policy._fill_tensor(observations)

# We're making everything conditionally independent
indep_dist = base_distributions.Normal(marginal.mean, marginal.variance.clamp_min(1e-8).sqrt())
res = indep_dist.log_prob(observations)

if nan_policy == "fill":
res = res * ~missing

# Do appropriate summation for multitask Gaussian likelihoods
num_event_dim = len(function_dist.event_shape)
num_event_dim = len(marginal.event_shape)
if num_event_dim > 1:
res = res.sum(list(range(-1, -num_event_dim, -1)))
return res
Expand Down Expand Up @@ -150,13 +191,15 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood):
.. note::
This likelihood can be used for exact or approximate inference.
.. warning::
This likelihood is deprecated in favor of :class:`gpytorch.settings.observation_nan_policy`.
:param noise_prior: Prior for noise parameter :math:`\sigma^2`.
:type noise_prior: ~gpytorch.priors.Prior, optional
:param noise_constraint: Constraint for noise parameter :math:`\sigma^2`.
:type noise_constraint: ~gpytorch.constraints.Interval, optional
:param batch_shape: The batch shape of the learned noise parameter (default: []).
:type batch_shape: torch.Size, optional
:var torch.Tensor noise: :math:`\sigma^2` parameter (noise)
.. note::
Expand All @@ -166,6 +209,10 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood):
MISSING_VALUE_FILL: float = -999.0

def __init__(self, **kwargs: Any) -> None:
warnings.warn(
"GaussianLikelihoodWithMissingObs is replaced by gpytorch.settings.observation_nan_policy('fill').",
DeprecationWarning,
)
super().__init__(**kwargs)

def _get_masked_obs(self, x: Tensor) -> Tuple[Tensor, Tensor]:
Expand Down
20 changes: 19 additions & 1 deletion gpytorch/mlls/exact_marginal_log_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#!/usr/bin/env python3

from linear_operator.operators import MaskedLinearOperator

from .. import settings
from ..distributions import MultivariateNormal
from ..likelihoods import _GaussianLikelihoodBase
from .marginal_log_likelihood import MarginalLogLikelihood
Expand Down Expand Up @@ -59,8 +62,23 @@ def forward(self, function_dist, target, *params):
if not isinstance(function_dist, MultivariateNormal):
raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables")

# Get the log prob of the marginal distribution
# Determine output likelihood
output = self.likelihood(function_dist, *params)

# Remove NaN values if enabled
if settings.observation_nan_policy.value() == "mask":
observed = settings.observation_nan_policy._get_observed(target, output.event_shape)
output = MultivariateNormal(
mean=output.mean[..., observed],
covariance_matrix=MaskedLinearOperator(
output.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
),
)
target = target[..., observed]
elif settings.observation_nan_policy.value() == "fill":
raise ValueError("NaN observation policy 'fill' is not supported by ExactMarginalLogLikelihood!")

# Get the log prob of the marginal distribution
res = output.log_prob(target)
res = self._add_other_terms(res, params)

Expand Down
Loading

0 comments on commit 981edd8

Please sign in to comment.