Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support for missing labels #2288

Merged
merged 31 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6bf4ba1
Fix prediction with NaN values in training labels
Turakar Mar 1, 2023
1c4e19c
Missing observation support for multitask and allow MultivariateMulti…
Turakar Mar 2, 2023
d0cf651
Fix error in MultitaskMultivariateNormal indexing on '...'
Turakar Mar 2, 2023
2c4ac2f
Fix indexing with negative values
Turakar Mar 2, 2023
03f2dd2
Add tests
Turakar Mar 3, 2023
270216b
Render docs for MultitaskMultivariateNormal indexing and missing obse…
Turakar Mar 5, 2023
954bed9
Fix docs warning
Turakar Mar 5, 2023
5fee76f
Fix docstring
Turakar Mar 6, 2023
cfa5435
Finally fix docstring
Turakar Mar 8, 2023
2531588
Change missing data handling to option flag
Turakar Mar 8, 2023
e7fca20
Revamp missing value implementation
Turakar Mar 10, 2023
31ed8dd
Fix Python version incompatibility
Turakar Mar 10, 2023
8546fd3
Increase atol on variational tests
Turakar Mar 10, 2023
1ac3b8e
Merge branch 'master' into missing-data
Turakar Mar 10, 2023
605db40
Merge branch 'master' into missing-data
Turakar Mar 15, 2023
e7b10cd
Merge branch 'master' into missing-data
Turakar Mar 16, 2023
6ab4e55
Add ExactMarginalLogLikelihoodWithMissingObs back with deprecation wa…
Turakar Mar 16, 2023
a55470a
Merge branch 'master' into missing-data
Turakar Apr 21, 2023
e2713ac
Add warning if kernel matrix is made dense
Turakar Apr 21, 2023
a67b40b
Fix docs
Turakar Apr 21, 2023
49fc2f4
Add quick path for noop slice indices
Turakar Apr 21, 2023
e8ecbef
Add test for noop slice indexing
Turakar Apr 21, 2023
4706477
Fix docs
Turakar Apr 21, 2023
e8788c8
Switch to MaskedLinearOperator
Turakar Jun 7, 2023
bd75162
Merge branch 'master' into missing-data
Turakar Jun 8, 2023
866bb6b
Merge branch 'master' into missing-data
Turakar Aug 8, 2023
522cbcf
Switch to MaskedLinearOperator from linear-operator 0.5.1
Turakar Aug 8, 2023
15ad542
Disable test_t_matmul_matrix() for LazyEvaluatedKernelTensor
Turakar Aug 8, 2023
c7677a1
Merge branch 'master' into missing-data
Turakar Aug 16, 2023
1034955
Fix merge conflict
Turakar Aug 16, 2023
6310aff
Merge branch 'master' into missing-data
Turakar Sep 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ MultitaskMultivariateNormal

.. autoclass:: MultitaskMultivariateNormal
:members:
:special-members: __getitem__


Delta
Expand Down
2 changes: 1 addition & 1 deletion docs/source/likelihoods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ reduce the variance when computing approximate GP objective functions.
:members:

:hidden:`GaussianLikelihoodWithMissingObs`
Turakar marked this conversation as resolved.
Show resolved Hide resolved
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: GaussianLikelihoodWithMissingObs
:members:
Expand Down
149 changes: 147 additions & 2 deletions gpytorch/distributions/multitask_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

import torch
from linear_operator import LinearOperator, to_linear_operator
from linear_operator.operators import BlockDiagLinearOperator, BlockInterleavedLinearOperator, CatLinearOperator
from linear_operator.operators import (
BlockDiagLinearOperator,
BlockInterleavedLinearOperator,
CatLinearOperator,
DiagLinearOperator,
)

from .multivariate_normal import MultivariateNormal

Expand All @@ -18,7 +23,7 @@ class MultitaskMultivariateNormal(MultivariateNormal):
:param torch.Tensor mean: An `n x t` or batch `b x n x t` matrix of means for the MVN distribution.
:param ~linear_operator.operators.LinearOperator covar: An `... x NT x NT` (batch) matrix.
covariance matrix of MVN distribution.
:param bool validate_args: (default=False) If True, validate `mean` anad `covariance_matrix` arguments.
:param bool validate_args: (default=False) If True, validate `mean` and `covariance_matrix` arguments.
:param bool interleaved: (default=True) If True, covariance matrix is interpreted as block-diagonal w.r.t.
inter-task covariances for each observation. If False, it is interpreted as block-diagonal
w.r.t. inter-observation covariance for each task.
Expand Down Expand Up @@ -276,5 +281,145 @@ def variance(self):
return var.view(new_shape).transpose(-1, -2).contiguous()
return var.view(self._output_shape)

def __getitem__(self, idx) -> MultivariateNormal:
"""
Constructs a new MultivariateNormal that represents a random variable
modified by an indexing operation.

The mean and covariance matrix arguments are indexed accordingly.

:param Any idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
:returns: If indices specify a slice for samples and tasks, returns a
MultitaskMultivariateNormal, else returns a MultivariateNormal.
"""

# Normalize index to a tuple
if not isinstance(idx, tuple):
idx = (idx,)

if ... in idx:
gpleiss marked this conversation as resolved.
Show resolved Hide resolved
Turakar marked this conversation as resolved.
Show resolved Hide resolved
# Replace ellipsis '...' with explicit indices
ellipsis_location = idx.index(...)
if ... in idx[ellipsis_location + 1 :]:
raise IndexError("Only one ellipsis '...' is supported!")
prefix = idx[:ellipsis_location]
suffix = idx[ellipsis_location + 1 :]
infix_length = self.mean.dim() - len(prefix) - len(suffix)
if infix_length < 0:
raise IndexError(f"Index {idx} has too many dimensions")
idx = prefix + (slice(None),) * infix_length + suffix
elif len(idx) == self.mean.dim() - 1:
# Normalize indices ignoring the task-index to include it
idx = idx + (slice(None),)

new_mean = self.mean[idx]

# We now create a covariance matrix appropriate for new_mean
if len(idx) <= self.mean.dim() - 2:
# We are only indexing the batch dimensions in this case
return MultitaskMultivariateNormal(
mean=new_mean,
covariance_matrix=self.lazy_covariance_matrix[idx],
interleaved=self._interleaved,
)
elif len(idx) > self.mean.dim():
raise IndexError(f"Index {idx} has too many dimensions")
else:
# We have an index that extends over all dimensions
batch_idx = idx[:-2]
if self._interleaved:
row_idx = idx[-2]
col_idx = idx[-1]
num_rows = self._output_shape[-2]
num_cols = self._output_shape[-1]
else:
row_idx = idx[-1]
col_idx = idx[-2]
num_rows = self._output_shape[-1]
num_cols = self._output_shape[-2]

if isinstance(row_idx, int) and isinstance(col_idx, int):
# Single sample with single task
row_idx = _normalize_index(row_idx, num_rows)
col_idx = _normalize_index(col_idx, num_cols)
new_cov = DiagLinearOperator(
self.lazy_covariance_matrix.diagonal()[batch_idx + (row_idx * num_cols + col_idx,)]
)
return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov)
elif isinstance(row_idx, int) and isinstance(col_idx, slice):
# A block of the covariance matrix
row_idx = _normalize_index(row_idx, num_rows)
col_idx = _normalize_slice(col_idx, num_cols)
new_slice = slice(
col_idx.start + row_idx * num_cols,
col_idx.stop + row_idx * num_cols,
col_idx.step,
)
new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)]
return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov)
elif isinstance(row_idx, slice) and isinstance(col_idx, int):
# A block of the reversely interleaved covariance matrix
row_idx = _normalize_slice(row_idx, num_rows)
col_idx = _normalize_index(col_idx, num_cols)
new_slice = slice(row_idx.start + col_idx, row_idx.stop * num_cols + col_idx, row_idx.step * num_cols)
new_cov = self.lazy_covariance_matrix[batch_idx + (new_slice, new_slice)]
return MultivariateNormal(mean=new_mean, covariance_matrix=new_cov)
elif (
isinstance(row_idx, slice)
and isinstance(col_idx, slice)
and row_idx == col_idx == slice(None, None, None)
):
new_cov = self.lazy_covariance_matrix[batch_idx]
return MultitaskMultivariateNormal(
mean=new_mean,
covariance_matrix=new_cov,
interleaved=self._interleaved,
validate_args=False,
)
elif isinstance(row_idx, slice) or isinstance(col_idx, slice):
# slice x slice or indices x slice or slice x indices
if isinstance(row_idx, slice):
row_idx = torch.arange(num_rows)[row_idx]
if isinstance(col_idx, slice):
col_idx = torch.arange(num_cols)[col_idx]
row_grid, col_grid = torch.meshgrid(row_idx, col_idx, indexing="ij")
gpleiss marked this conversation as resolved.
Show resolved Hide resolved
indices = (row_grid * num_cols + col_grid).reshape(-1)
new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices]
return MultitaskMultivariateNormal(
mean=new_mean, covariance_matrix=new_cov, interleaved=self._interleaved, validate_args=False
)
else:
# row_idx and col_idx have pairs of indices
indices = row_idx * num_cols + col_idx
new_cov = self.lazy_covariance_matrix[batch_idx + (indices,)][..., indices]
return MultivariateNormal(
mean=new_mean,
covariance_matrix=new_cov,
)

def __repr__(self) -> str:
return f"MultitaskMultivariateNormal(mean shape: {self._output_shape})"


def _normalize_index(i: int, dim_size: int) -> int:
if i < 0:
return dim_size + i
else:
return i


def _normalize_slice(s: slice, dim_size: int) -> slice:
start = s.start
if start is None:
start = 0
elif start < 0:
start = dim_size + start
stop = s.stop
if stop is None:
stop = dim_size
elif stop < 0:
stop = dim_size + stop
step = s.step
if step is None:
step = 1
return slice(start, stop, step)
8 changes: 7 additions & 1 deletion gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,16 +343,22 @@ def __getitem__(self, idx) -> MultivariateNormal:

The mean and covariance matrix arguments are indexed accordingly.

:param idx: Index to apply.
:param idx: Index to apply to the mean. The covariance matrix is indexed accordingly.
"""

if not isinstance(idx, tuple):
idx = (idx,)
if len(idx) > self.mean.dim() and Ellipsis in idx:
idx = tuple(i for i in idx if i != Ellipsis)
if len(idx) < self.mean.dim():
raise IndexError("Multiple ambiguous ellipsis in index!")

rest_idx = idx[:-1]
last_idx = idx[-1]
new_mean = self.mean[idx]

if len(idx) <= self.mean.dim() - 1 and (Ellipsis not in rest_idx):
# We are only indexing the batch dimensions in this case
new_cov = self.lazy_covariance_matrix[idx]
elif len(idx) > self.mean.dim():
raise IndexError(f"Index {idx} has too many dimensions")
Expand Down
63 changes: 55 additions & 8 deletions gpytorch/likelihoods/gaussian_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#!/usr/bin/env python3

import math
import warnings
from copy import deepcopy
from typing import Any, Optional, Tuple, Union

import torch
from linear_operator.operators import LinearOperator, ZeroLinearOperator
from linear_operator.operators import LinearOperator, MaskedLinearOperator, ZeroLinearOperator
from torch import Tensor
from torch.distributions import Distribution, Normal

from .. import settings
from ..constraints import Interval
from ..distributions import base_distributions, MultivariateNormal
from ..priors import Prior
Expand Down Expand Up @@ -39,17 +39,39 @@ def _shaped_noise_covar(self, base_shape: torch.Size, *params: Any, **kwargs: An
return self.noise_covar(*params, shape=base_shape, **kwargs)

def expected_log_prob(self, target: Tensor, input: MultivariateNormal, *params: Any, **kwargs: Any) -> Tensor:
mean, variance = input.mean, input.variance
num_event_dim = len(input.event_shape)

noise = self._shaped_noise_covar(mean.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2)
noise = self._shaped_noise_covar(input.mean.shape, *params, **kwargs).diagonal(dim1=-1, dim2=-2)
# Potentially reshape the noise to deal with the multitask case
noise = noise.view(*noise.shape[:-1], *input.event_shape)

# Handle NaN values if enabled
nan_policy = settings.observation_nan_policy.value()
if nan_policy == "mask":
observed = settings.observation_nan_policy._get_observed(target, input.event_shape)
input = MultivariateNormal(
mean=input.mean[..., observed],
covariance_matrix=MaskedLinearOperator(
input.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
),
)
noise = noise[..., observed]
target = target[..., observed]
elif nan_policy == "fill":
missing = torch.isnan(target)
target = settings.observation_nan_policy._fill_tensor(target)

mean, variance = input.mean, input.variance
res = ((target - mean).square() + variance) / noise + noise.log() + math.log(2 * math.pi)
res = res.mul(-0.5)
if num_event_dim > 1: # Do appropriate summation for multitask Gaussian likelihoods

if nan_policy == "fill":
res = res * ~missing

# Do appropriate summation for multitask Gaussian likelihoods
num_event_dim = len(input.event_shape)
if num_event_dim > 1:
res = res.sum(list(range(-1, -num_event_dim, -1)))

return res

def forward(self, function_samples: Tensor, *params: Any, **kwargs: Any) -> Normal:
Expand All @@ -60,12 +82,31 @@ def log_marginal(
self, observations: Tensor, function_dist: MultivariateNormal, *params: Any, **kwargs: Any
) -> Tensor:
marginal = self.marginal(function_dist, *params, **kwargs)

# Handle NaN values if enabled
nan_policy = settings.observation_nan_policy.value()
if nan_policy == "mask":
observed = settings.observation_nan_policy._get_observed(observations, marginal.event_shape)
marginal = MultivariateNormal(
mean=marginal.mean[..., observed],
covariance_matrix=MaskedLinearOperator(
marginal.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
),
)
observations = observations[..., observed]
elif nan_policy == "fill":
missing = torch.isnan(observations)
observations = settings.observation_nan_policy._fill_tensor(observations)

# We're making everything conditionally independent
indep_dist = base_distributions.Normal(marginal.mean, marginal.variance.clamp_min(1e-8).sqrt())
res = indep_dist.log_prob(observations)

if nan_policy == "fill":
res = res * ~missing

# Do appropriate summation for multitask Gaussian likelihoods
num_event_dim = len(function_dist.event_shape)
num_event_dim = len(marginal.event_shape)
if num_event_dim > 1:
res = res.sum(list(range(-1, -num_event_dim, -1)))
return res
Expand Down Expand Up @@ -150,13 +191,15 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood):
.. note::
This likelihood can be used for exact or approximate inference.

.. warning::
This likelihood is deprecated in favor of :class:`gpytorch.settings.observation_nan_policy`.

:param noise_prior: Prior for noise parameter :math:`\sigma^2`.
:type noise_prior: ~gpytorch.priors.Prior, optional
:param noise_constraint: Constraint for noise parameter :math:`\sigma^2`.
:type noise_constraint: ~gpytorch.constraints.Interval, optional
:param batch_shape: The batch shape of the learned noise parameter (default: []).
:type batch_shape: torch.Size, optional

:var torch.Tensor noise: :math:`\sigma^2` parameter (noise)

.. note::
Expand All @@ -166,6 +209,10 @@ class GaussianLikelihoodWithMissingObs(GaussianLikelihood):
MISSING_VALUE_FILL: float = -999.0

def __init__(self, **kwargs: Any) -> None:
warnings.warn(
"GaussianLikelihoodWithMissingObs is replaced by gpytorch.settings.observation_nan_policy('fill').",
DeprecationWarning,
)
super().__init__(**kwargs)

def _get_masked_obs(self, x: Tensor) -> Tuple[Tensor, Tensor]:
Expand Down
20 changes: 19 additions & 1 deletion gpytorch/mlls/exact_marginal_log_likelihood.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#!/usr/bin/env python3

from linear_operator.operators import MaskedLinearOperator

from .. import settings
from ..distributions import MultivariateNormal
from ..likelihoods import _GaussianLikelihoodBase
from .marginal_log_likelihood import MarginalLogLikelihood
Expand Down Expand Up @@ -59,8 +62,23 @@ def forward(self, function_dist, target, *params):
if not isinstance(function_dist, MultivariateNormal):
raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables")

# Get the log prob of the marginal distribution
# Determine output likelihood
output = self.likelihood(function_dist, *params)

gpleiss marked this conversation as resolved.
Show resolved Hide resolved
# Remove NaN values if enabled
if settings.observation_nan_policy.value() == "mask":
observed = settings.observation_nan_policy._get_observed(target, output.event_shape)
output = MultivariateNormal(
mean=output.mean[..., observed],
covariance_matrix=MaskedLinearOperator(
output.lazy_covariance_matrix, observed.reshape(-1), observed.reshape(-1)
),
)
target = target[..., observed]
elif settings.observation_nan_policy.value() == "fill":
raise ValueError("NaN observation policy 'fill' is not supported by ExactMarginalLogLikelihood!")

# Get the log prob of the marginal distribution
res = output.log_prob(target)
res = self._add_other_terms(res, params)

Expand Down
Loading