Skip to content

Commit

Permalink
Likelihood bugfix (#2395)
Browse files Browse the repository at this point in the history
* Typehints for approximate gp

* Likelihood passes in args/kwargs to expected_log_prob

* Fix CI errors
  • Loading branch information
gpleiss authored Aug 11, 2023
1 parent 8979210 commit 1c743fa
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
6 changes: 4 additions & 2 deletions gpytorch/likelihoods/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def expected_log_prob(
self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any
) -> Tensor:
likelihood_samples = self._draw_likelihood_samples(function_dist, *args, **kwargs)
res = likelihood_samples.log_prob(observations).mean(dim=0)
res = likelihood_samples.log_prob(observations, *args, **kwargs).mean(dim=0)
return res

@abstractmethod
Expand Down Expand Up @@ -410,7 +410,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
def expected_log_prob(
self, observations: Tensor, function_dist: MultivariateNormal, *args: Any, **kwargs: Any
) -> Tensor:
log_prob_lambda = lambda function_samples: self.forward(function_samples).log_prob(observations)
log_prob_lambda = lambda function_samples: self.forward(function_samples, *args, **kwargs).log_prob(
observations
)
log_prob = self.quadrature(log_prob_lambda, function_dist)
return log_prob

Expand Down
38 changes: 22 additions & 16 deletions gpytorch/models/approximate_gp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
#!/usr/bin/env python3

from typing import Any, Optional

from torch import Tensor

from ..distributions import MultivariateNormal
from .exact_gp import ExactGP

from .gp import GP
from .pyro import _PyroMixin # This will only contain functions if Pyro is installed

Expand Down Expand Up @@ -44,38 +51,38 @@ class ApproximateGP(GP, _PyroMixin):

def __init__(self, variational_strategy):
super().__init__()

self.variational_strategy = variational_strategy

def forward(self, x):
def forward(self, x: Tensor):
raise NotImplementedError

def pyro_guide(self, input, beta=1.0, name_prefix=""):
def pyro_guide(self, input: Tensor, beta: float = 1.0, name_prefix: str = ""):
r"""
(For Pyro integration only). The component of a `pyro.guide` that
corresponds to drawing samples from the latent GP function.
:param torch.Tensor input: The inputs :math:`\mathbf X`.
:param float beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
:param input: The inputs :math:`\mathbf X`.
:param beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
term by.
:param str name_prefix: (default="") A name prefix to prepend to pyro sample sites.
:param name_prefix: (default="") A name prefix to prepend to pyro sample sites.
"""
return super().pyro_guide(input, beta=beta, name_prefix=name_prefix)

def pyro_model(self, input, beta=1.0, name_prefix=""):
def pyro_model(self, input: Tensor, beta: float = 1.0, name_prefix: str = "") -> Tensor:
r"""
(For Pyro integration only). The component of a `pyro.model` that
corresponds to drawing samples from the latent GP function.
:param torch.Tensor input: The inputs :math:`\mathbf X`.
:param float beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
:param input: The inputs :math:`\mathbf X`.
:param beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
term by.
:param str name_prefix: (default="") A name prefix to prepend to pyro sample sites.
:param name_prefix: (default="") A name prefix to prepend to pyro sample sites.
:return: samples from :math:`q(\mathbf f)`
:rtype: torch.Tensor
"""
return super().pyro_model(input, beta=beta, name_prefix=name_prefix)

def get_fantasy_model(self, inputs, targets, **kwargs):
def get_fantasy_model(self, inputs: Tensor, targets: Tensor, **kwargs: Any) -> ExactGP:
r"""
Returns a new GP model that incorporates the specified inputs and targets as new training data using
online variational conditioning (OVC).
Expand All @@ -88,12 +95,11 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
are the same for each target batch.
:param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
:param inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
observations.
:param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
:param targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
:return: An `ExactGP` model with `n + m` training examples, where the `m` fantasy examples have been added
and all test-time caches have been updated.
:rtype: ~gpytorch.models.ExactGP
Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
Maddox, Stanton, Wilson, NeurIPS, '21
Expand All @@ -102,7 +108,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
"""
return self.variational_strategy.get_fantasy_model(inputs=inputs, targets=targets, **kwargs)

def __call__(self, inputs, prior=False, **kwargs):
if inputs.dim() == 1:
def __call__(self, inputs: Optional[Tensor], prior: bool = False, **kwargs) -> MultivariateNormal:
if inputs is not None and inputs.dim() == 1:
inputs = inputs.unsqueeze(-1)
return self.variational_strategy(inputs, prior=prior, **kwargs)
3 changes: 3 additions & 0 deletions test/lazy/test_lazy_evaluated_kernel_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ def test_getitem_tensor_index(self):
def test_bilinear_derivative(self):
pass

def test_t_matmul_matrix(self):
pass

def test_half(self):
# many transform operations aren't supported in half so we overwrite
# this test
Expand Down

0 comments on commit 1c743fa

Please sign in to comment.