-
Notifications
You must be signed in to change notification settings - Fork 562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better support for missing labels #2288
Conversation
…taskNormal indexing
If we wanted, we could probably incorporate the missing obs in to the existing ExactMarginalLogLikelihood via a
Probably yes, it should already be saving you. All we'd need to check is that the output of the likelihood call has a output_dist = model(train_x)
mll = exact_mll(output_dist, train_y) And you'd expect output_dist to have a |
Does someone of you have an idea what the problem with the docstring is? |
Ok, I like the idea of adding it as an option flag.
This is a |
- Enable via gpytorch.settings - Two modes: 'mask' and 'fill' - Makes GaussianLikelihoodWithMissingObs obsolete - Supports approximate GPs
I reworked large parts of this PR. It should be ready for review now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Turakar for the awesome (and very thorough) PR! I'm excited to get this merged in.
See below for some comments. Mostly nit-picking, but one or two questions about performance.
I created some initial benchmarks. This is the benchmark code: Benchmark codeimport copy
import gc
import math
import time
import gpytorch.settings
import torch
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP
from plotly.subplots import make_subplots
import plotly.graph_objs as go
from torch import Tensor
from tqdm import tqdm
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
class Model(ExactGP):
def __init__(self, train_x: Tensor, train_y: Tensor):
super().__init__(train_x, train_y, GaussianLikelihood())
self.mean_module = ConstantMean()
self.covar_module = ScaleKernel(RBFKernel())
def forward(self, x: Tensor) -> MultivariateNormal:
return MultivariateNormal(
self.mean_module(x),
self.covar_module(x),
)
def make_dataset(train_num: int, train_missing: float, val_num: int, val_missing: float) -> tuple[Tensor, Tensor, Tensor, Tensor]:
train_x = torch.linspace(0, 1, train_num, device=device)
train_y = torch.sin(2 * torch.pi * train_x)
val_x = torch.linspace(0, 1, val_num, device=device)
val_y = torch.sin(2 * torch.pi * val_x)
# Randomly mask out some data
if train_missing > 0:
train_mask = torch.bernoulli(torch.full_like(train_y, train_missing)).to(torch.bool)
train_y[train_mask] = torch.nan
if val_missing > 0:
val_mask = torch.bernoulli(torch.full_like(val_y, val_missing)).to(torch.bool)
val_y[val_mask] = torch.nan
train_x = train_x.unsqueeze(-1)
val_x = val_x.unsqueeze(-1)
return train_x, train_y, val_x, val_y
def prepare_model(train_x: Tensor, train_y: Tensor, steps: int) -> Model:
model = Model(train_x, train_y).to(device)
model.train()
# Pre-train model s.t. we have realistic values and convergence times
mll = ExactMarginalLogLikelihood(model.likelihood, model).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for _ in range(steps):
optimizer.zero_grad(set_to_none=True)
loss = -mll(model(*model.train_inputs), model.train_targets)
loss.backward()
optimizer.step()
model.cpu()
return model
def measure(model: Model, train_x: Tensor, train_y: Tensor, val_x: Tensor) -> tuple[float, float]:
# Create a copy of the model with the new training data
model = copy.deepcopy(model)
model.to(device)
model.set_train_data(train_x, train_y, strict=False)
# Simulate training step
mll = ExactMarginalLogLikelihood(model.likelihood, model).to(device)
prior_time = time.time()
loss = mll(model(*model.train_inputs), model.train_targets)
loss.backward()
train_time = time.time() - prior_time
# Simulate prediction
with torch.no_grad():
prior_time = time.time()
model.eval()
prediction: MultivariateNormal = model.likelihood(model(val_x))
mean = prediction.mean
covar = prediction.covariance_matrix
val_time = time.time() - prior_time
return train_time, val_time
def measure_multiple(models: list[Model], train_x: Tensor, train_y: Tensor, val_x: Tensor) -> tuple[float, float, float, float]:
train_times = []
val_times = []
for model in tqdm(models, desc="Collecting", leave=False):
train_time, val_time = measure(model, train_x, train_y, val_x)
train_times.append(train_time)
val_times.append(val_time)
gc.collect()
train_times = torch.tensor(train_times)
train_mean = torch.mean(train_times).item()
train_sem = torch.std(train_times).item() / math.sqrt(len(models))
val_times = torch.tensor(val_times)
val_mean = torch.mean(val_times).item()
val_sem = torch.std(val_times).item() / math.sqrt(len(models))
return train_mean, train_sem, val_mean, val_sem
def main():
with gpytorch.settings.max_cholesky_size(0):
n = 8000
iterations = 50
train_steps = 50
nan_fractions = [x.item() for x in torch.linspace(0, 0.5, 12)]
# Prepare some models
# We will use the same models for each NaN fraction and just change their training datasets.
sample_x, sample_y, _, _ = make_dataset(n, 0, n // 10, 0)
models = []
for _ in tqdm(range(iterations), desc="Preparing models"):
models.append(prepare_model(sample_x, sample_y, train_steps))
# Collect measurements
measurements = []
for nan_fraction in tqdm(nan_fractions, desc="NaN fractions"):
with gpytorch.settings.observation_nan_policy("mask" if nan_fraction > 0 else "ignore"):
nan_n = int(n / (1 - nan_fraction)) # Scale n s.t. we have an equal amount of observed data
train_x, train_y, val_x, val_y = make_dataset(nan_n, nan_fraction, nan_n // 10, nan_fraction)
measurements.append(list(measure_multiple(models, train_x, train_y, val_x)))
measurements = torch.tensor(measurements)
# Create a plot showing the mean and std. error of the mean for both training and prediction
fig = make_subplots(rows=1, cols=2, column_titles=["Training step", "Prediction step"])
fig.add_trace(go.Scatter(
x=nan_fractions,
y=measurements[:, 0],
error_y=dict(
type="data",
array=measurements[:, 1],
),
mode="lines",
), row=1, col=1)
fig.add_trace(go.Scatter(
x=nan_fractions,
y=measurements[:, 2],
error_y=dict(
type="data",
array=measurements[:, 3],
),
mode="lines",
), row=1, col=2)
fig.update_layout(
title="NaN masking performance for simple RBF model",
showlegend=False,
)
fig.update_xaxes(title="NaN fraction")
fig.update_yaxes(title="Time per step (s)")
fig.show(renderer="browser")
fig.write_html("missing_data_performance.html")
fig.write_image("missing_data_performance.svg")
if __name__ == '__main__':
main() And this is the result: It would also be interesting to do this for a multitask model which uses Kronecker structure, as this is the more likely use case. But preliminary, I think it is safe to say that masking creates a certain overhead during training, but is independent of the masked fraction, while there is some steady increase with increasing masked points during prediction. |
And here are the benchmarks for the Kronecker case. It seems like this operator cannot make use of indexing to accelerate training or inference. Benchmark codeimport copy
import gc
import math
import time
import gpytorch.settings
import torch
from gpytorch import ExactMarginalLogLikelihood
from gpytorch.distributions import MultivariateNormal, MultitaskMultivariateNormal
from gpytorch.kernels import ScaleKernel, RBFKernel, LCMKernel
from gpytorch.likelihoods import MultitaskGaussianLikelihood
from gpytorch.means import ConstantMean, MultitaskMean
from gpytorch.models import ExactGP
from plotly.subplots import make_subplots
import plotly.graph_objs as go
from torch import Tensor
from tqdm import tqdm
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
class Model(ExactGP):
def __init__(self, train_x: Tensor, train_y: Tensor, num_latents: int):
num_tasks = train_y.shape[-1]
super().__init__(train_x, train_y, MultitaskGaussianLikelihood(num_tasks))
self.mean_module = MultitaskMean(ConstantMean(), num_tasks)
self.covar_module = LCMKernel(
[ScaleKernel(RBFKernel()) for _ in range(num_latents)],
num_tasks,
)
def forward(self, x: Tensor) -> MultitaskMultivariateNormal:
return MultitaskMultivariateNormal(
self.mean_module(x),
self.covar_module(x),
)
def make_dataset(train_num: int, train_missing: float, val_num: int, val_missing: float) -> tuple[Tensor, Tensor, Tensor, Tensor]:
def target_function(x: Tensor) -> Tensor:
return torch.stack([
torch.sin(2 * torch.pi * x),
torch.sin(2 * torch.pi * x) * 0.25,
torch.sin(3 * torch.pi * x) + torch.sin(2 * torch.pi * x),
], dim=1)
train_x = torch.linspace(0, 1, train_num, device=device)
train_y = target_function(train_x)
val_x = torch.linspace(0, 1, val_num, device=device)
val_y = target_function(val_x)
# Randomly mask out some data
if train_missing > 0:
train_mask = torch.bernoulli(torch.full_like(train_y, train_missing)).to(torch.bool)
train_y[train_mask] = torch.nan
if val_missing > 0:
val_mask = torch.bernoulli(torch.full_like(val_y, val_missing)).to(torch.bool)
val_y[val_mask] = torch.nan
train_x = train_x.unsqueeze(-1)
val_x = val_x.unsqueeze(-1)
return train_x, train_y, val_x, val_y
def prepare_model(train_x: Tensor, train_y: Tensor, num_latents: int, steps: int) -> Model:
model = Model(train_x, train_y, num_latents).to(device)
model.train()
# Pre-train model s.t. we have realistic values and convergence times
mll = ExactMarginalLogLikelihood(model.likelihood, model).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for _ in range(steps):
optimizer.zero_grad(set_to_none=True)
loss = -mll(model(*model.train_inputs), model.train_targets)
loss.backward()
optimizer.step()
model.cpu()
return model
def measure(model: Model, train_x: Tensor, train_y: Tensor, val_x: Tensor) -> tuple[float, float]:
# Create a copy of the model with the new training data
model = copy.deepcopy(model)
model.to(device)
model.set_train_data(train_x, train_y, strict=False)
# Simulate training step
mll = ExactMarginalLogLikelihood(model.likelihood, model).to(device)
prior_time = time.time()
loss = mll(model(*model.train_inputs), model.train_targets)
loss.backward()
train_time = time.time() - prior_time
# Simulate prediction
with torch.no_grad():
prior_time = time.time()
model.eval()
prediction: MultivariateNormal = model.likelihood(model(val_x))
mean = prediction.mean
covar = prediction.covariance_matrix
val_time = time.time() - prior_time
return train_time, val_time
def measure_multiple(models: list[Model], train_x: Tensor, train_y: Tensor, val_x: Tensor) -> tuple[float, float, float, float]:
train_times = []
val_times = []
for model in tqdm(models, desc="Collecting", leave=False):
train_time, val_time = measure(model, train_x, train_y, val_x)
train_times.append(train_time)
val_times.append(val_time)
gc.collect()
train_times = torch.tensor(train_times)
train_mean = torch.mean(train_times).item()
train_sem = torch.std(train_times).item() / math.sqrt(len(models))
val_times = torch.tensor(val_times)
val_mean = torch.mean(val_times).item()
val_sem = torch.std(val_times).item() / math.sqrt(len(models))
return train_mean, train_sem, val_mean, val_sem
def main():
with gpytorch.settings.max_cholesky_size(0):
n = 3000
iterations = 50
train_steps = 50
num_latents = 2
nan_fractions = [x.item() for x in torch.linspace(0, 0.5, 12)]
# Prepare some models
# We will use the same models for each NaN fraction and just change their training datasets.
sample_x, sample_y, _, _ = make_dataset(n, 0, n // 10, 0)
models = []
for _ in tqdm(range(iterations), desc="Preparing models"):
models.append(prepare_model(sample_x, sample_y, num_latents, train_steps))
# Collect measurements
measurements = []
for nan_fraction in tqdm(nan_fractions, desc="NaN fractions"):
with gpytorch.settings.observation_nan_policy("mask" if nan_fraction > 0 else "ignore"):
nan_n = int(n / (1 - nan_fraction)) # Scale n s.t. we have an equal amount of observed data
train_x, train_y, val_x, val_y = make_dataset(nan_n, nan_fraction, nan_n // 10, nan_fraction)
measurements.append(list(measure_multiple(models, train_x, train_y, val_x)))
measurements = torch.tensor(measurements)
# Create a plot showing the mean and std. error of the mean for both training and prediction
fig = make_subplots(rows=1, cols=2, column_titles=["Training step", "Prediction step"])
fig.add_trace(go.Scatter(
x=nan_fractions,
y=measurements[:, 0],
error_y=dict(
type="data",
array=measurements[:, 1],
),
mode="lines",
), row=1, col=1)
fig.add_trace(go.Scatter(
x=nan_fractions,
y=measurements[:, 2],
error_y=dict(
type="data",
array=measurements[:, 3],
),
mode="lines",
), row=1, col=2)
fig.update_layout(
title="NaN masking performance for simple RBF model",
showlegend=False,
)
fig.update_xaxes(title="NaN fraction")
fig.update_yaxes(title="Time per step (s)")
fig.show(renderer="browser")
fig.write_html("missing_data_performance.html")
fig.write_image("missing_data_performance.svg")
if __name__ == '__main__':
main() And the result: |
I did another test: Instead of indexing the linear operator in the kronecker case, which seems to bring no improvement, I instead attached a new Masked Linear Operatorfrom typing import Optional, Union
import torch
from linear_operator import LinearOperator
from torch import Tensor
class MaskedLinearOperator(LinearOperator):
def __init__(
self, base: LinearOperator, row_mask: Tensor, col_mask: Tensor
):
super().__init__(base, row_mask, col_mask)
self.base = base
self.row_mask = row_mask
self.col_mask = col_mask
self.row_eq_col_mask = (
row_mask is not None and col_mask is not None and torch.equal(row_mask, col_mask)
)
def _matmul(self, rhs: Tensor) -> Tensor:
if self.col_mask is not None:
rhs_expanded = torch.zeros(
*rhs.shape[:-2],
self.base.size(-1),
rhs.shape[-1],
device=rhs.device,
dtype=rhs.dtype,
)
rhs_expanded[..., self.col_mask, :] = rhs
rhs = rhs_expanded
res = self.base.matmul(rhs)
if self.row_mask is not None:
res = res[..., self.row_mask, :]
return res
def _size(self) -> torch.Size:
base_size = list(self.base.size())
if self.row_mask is not None:
base_size[-2] = torch.count_nonzero(self.row_mask)
if self.col_mask is not None:
base_size[-1] = torch.count_nonzero(self.col_mask)
return torch.Size(tuple(base_size))
def _transpose_nonbatch(self) -> LinearOperator:
return MaskedLinearOperator(self.base.mT, self.col_mask, self.row_mask)
def _getitem(
self,
row_index: Union[slice, torch.LongTensor],
col_index: Union[slice, torch.LongTensor],
*batch_indices: tuple[Union[int, slice, torch.LongTensor], ...],
) -> LinearOperator:
raise NotImplementedError(
"Indexing with %r, %r, %r not supported." % (batch_indices, row_index, col_index)
)
def _get_indices(
self,
row_index: torch.LongTensor,
col_index: torch.LongTensor,
*batch_indices: tuple[torch.LongTensor, ...],
) -> torch.Tensor:
def map_indices(index: torch.LongTensor, mask: Optional[Tensor], base_size: int) -> torch.LongTensor:
if mask is None:
return index
map = torch.arange(base_size, device=self.base.device)[mask]
return map[index]
if len(batch_indices) == 0:
row_index = map_indices(row_index, self.row_mask, self.base.size(-2))
col_index = map_indices(col_index, self.col_mask, self.base.size(-1))
return self.base._get_indices(row_index, col_index)
raise NotImplementedError(
"Indexing with %r, %r, %r not supported." % (batch_indices, row_index, col_index)
)
def _diagonal(self) -> Tensor:
if not self.row_eq_col_mask:
raise NotImplementedError()
diag = self.base.diagonal()
return diag[self.row_mask]
def to_dense(self) -> torch.Tensor:
full_dense = self.base.to_dense()
return full_dense[..., self.row_mask, :][..., :, self.col_mask]
def _cholesky_solve(self, rhs, upper: bool = False) -> LinearOperator:
raise NotImplementedError()
def _expand_batch(self, batch_shape: torch.Size) -> LinearOperator:
raise NotImplementedError()
def _isclose(
self, other, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False
) -> Tensor:
raise NotImplementedError()
def _prod_batch(self, dim: int) -> LinearOperator:
raise NotImplementedError()
def _sum_batch(self, dim: int) -> LinearOperator:
raise NotImplementedError() |
@Turakar would you be able to add a PR for MaskedLinearOperator to the linear_operator repo, and then we can merge this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay @Turakar - I'm back online now!
Just one small change: moving MaskedLinearOperator
and adding a short unit test for it. Then I'll merge!
Hm, I think the failing unit tests might be caused by an incompatbility with linear-operators 0.5.1. |
The test fails because LazyEvaluatedKernelTensor only supports _matmul() with checkpointing, but checkpointing is deprecated.
Considering #2342, I decided to just disable the failing test. It is specific to a missing support of |
RTD is removing the "use system packages" feature on 29 Aug 2023. This PR ensures that our docs will sill build. Moreover, the linear_operator requirement needs to be updated for #2288.
RTD is removing the "use system packages" feature on 29 Aug 2023. This PR ensures that our docs will sill build. Moreover, the linear_operator requirement needs to be updated for #2288.
RTD is removing the "use system packages" feature on 29 Aug 2023. This PR ensures that our docs will sill build. Moreover, the linear_operator requirement needs to be updated for #2288.
RTD is removing the "use system packages" feature on 29 Aug 2023. This PR ensures that our docs will sill build. Moreover, the linear_operator requirement needs to be updated for #2288.
I fixed the merge conflicts. |
Finally merged! Thanks for the patience @Turakar ! |
I am happy it's merged 🙂 |
Summary
Fix #1790. Fix #1881.
While GPyTorch does have limited support for missing labels, e.g. in
GaussianLikelihoodWithMissingObs
, it does not have general support for this. For example, neither single task nor multitask training / prediction allows for NaN target values. NaN values are especially useful for training multitask models with partially missing observations (cf. #1790).To fix this, this PR adds the following new functionality:
observation_nan_policy
with valuesignore
,mask
andfill
. See documentation and implementation for details.ExactMarginalLogLikelihood
for exact training, inDefaultPredictionStrategy
for exact prediction, and in_GaussianLikelihoodBase
for variational training.fill
during exact prediction, because I need to zero-out some of the elements in the kernel matrices.MultitaskMultivariateNormal
can now be indexed (required for indexing the observed values).MultivariateNormal
indexing can deal with a superfluous...
now.GaussianLikelihoodWithMissingObs
is now API-equivalent toGaussianLikelihood
withobservation_nan_policy('fill')
.Alternative to missing data support
Alternatively, one may pass the task index as an additional input to the model. However, depending on the choice of the kernel matrix, it may become complicated to construct. It is conceptually simpler to construct the kernel matrix for all tasks and samples at once (e.g. this allows for
BlockDiagOperator
and alike) and then use the NaN values later for filtering before calculating the marginal log likelihood.Alternative to using a setting for this
Either subclassing like
GaussianLikelihoodWithMissingObs
already does or passing a keyword argument everywhere. In my opinion, a setting is way more useful, especially considering thatDefaultPredictionStrategy
is deeply nested and hard to reach otherwise.Open points
GaussianLikelihoodWithMissingObs()
? At the moment, I just removed it, but we probably want some sort of deprecation?Examples
The following snippets demonstrate the abilities of the proposed changes:
Single Task
Multitask
Variational Multitask