diff --git a/docs/source/kernels.rst b/docs/source/kernels.rst index 5fa89b916..0964085be 100644 --- a/docs/source/kernels.rst +++ b/docs/source/kernels.rst @@ -119,24 +119,12 @@ Composition/Decoration Kernels .. autoclass:: MultiDeviceKernel :members: -:hidden:`AdditiveStructureKernel` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: AdditiveStructureKernel - :members: - :hidden:`ProductKernel` ~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: ProductKernel :members: -:hidden:`ProductStructureKernel` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: ProductStructureKernel - :members: - :hidden:`ScaleKernel` ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/gpytorch/kernels/__init__.py b/gpytorch/kernels/__init__.py index 55119b784..e3470fb82 100644 --- a/gpytorch/kernels/__init__.py +++ b/gpytorch/kernels/__init__.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 from . import keops -from .additive_structure_kernel import AdditiveStructureKernel from .arc_kernel import ArcKernel from .constant_kernel import ConstantKernel from .cosine_kernel import CosineKernel @@ -19,12 +18,10 @@ from .matern_kernel import MaternKernel from .multi_device_kernel import MultiDeviceKernel from .multitask_kernel import MultitaskKernel -from .newton_girard_additive_kernel import NewtonGirardAdditiveKernel from .periodic_kernel import PeriodicKernel from .piecewise_polynomial_kernel import PiecewisePolynomialKernel from .polynomial_kernel import PolynomialKernel from .polynomial_kernel_grad import PolynomialKernelGrad -from .product_structure_kernel import ProductStructureKernel from .rbf_kernel import RBFKernel from .rbf_kernel_grad import RBFKernelGrad from .rbf_kernel_gradgrad import RBFKernelGradGrad @@ -39,7 +36,6 @@ "Kernel", "ArcKernel", "AdditiveKernel", - "AdditiveStructureKernel", "ConstantKernel", "CylindricalKernel", "MultiDeviceKernel", @@ -55,13 +51,11 @@ "LinearKernel", "MaternKernel", "MultitaskKernel", - "NewtonGirardAdditiveKernel", "PeriodicKernel", "PiecewisePolynomialKernel", "PolynomialKernel", "PolynomialKernelGrad", "ProductKernel", - "ProductStructureKernel", "RBFKernel", "RFFKernel", "RBFKernelGrad", diff --git a/gpytorch/kernels/additive_structure_kernel.py b/gpytorch/kernels/additive_structure_kernel.py deleted file mode 100644 index 35dfea259..000000000 --- a/gpytorch/kernels/additive_structure_kernel.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 - -import warnings -from typing import Optional, Tuple - -from .kernel import Kernel - - -class AdditiveStructureKernel(Kernel): - r""" - A Kernel decorator for kernels with additive structure. If a kernel decomposes - additively, then this module will be much more computationally efficient. - - A kernel function `k` decomposes additively if it can be written as - - .. math:: - - \begin{equation*} - k(\mathbf{x_1}, \mathbf{x_2}) = k'(x_1^{(1)}, x_2^{(1)}) + \ldots + k'(x_1^{(d)}, x_2^{(d)}) - \end{equation*} - - for some kernel :math:`k'` that operates on a subset of dimensions. - - Given a `b x n x d` input, `AdditiveStructureKernel` computes `d` one-dimensional kernels - (using the supplied base_kernel), and then adds the component kernels together. - Unlike :class:`~gpytorch.kernels.AdditiveKernel`, `AdditiveStructureKernel` computes each - of the additive terms in batch, making it very fast. - - Args: - base_kernel (Kernel): - The kernel to approximate with KISS-GP - num_dims (int): - The dimension of the input data. - active_dims (tuple of ints, optional): - Passed down to the `base_kernel`. - """ - - @property - def is_stationary(self) -> bool: - """ - Kernel is stationary if the base kernel is stationary. - """ - return self.base_kernel.is_stationary - - def __init__( - self, - base_kernel: Kernel, - num_dims: int, - active_dims: Optional[Tuple[int, ...]] = None, - ): - warnings.warn( - "AdditiveStructureKernel is deprecated, and will be removed in GPyTorch 2.0. " - 'Please refer to the "Kernels with Additive or Product Structure" tutorial ' - "in the GPyTorch docs for how to implement GPs with additive structure.", - DeprecationWarning, - ) - super(AdditiveStructureKernel, self).__init__(active_dims=active_dims) - self.base_kernel = base_kernel - self.num_dims = num_dims - - def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): - if last_dim_is_batch: - raise RuntimeError("AdditiveStructureKernel does not accept the last_dim_is_batch argument.") - - res = self.base_kernel(x1, x2, diag=diag, last_dim_is_batch=True, **params) - res = res.sum(-2 if diag else -3) - return res - - def prediction_strategy(self, train_inputs, train_prior_dist, train_labels, likelihood): - return self.base_kernel.prediction_strategy(train_inputs, train_prior_dist, train_labels, likelihood) - - def num_outputs_per_input(self, x1, x2): - return self.base_kernel.num_outputs_per_input(x1, x2) diff --git a/gpytorch/kernels/constant_kernel.py b/gpytorch/kernels/constant_kernel.py index 98a3560e2..ab177519c 100644 --- a/gpytorch/kernels/constant_kernel.py +++ b/gpytorch/kernels/constant_kernel.py @@ -90,7 +90,6 @@ def forward( x1: Tensor, x2: Tensor, diag: Optional[bool] = False, - last_dim_is_batch: Optional[bool] = False, ) -> Tensor: """Evaluates the constant kernel. @@ -98,17 +97,11 @@ def forward( x1: First input tensor of shape (batch_shape x n1 x d). x2: Second input tensor of shape (batch_shape x n2 x d). diag: If True, returns the diagonal of the covariance matrix. - last_dim_is_batch: If True, the last dimension of size `d` of the input - tensors are treated as a batch dimension. Returns: A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of constant covariance values if diag is False, resp. True. """ - if last_dim_is_batch: - x1 = x1.transpose(-1, -2).unsqueeze(-1) - x2 = x2.transpose(-1, -2).unsqueeze(-1) - dtype = torch.promote_types(x1.dtype, x2.dtype) batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],)) @@ -117,7 +110,4 @@ def forward( if not diag: constant = constant.unsqueeze(-1) - if last_dim_is_batch: - constant = constant.unsqueeze(-1) - return constant.expand(shape) diff --git a/gpytorch/kernels/cosine_kernel.py b/gpytorch/kernels/cosine_kernel.py index 11add6f2f..49d6a67e6 100644 --- a/gpytorch/kernels/cosine_kernel.py +++ b/gpytorch/kernels/cosine_kernel.py @@ -56,8 +56,6 @@ class CosineKernel(Kernel): >>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10) """ - is_stationary = True - def __init__( self, period_length_prior: Optional[Prior] = None, @@ -85,6 +83,10 @@ def __init__( self.register_constraint("raw_period_length", period_length_constraint) + @property + def is_stationary(self): + return True + @property def period_length(self): return self.raw_period_length_constraint.transform(self.raw_period_length) diff --git a/gpytorch/kernels/cylindrical_kernel.py b/gpytorch/kernels/cylindrical_kernel.py index 48f24958c..2ad270bd7 100644 --- a/gpytorch/kernels/cylindrical_kernel.py +++ b/gpytorch/kernels/cylindrical_kernel.py @@ -4,7 +4,6 @@ import torch -from .. import settings from ..constraints import Interval, Positive from ..priors import Prior from .kernel import Kernel @@ -152,8 +151,7 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = Fal else: angular_kernel = angular_kernel + self.angular_weights[..., p, None].mul(gram_mat.pow(p)) - with settings.lazily_evaluate_kernels(False): - radial_kernel = self.radial_base_kernel(self.kuma(r1), self.kuma(r2), diag=diag, **params) + radial_kernel = self.radial_base_kernel.forward(self.kuma(r1), self.kuma(r2), diag=diag, **params) return radial_kernel.mul(angular_kernel) def kuma(self, x: torch.Tensor) -> torch.Tensor: diff --git a/gpytorch/kernels/grid_interpolation_kernel.py b/gpytorch/kernels/grid_interpolation_kernel.py index bcdc48ed1..4abdef4d1 100644 --- a/gpytorch/kernels/grid_interpolation_kernel.py +++ b/gpytorch/kernels/grid_interpolation_kernel.py @@ -1,10 +1,12 @@ #!/usr/bin/env python3 -from typing import List, Optional, Tuple, Union +from typing import Iterable, Optional, Tuple, Union import torch +from jaxtyping import Float from linear_operator import to_linear_operator -from linear_operator.operators import InterpolatedLinearOperator +from linear_operator.operators import InterpolatedLinearOperator, LinearOperator +from torch import Tensor from ..models.exact_prediction_strategies import InterpolatedPredictionStrategy from ..utils.grid import create_grid @@ -25,14 +27,14 @@ class GridInterpolationKernel(GridKernel): .. math:: \begin{equation*} - k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{U,U} \mathbf{w_{x_2}} + k(\mathbf{x_1}, \mathbf{x_2}) = \mathbf{w_{x_1}}^\top K_{\boldsymbol Z, \boldsymbol Z} \mathbf{w_{x_2}} \end{equation*} where - * :math:`U` is the set of gridded inducing points + * :math:`\boldsymbol Z` is the set of gridded inducing points - * :math:`K_{U,U}` is the kernel matrix between the inducing points + * :math:`K_{\boldsymbol Z, \boldsymbol Z}` is the kernel matrix between the inducing points * :math:`\mathbf{w_{x_1}}` and :math:`\mathbf{w_{x_2}}` are sparse vectors based on :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}` that apply cubic interpolation. @@ -50,20 +52,13 @@ class GridInterpolationKernel(GridKernel): `GridInterpolationKernel` can only wrap **stationary kernels** (such as RBF, Matern, Periodic, Spectral Mixture, etc.) - Args: - base_kernel (Kernel): - The kernel to approximate with KISS-GP - grid_size (Union[int, List[int]]): - The size of the grid in each dimension. - If a single int is provided, then every dimension will have the same grid size. - num_dims (int): - The dimension of the input data. Required if `grid_bounds=None` - grid_bounds (tuple(float, float), optional): - The bounds of the grid, if known (high performance mode). - The length of the tuple must match the number of dimensions. - The entries represent the min/max values for each dimension. - active_dims (tuple of ints, optional): - Passed down to the `base_kernel`. + :param base_kernel: The kernel to approximate with KISS-GP. + :param grid_size: The size of the grid in each dimension. + If a single int is provided, then every dimension will have the same grid size. + :param num_dims: The dimension of the input data. Required if `grid_bounds=None` + :param grid_bounds: The bounds of the grid, if known (high performance mode). + The length of the tuple must match the number of dimensions. + The entries represent the min/max values for each dimension. .. _Kernel Interpolation for Scalable Structured Gaussian Processes: http://proceedings.mlr.press/v37/wilson15.pdf @@ -72,10 +67,10 @@ class GridInterpolationKernel(GridKernel): def __init__( self, base_kernel: Kernel, - grid_size: Union[int, List[int]], + grid_size: Union[int, Iterable[int]], num_dims: Optional[int] = None, grid_bounds: Optional[Tuple[float, float]] = None, - active_dims: Optional[Tuple[int, ...]] = None, + **kwargs, ): has_initialized_grid = 0 grid_is_dynamic = True @@ -116,11 +111,17 @@ def __init__( super(GridInterpolationKernel, self).__init__( base_kernel=base_kernel, grid=grid, - interpolation_mode=True, - active_dims=active_dims, + **kwargs, ) self.register_buffer("has_initialized_grid", torch.tensor(has_initialized_grid, dtype=torch.bool)) + @property + def _lazily_evaluate(self) -> bool: + # GridInterpolationKernels should not lazily evaluate; there are few gains (the inducing point kernel + # matrix always needs to be evaluated; regardless of the size of x1 and x2), and the + # InterpolatedLinearOperator structure is needed for fast predictions. + return False + @property def _tight_grid_bounds(self): grid_spacings = tuple((bound[1] - bound[0]) / self.grid_sizes[i] for i, bound in enumerate(self.grid_bounds)) @@ -129,23 +130,26 @@ def _tight_grid_bounds(self): for bound, spacing in zip(self.grid_bounds, grid_spacings) ) - def _compute_grid(self, inputs, last_dim_is_batch=False): - n_data, n_dimensions = inputs.size(-2), inputs.size(-1) - if last_dim_is_batch: - inputs = inputs.transpose(-1, -2).unsqueeze(-1) - n_dimensions = 1 - batch_shape = inputs.shape[:-2] - + def _compute_grid(self, inputs): + *batch_shape, n_data, n_dimensions = inputs.shape inputs = inputs.reshape(-1, n_dimensions) interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs) interp_indices = interp_indices.view(*batch_shape, n_data, -1) interp_values = interp_values.view(*batch_shape, n_data, -1) return interp_indices, interp_values - def _inducing_forward(self, last_dim_is_batch, **params): - return super().forward(self.grid, self.grid, last_dim_is_batch=last_dim_is_batch, **params) + def _create_or_update_full_grid(self, grid: Iterable[Tensor]): + pass - def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): + def _validate_inputs(self, x: Float[Tensor, "... N D"]) -> bool: + return True + + def _inducing_forward(self, **params): + return super().forward(None, None, **params) + + def forward( + self, x1: Float[Tensor, "... N_1 D"], x2: Float[Tensor, "... N_2 D"], diag: bool = False, **params + ) -> Float[Union[Tensor, LinearOperator], "... N_1 N_2"]: # See if we need to update the grid or not if self.grid_is_dynamic: # This is true if a grid_bounds wasn't passed in if torch.equal(x1, x2): @@ -180,16 +184,13 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): ) self.update_grid(grid) - base_lazy_tsr = to_linear_operator(self._inducing_forward(last_dim_is_batch=last_dim_is_batch, **params)) - if last_dim_is_batch and base_lazy_tsr.size(-3) == 1: - base_lazy_tsr = base_lazy_tsr.repeat(*x1.shape[:-2], x1.size(-1), 1, 1) - - left_interp_indices, left_interp_values = self._compute_grid(x1, last_dim_is_batch) + base_lazy_tsr = to_linear_operator(self._inducing_forward(**params)) + left_interp_indices, left_interp_values = self._compute_grid(x1) if torch.equal(x1, x2): right_interp_indices = left_interp_indices right_interp_values = left_interp_values else: - right_interp_indices, right_interp_values = self._compute_grid(x2, last_dim_is_batch) + right_interp_indices, right_interp_values = self._compute_grid(x2) batch_shape = torch.broadcast_shapes( base_lazy_tsr.batch_shape, diff --git a/gpytorch/kernels/grid_kernel.py b/gpytorch/kernels/grid_kernel.py index 8a3503943..334b28d6f 100644 --- a/gpytorch/kernels/grid_kernel.py +++ b/gpytorch/kernels/grid_kernel.py @@ -1,24 +1,25 @@ #!/usr/bin/env python3 import warnings -from typing import Optional +from typing import Iterable, Union import torch +from jaxtyping import Float from linear_operator import to_dense -from linear_operator.operators import KroneckerProductLinearOperator, ToeplitzLinearOperator +from linear_operator.operators import KroneckerProductLinearOperator, LinearOperator, ToeplitzLinearOperator from torch import Tensor from .. import settings -from ..utils.grid import convert_legacy_grid, create_data_from_grid +from ..utils.grid import create_data_from_grid from .kernel import Kernel class GridKernel(Kernel): r""" - If the input data :math:`X` are regularly spaced on a grid, then - `GridKernel` can dramatically speed up computatations for stationary kernel. - - GridKernel exploits Toeplitz and Kronecker structure within the covariance matrix. + `GridKernel` wraps a stationary kernel that is computed on a (multidimensional) + grid that is regularly spaced along each dimension. + It exploits Toeplitz and Kronecker structure within the covariance matrix + for massive computational speedups. See `Fast kernel learning for multidimensional pattern extrapolation`_ for more info. .. note:: @@ -26,154 +27,153 @@ class GridKernel(Kernel): `GridKernel` can only wrap **stationary kernels** (such as RBF, Matern, Periodic, Spectral Mixture, etc.) - Args: - base_kernel (Kernel): - The kernel to speed up with grid methods. - grid (Tensor): - A g x d tensor where column i consists of the projections of the - grid in dimension i. - active_dims (tuple of ints, optional): - Passed down to the `base_kernel`. - interpolation_mode (bool): - Used for GridInterpolationKernel where we want the covariance - between points in the projections of the grid of each dimension. - We do this by treating `grid` as d batches of g x 1 tensors by - calling base_kernel(grid, grid) with last_dim_is_batch to get a d x g x g Tensor - which we Kronecker product to get a g x g KroneckerProductLinearOperator. + :param base_kernel: The stationary kernel to speed up with grid methods. + :param grid: A list of tensors where tensor `i` consists of the projections + of the grid in dimension i. + :param active_dims: + + :ivar ragged_grid: A concatenation of all grid projections + :type ragged_grid: Tensor (max(M_i) x D) + :ivar full_grid: A full representation of the grid + :type ragged_grid: Tensor (N x D) .. _Fast kernel learning for multidimensional pattern extrapolation: http://www.cs.cmu.edu/~andrewgw/manet.pdf """ - is_stationary = True - def __init__( self, base_kernel: Kernel, - grid: Tensor, - interpolation_mode: Optional[bool] = False, - active_dims: Optional[bool] = None, + grid: Iterable[Float[Tensor, "M_i"]], # noqa F821 + **kwargs, ): if not base_kernel.is_stationary: raise RuntimeError("The base_kernel for GridKernel must be stationary.") + batch_shapes, num_grid_points = zip(*[(sub_grid.shape[:-1], sub_grid.shape[-1]) for sub_grid in grid]) - super().__init__(active_dims=active_dims) - if torch.is_tensor(grid): - grid = convert_legacy_grid(grid) - self.interpolation_mode = interpolation_mode + super().__init__(**kwargs) self.base_kernel = base_kernel self.num_dims = len(grid) - self.register_buffer_list("grid", grid) - if not self.interpolation_mode: - self.register_buffer("full_grid", create_data_from_grid(grid)) + self.num_grid_points = num_grid_points + + # Store each grid in a buffer + for i, sub_grid in enumerate(grid): + assert sub_grid.dim() == 1 + self.register_buffer(f"grid_{i}", sub_grid) + + # Create a buffer to store a concatenation of all grids + num_grid_points = [sub_grid.size(-1) for sub_grid in grid] + ragged_grid: Float[Tensor, "M D"] = torch.zeros( + max(self.num_grid_points), self.num_dims, dtype=grid[0].dtype, device=grid[0].device + ) + self.register_buffer("ragged_grid", ragged_grid) + + # Update the ragged_grid buffer + # Also create the full_grid buffer + self.update_grid(grid) + + @property + def _lazily_evaluate(self) -> bool: + # Toeplitz structure is very efficient; no need to lazily evaluate + return False + + @property + def is_stationary(self) -> bool: + return True def _clear_cache(self): if hasattr(self, "_cached_kernel_mat"): del self._cached_kernel_mat - def register_buffer_list(self, base_name, tensors): - """Helper to register several buffers at once under a single base name""" - for i, tensor in enumerate(tensors): - self.register_buffer(base_name + "_" + str(i), tensor) + def _create_or_update_full_grid(self, grid: Iterable[Float[Tensor, "M_i"]]): # noqa F821 + full_grid = create_data_from_grid(self.grid) + if hasattr(self, "full_grid"): + self.full_grid.reshape(full_grid.shape) + self.full_grid.copy_(full_grid.type_as(self.full_grid)) + else: + self.register_buffer("full_grid", full_grid) + + def _validate_inputs(self, x: Float[Tensor, "... N D"]) -> bool: + return torch.equal(self.full_grid.expand(*x.shape[:-2], *self.full_grid.shape[-2:]), x) @property - def grid(self): + def grid(self) -> Float[Tensor, "N D"]: return [getattr(self, f"grid_{i}") for i in range(self.num_dims)] - def update_grid(self, grid): + def update_grid(self, grid: Iterable[Float[Tensor, "M_i"]]): # noqa F821 """ Supply a new `grid` if it ever changes. """ - if torch.is_tensor(grid): - grid = convert_legacy_grid(grid) - if len(grid) != self.num_dims: raise RuntimeError("New grid should have the same number of dimensions as before.") + num_grid_points = [sub_grid.size(-1) for sub_grid in grid] + + # Update the size of the ragged_grid buffer + self.ragged_grid.reshape(max(self.num_grid_points), self.num_dims) + self.num_grid_points = num_grid_points - for i in range(self.num_dims): - setattr(self, f"grid_{i}", grid[i]) + # Update the grid and ragged_grid buffers + for i, (num_grid_point, sub_grid) in enumerate(zip(num_grid_points, grid)): + assert sub_grid.dim() == 1 + getattr(self, f"grid_{i}").reshape(sub_grid.shape) + getattr(self, f"grid_{i}").copy_(sub_grid.type_as(self.ragged_grid)) + # Grids aren't necessarily the same size across each dimension + # Some grids will be padded by zeros, which will be removed after computing kernel rows + self.ragged_grid[..., :num_grid_point, i] = sub_grid.type_as(self.ragged_grid) - if not self.interpolation_mode: - self.full_grid = create_data_from_grid(self.grid) + # Update the full_grid buffer + self._create_or_update_full_grid(grid) + # Clear cache self._clear_cache() return self - @property - def is_ragged(self): - return not all(self.grid[0].size() == proj.size() for proj in self.grid) - - def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): - if last_dim_is_batch and not self.interpolation_mode: - raise ValueError("last_dim_is_batch is only valid with interpolation model") - - grid = self.grid - if self.is_ragged: - # Pad the grid - so that grid is the same size for each dimension - max_grid_size = max(proj.size(-1) for proj in grid) - padded_grid = [] - for proj in grid: - padding_size = max_grid_size - proj.size(-1) - if padding_size > 0: - dtype = proj.dtype - device = proj.device - padded_grid.append( - torch.cat([proj, torch.zeros(*proj.shape[:-1], padding_size, dtype=dtype, device=device)]) - ) - else: - padded_grid.append(proj) - else: - padded_grid = grid - - if not self.interpolation_mode: - if len(x1.shape[:-2]): - full_grid = self.full_grid.expand(*x1.shape[:-2], *self.full_grid.shape[-2:]) - else: - full_grid = self.full_grid - - if self.interpolation_mode or (torch.equal(x1, full_grid) and torch.equal(x2, full_grid)): - if not self.training and hasattr(self, "_cached_kernel_mat"): - return self._cached_kernel_mat - # Can exploit Toeplitz structure if grid points in each dimension are equally - # spaced and using a translation-invariant kernel - if settings.use_toeplitz.on(): - # Use padded grid for batch mode - first_grid_point = torch.stack([proj[0].unsqueeze(0) for proj in grid], dim=-1) - full_grid = torch.stack(padded_grid, dim=-1) - with warnings.catch_warnings(): # Hide the GPyTorch 2.0 deprecation warning - warnings.simplefilter("ignore", DeprecationWarning) - covars = to_dense(self.base_kernel(first_grid_point, full_grid, last_dim_is_batch=True, **params)) - - if last_dim_is_batch: - # Toeplitz expects batches of columns so we concatenate the - # 1 x grid_size[i] tensors together - # Note that this requires all the dimensions to have the same number of grid points - covar = ToeplitzLinearOperator(covars.squeeze(-2)) - else: - # Non-batched ToeplitzLinearOperator expects a 1D tensor, so we squeeze out the row dimension - covars = covars.squeeze(-2) # Get rid of the dimension corresponding to the first point - # Un-pad the grid - covars = [ToeplitzLinearOperator(covars[..., i, : proj.size(-1)]) for i, proj in enumerate(grid)] - # Due to legacy reasons, KroneckerProductLinearOperator(A, B, C) is actually (C Kron B Kron A) - covar = KroneckerProductLinearOperator(*covars[::-1]) - else: - full_grid = torch.stack(padded_grid, dim=-1) - with warnings.catch_warnings(): # Hide the GPyTorch 2.0 deprecation warning - warnings.simplefilter("ignore", DeprecationWarning) - covars = to_dense(self.base_kernel(full_grid, full_grid, last_dim_is_batch=True, **params)) - if last_dim_is_batch: - # Note that this requires all the dimensions to have the same number of grid points - covar = covars - else: - covars = [covars[..., i, : proj.size(-1), : proj.size(-1)] for i, proj in enumerate(self.grid)] - covar = KroneckerProductLinearOperator(*covars[::-1]) - - if not self.training: - self._cached_kernel_mat = covar - - return covar - else: - return self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params) - - def num_outputs_per_input(self, x1, x2): + def forward( + self, x1: Float[Tensor, "... N_1 D"], x2: Float[Tensor, "... N_2 D"], diag: bool = False, **params + ) -> Union[Float[LinearOperator, "... N_1 N_2"], Float[Tensor, "... N_1"]]: + if diag: + return self.base_kernel(x1, x2, diag=True, **params) + + # If this kernel is not called with the grid data, directly call base_kernel + if not (self._validate_inputs(x1) and self._validate_inputs(x2)): + warnings.warn("GridKernel was called with non-grid data.", RuntimeWarning) + return self.base_kernel(x1, x2, diag=False, **params) + + # Default case + if not self.training and hasattr(self, "_cached_kernel_mat"): + return self._cached_kernel_mat + + first_grid_points = self.ragged_grid[..., :1, :] + + # Compute the first rows of each univariate kernel on each of the D-dimensions + # The result will be batched and stored in a D x ... x M matrix + # Hack: + # Base kernel expects a d-dimensional input. To compute the kernel on + # the grid projected ondo dim i, we zero the data in all other dimensions. + # Since the kernel is stationary, the other dimensions won't contribute to the covariance. + batch_shape = torch.broadcast_shapes(self.ragged_grid.shape[:-2], self.base_kernel.batch_shape) + masks = torch.eye(self.num_dims, dtype=self.ragged_grid.dtype, device=self.ragged_grid.device).view( + self.num_dims, *[1 for _ in batch_shape], 1, self.num_dims + ) # D x ... x 1 x D + # This mask will zero out all but the i^th dimension for the i^th batch member + with settings.lazily_evaluate_kernels(False): + unidimensional_kernel_first_rows = to_dense( + self.base_kernel(first_grid_points * masks, self.ragged_grid * masks, **params) + ) # D x ... x M + + # Convert the first rows of the unidimensional kernels into ToeplitzLinearOperators + # (Un-pad the kernel first row as necessary) + unidimensional_kernels = [ + ToeplitzLinearOperator(unidimensional_kernel_first_rows[i, ..., 0, :num_grid_point]) + for i, num_grid_point in enumerate(self.num_grid_points) + ] # D x ... x M_i x M_i + # Due to legacy reasons, KroneckerProductLinearOperator(A, B, C) is actually (C Kron B Kron A) + covar = KroneckerProductLinearOperator(*unidimensional_kernels[::-1]) # ... x N x N + + if not self.training: + self._cached_kernel_mat = covar + + return covar + + def num_outputs_per_input(self, x1: Float[Tensor, "... N_1 D"], x2: Float[Tensor, "... N_2 D"]) -> int: return self.base_kernel.num_outputs_per_input(x1, x2) diff --git a/gpytorch/kernels/index_kernel.py b/gpytorch/kernels/index_kernel.py index 7fa5e01f3..f7399b652 100644 --- a/gpytorch/kernels/index_kernel.py +++ b/gpytorch/kernels/index_kernel.py @@ -76,6 +76,12 @@ def __init__( self.register_constraint("raw_var", var_constraint) + @property + def _lazily_evaluate(self) -> bool: + # IndexKernel does not need lazy evaluation, since the complete BB^T + D_v` is always + # computed regardless of x1 and x2 + return False + @property def var(self): return self.raw_var_constraint.transform(self.raw_var) diff --git a/gpytorch/kernels/inducing_point_kernel.py b/gpytorch/kernels/inducing_point_kernel.py index 7ea19f283..11d2e4473 100644 --- a/gpytorch/kernels/inducing_point_kernel.py +++ b/gpytorch/kernels/inducing_point_kernel.py @@ -47,6 +47,12 @@ def _clear_cache(self): if hasattr(self, "_cached_kernel_inv_root"): del self._cached_kernel_inv_root + @property + def _lazily_evaluate(self) -> bool: + # InducingPointKernels kernels should not lazily evaluate; to use the Woodbury formula, + # we want the Kernel to return a LowRankLinearOperator, not a KernelLinaerOperator. + return False + @property def _inducing_mat(self): if not self.training and hasattr(self, "_cached_kernel_mat"): diff --git a/gpytorch/kernels/keops/rbf_kernel.py b/gpytorch/kernels/keops/rbf_kernel.py index 5497f0f47..918ec523d 100644 --- a/gpytorch/kernels/keops/rbf_kernel.py +++ b/gpytorch/kernels/keops/rbf_kernel.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -# from linear_operator.operators import KeOpsLinearOperator from linear_operator.operators import KernelLinearOperator from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel diff --git a/gpytorch/kernels/kernel.py b/gpytorch/kernels/kernel.py index 67e576db3..fc0c3e68d 100644 --- a/gpytorch/kernels/kernel.py +++ b/gpytorch/kernels/kernel.py @@ -4,12 +4,13 @@ import warnings from abc import abstractmethod +from collections import defaultdict, OrderedDict from copy import deepcopy -from typing import Callable, Dict, Iterable, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import torch from linear_operator import to_dense, to_linear_operator -from linear_operator.operators import LinearOperator, ZeroLinearOperator +from linear_operator.operators import KernelLinearOperator, LinearOperator, ZeroLinearOperator from torch import Tensor from torch.nn import ModuleList @@ -81,6 +82,44 @@ def _dist(self, x1, x2, x1_eq_x2=False, postprocess=False): return self._postprocess(res) if postprocess else res +class _autograd_kernel_hack: + """ + Helper class. + + When using KernelLinearOperator, the `covar_func` cannot close over any Tensors that require gradients. + (Any Tensor that `covar_func` closes over will not backpropagate gradients.) + Unfortunately, for most kernels, `covar_func=self.forward`, which closes over all of the kernel's parameters. + + This context manager temporarily replaces a kernel (and its submodules') parameter assignments with an + external set of references to these parameters. + The external set of references will be passed in by KernelLinearOperator. + + This way, when calling self.forward, no parameter references are closed over, and so all parameters + will receive the appropriate gradients. + """ + + def __init__( + self, + kernel: Kernel, + params: Dict[str, torch.nn.Parameters], + module_params: Dict[torch.nn.Module, Iterable[str]], + ): + self.temp_module_param_dicts = defaultdict(OrderedDict) + for module, param_names in module_params.items(): + self.temp_module_param_dicts[module] = OrderedDict( + (param_name.rsplit(".", 1)[-1], params[param_name]) for param_name in param_names + ) + self.orig_model_param_dicts = dict((module, module._parameters) for module in self.temp_module_param_dicts) + + def __enter__(self): + for module, temp_param_dict in self.temp_module_param_dicts.items(): + object.__setattr__(module, "_parameters", temp_param_dict) + + def __exit__(self, type, value, traceback): + for module, orig_param_dict in self.orig_model_param_dicts.items(): + object.__setattr__(module, "_parameters", orig_param_dict) + + class Kernel(Module): r""" Kernels in GPyTorch are implemented as a :class:`gpytorch.Module` that, when called on two :class:`torch.Tensor` @@ -212,6 +251,45 @@ def __init__( # TODO: Remove this on next official PyTorch release. self.__pdist_supports_batch = True + @property + def _lazily_evaluate(self) -> bool: + r""" + Determines whether or not the kernel is lazily evaluated. + + If False, kernel(x1, x2) produces a Tensor/LinearOperator where the covariance function has been evaluated + over x1 and x2. + + If True, kernel(x1, x2) produces a KernelLinearOperator that delays evaluation of the kernel function. + The kernel function will only be evaluated when either + - An mathematical operation is performed on the kernel matrix (e.g. solves, logdets, etc.), or + - An indexing operation is performed on the kernel matrix to select specific covariance entries. + + In general, _lazily_evaluate should return True (this option is more efficient), unless lazy evaluation + offers no gains and there is specific structure that will be lost with lazy evaluation + (e.g. low-rank/Nystrom approximations). + """ + return True + + def _kernel_linear_operator_covar_func( + self, + x1: Tensor, + x2: Tensor, + non_param_kwargs: Dict[str, Any], + module_params: Dict[torch.nn.Module, Iterable[str]], + **params: torch.nn.Parameter, + ) -> Union[Tensor, LinearOperator]: + # This is the `covar_function` that is passed into KernelLinearOperator + # This function calls self.forward, but does so in a way so that no parameters are closed over + # (by using the _autograd_kernel_hack context manager) + try: + if any(param.requires_grad for param in params.values()): + with _autograd_kernel_hack(self, params, module_params): + return self.forward(x1, x2, **non_param_kwargs) + else: + return self.forward(x1, x2, **non_param_kwargs) + except Exception as e: + raise e + def _lengthscale_param(self, m: Kernel) -> Tensor: # Used by the lengthscale_prior return m.lengthscale @@ -231,9 +309,7 @@ def _set_lengthscale(self, value: Tensor): self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value)) @abstractmethod - def forward( - self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params - ) -> Union[Tensor, LinearOperator]: + def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]: r""" Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`. This method should be implemented by all Kernel subclasses. @@ -242,16 +318,11 @@ def forward( :param x2: Second set of data (... x M x D). :param diag: Should the Kernel compute the whole kernel, or just the diag? If True, it must be the case that `x1 == x2`. (Default: False.) - :param last_dim_is_batch: If True, treat the last dimension - of `x1` and `x2` as another batch dimension. - (Useful for additive structure over the dimensions). (Default: False.) :return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode: * `full_covar`: `... x N x M` - * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M` * `diag`: `... x N` - * `diag` with `last_dim_is_batch=True`: `... x K x N` """ raise NotImplementedError() @@ -314,7 +385,6 @@ def covar_dist( x1: Tensor, x2: Tensor, diag: bool = False, - last_dim_is_batch: bool = False, square_dist: bool = False, **params, ) -> Tensor: @@ -326,22 +396,13 @@ def covar_dist( :param x2: Second set of data (... x M x D). :param diag: Should the Kernel compute the whole kernel, or just the diag? If True, it must be the case that `x1 == x2`. (Default: False.) - :param last_dim_is_batch: If True, treat the last dimension - of `x1` and `x2` as another batch dimension. - (Useful for additive structure over the dimensions). (Default: False.) :param square_dist: If True, returns the squared distance rather than the standard distance. (Default: False.) :return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode: * `full_covar`: `... x N x M` - * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M` * `diag`: `... x N` - * `diag` with `last_dim_is_batch=True`: `... x K x N` """ - if last_dim_is_batch: - x1 = x1.transpose(-1, -2).unsqueeze(-1) - x2 = x2.transpose(-1, -2).unsqueeze(-1) - x1_eq_x2 = torch.equal(x1, x2) res = None @@ -457,7 +518,7 @@ def sub_kernels(self) -> Iterable[Kernel]: yield kernel def __call__( - self, x1: Tensor, x2: Optional[Tensor] = None, diag: bool = False, last_dim_is_batch: bool = False, **params + self, x1: Tensor, x2: Optional[Tensor] = None, diag: bool = False, **params ) -> Union[LazyEvaluatedKernelTensor, LinearOperator, Tensor]: r""" Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`. @@ -473,27 +534,13 @@ def __call__( (If `None`, then `x2` is set to `x1`.) :param diag: Should the Kernel compute the whole kernel, or just the diag? If True, it must be the case that `x1 == x2`. (Default: False.) - :param last_dim_is_batch: If True, treat the last dimension - of `x1` and `x2` as another batch dimension. - (Useful for additive structure over the dimensions). (Default: False.) :return: An object that will lazily evaluate to the kernel matrix or vector. The shape depends on the kernel's evaluation mode: * `full_covar`: `... x N x M` - * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M` * `diag`: `... x N` - * `diag` with `last_dim_is_batch=True`: `... x K x N` """ - if last_dim_is_batch: - warnings.warn( - "The last_dim_is_batch argument is deprecated, and will be removed in GPyTorch 2.0. " - "If you are using it as part of AdditiveStructureKernel or ProductStructureKernel, " - 'please update your code according to the "Kernels with Additive or Product Structure" ' - "tutorial in the GPyTorch docs.", - DeprecationWarning, - ) - x1_, x2_ = x1, x2 # Select the active dimensions @@ -523,7 +570,7 @@ def __call__( ) if diag: - res = super(Kernel, self).__call__(x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, **params) + res = super(Kernel, self).__call__(x1_, x2_, diag=True, **params) # Did this Kernel eat the diag option? # If it does not return a LazyEvaluatedKernelTensor, we can call diag on the output if not isinstance(res, LazyEvaluatedKernelTensor): @@ -532,12 +579,65 @@ def __call__( return res else: - if settings.lazily_evaluate_kernels.on(): - res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, **params) - else: - res = to_linear_operator( - super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params) + if settings.lazily_evaluate_kernels.on() and self._lazily_evaluate: + num_outputs_per_input = self.num_outputs_per_input(x1_, x2_) + if isinstance(num_outputs_per_input, int): + num_outputs_per_input = (num_outputs_per_input, num_outputs_per_input) + + def _get_parameter_parent_module_and_batch_shape(module): + num_module_batch_dimension = len(module.batch_shape) if isinstance(module, Kernel) else 0 + for name, param in module._parameters.items(): + yield name, (param, module, param.dim() - num_module_batch_dimension) + + # The following returns a list of tuples for each parameter + parameters of sub-modules: + # (param_name, (param_val, param_parent_module, param_batch_shape)) + named_parameters_parent_modules_and_batch_dimensions = tuple( + self._named_members( + _get_parameter_parent_module_and_batch_shape, + prefix="", + recurse=True, + ) ) + + if len(named_parameters_parent_modules_and_batch_dimensions): + # Information we need for the KernelLinearOperator, as well as the autograd hack: + # - the names/values of all parameters + # - the parent module associated with each parameter + # - the number of non-batch dimensions associated with each parameter + # WE get this information from the list constructed in the previous step + params = dict() + module_params = defaultdict(list) + num_nonbatch_dimensions = dict() + for name, ( + param, + parent_module, + num_nonbatch_dimension, + ) in named_parameters_parent_modules_and_batch_dimensions: + params[name] = param + module_params[parent_module].append(name) + num_nonbatch_dimensions[name] = num_nonbatch_dimension + + # Construct the KernelLinearOperator + res = KernelLinearOperator( + x1_, + x2_, + covar_func=self._kernel_linear_operator_covar_func, + num_outputs_per_input=num_outputs_per_input, + num_nonbatch_dimensions=num_nonbatch_dimensions, + module_params=module_params, # params for _kernel_linear_operator_covar_func + non_param_kwargs=dict(**params), # params for forward + **params, + ) + else: + res = KernelLinearOperator( + x1_, + x2_, + covar_func=self.forward, + num_outputs_per_input=num_outputs_per_input, + non_param_kwargs=dict(**params), # params for forward + ) + else: + res = to_linear_operator(super(Kernel, self).__call__(x1_, x2_, **params)) return res def __getstate__(self): @@ -608,13 +708,17 @@ class AdditiveKernel(Kernel): :param kernels: Kernels to add together. """ + def __init__(self, *kernels: Iterable[Kernel]): + super(AdditiveKernel, self).__init__() + self.kernels = ModuleList(kernels) + @property def is_stationary(self) -> bool: return all(k.is_stationary for k in self.kernels) - def __init__(self, *kernels: Iterable[Kernel]): - super(AdditiveKernel, self).__init__() - self.kernels = ModuleList(kernels) + @property + def _lazily_evaluate(self) -> bool: + return all(k._lazily_evaluate for k in self.kernels) def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]: res = ZeroLinearOperator() if not diag else 0 @@ -650,13 +754,17 @@ class ProductKernel(Kernel): :param kernels: Kernels to multiply together. """ + def __init__(self, *kernels: Iterable[Kernel]): + super(ProductKernel, self).__init__() + self.kernels = ModuleList(kernels) + @property def is_stationary(self) -> bool: return all(k.is_stationary for k in self.kernels) - def __init__(self, *kernels: Iterable[Kernel]): - super(ProductKernel, self).__init__() - self.kernels = ModuleList(kernels) + @property + def _lazily_evaluate(self) -> bool: + return False def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]: x1_eq_x2 = torch.equal(x1, x2) diff --git a/gpytorch/kernels/linear_kernel.py b/gpytorch/kernels/linear_kernel.py index d7ecd1014..c2df55e5f 100644 --- a/gpytorch/kernels/linear_kernel.py +++ b/gpytorch/kernels/linear_kernel.py @@ -72,6 +72,12 @@ def __init__( self.register_constraint("raw_variance", variance_constraint) + @property + def _lazily_evaluate(self) -> bool: + # LinearKernel should not lazily evaluate; to use the Woodbury formula, + # we want the Kernel to return a LowRankLinearOperator, not a KernelLinaerOperator. + return False + @property def variance(self) -> Tensor: return self.raw_variance_constraint.transform(self.raw_variance) @@ -85,12 +91,8 @@ def _set_variance(self, value: Union[float, Tensor]): value = torch.as_tensor(value).to(self.raw_variance) self.initialize(raw_variance=self.raw_variance_constraint.inverse_transform(value)) - def forward( - self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, last_dim_is_batch: Optional[bool] = False, **params - ) -> LinearOperator: + def forward(self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, **params) -> LinearOperator: x1_ = x1 * self.variance.sqrt() - if last_dim_is_batch: - x1_ = x1_.transpose(-1, -2).unsqueeze(-1) if x1.size() == x2.size() and torch.equal(x1, x2): # Use RootLinearOperator when x1 == x2 for efficiency when composing @@ -99,9 +101,6 @@ def forward( else: x2_ = x2 * self.variance.sqrt() - if last_dim_is_batch: - x2_ = x2_.transpose(-1, -2).unsqueeze(-1) - prod = MatmulLinearOperator(x1_, x2_.transpose(-2, -1)) if diag: diff --git a/gpytorch/kernels/matern_kernel.py b/gpytorch/kernels/matern_kernel.py index baf145e36..3824ef49b 100644 --- a/gpytorch/kernels/matern_kernel.py +++ b/gpytorch/kernels/matern_kernel.py @@ -89,7 +89,6 @@ def forward(self, x1, x2, diag=False, **params): or x2.requires_grad or (self.ard_num_dims is not None and self.ard_num_dims > 1) or diag - or params.get("last_dim_is_batch", False) or trace_mode.on() ): mean = x1.mean(dim=-2, keepdim=True) diff --git a/gpytorch/kernels/multi_device_kernel.py b/gpytorch/kernels/multi_device_kernel.py index 3d416a1c9..43bab2132 100644 --- a/gpytorch/kernels/multi_device_kernel.py +++ b/gpytorch/kernels/multi_device_kernel.py @@ -42,10 +42,18 @@ def __init__( self.__cached_x1 = torch.empty(1) self.__cached_x2 = torch.empty(1) + @property + def _lazily_evaluate(self) -> bool: + return self.base_kernel._lazily_evaluate + @property def base_kernel(self): return self.module + @property + def is_stationary(self): + return self.base_kernel.is_stationary + def forward(self, x1, x2, diag=False, **kwargs): if diag: return self.module.forward(x1, x2, diag=True, **kwargs).to(self.output_device) diff --git a/gpytorch/kernels/multitask_kernel.py b/gpytorch/kernels/multitask_kernel.py index 79a4f1388..f2e0b16d7 100644 --- a/gpytorch/kernels/multitask_kernel.py +++ b/gpytorch/kernels/multitask_kernel.py @@ -43,9 +43,7 @@ def __init__( self.data_covar_module = data_covar_module self.num_tasks = num_tasks - def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): - if last_dim_is_batch: - raise RuntimeError("MultitaskKernel does not accept the last_dim_is_batch argument.") + def forward(self, x1, x2, diag=False, **params): covar_i = self.task_covar_module.covar_matrix if len(x1.shape[:-2]): covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1) diff --git a/gpytorch/kernels/newton_girard_additive_kernel.py b/gpytorch/kernels/newton_girard_additive_kernel.py deleted file mode 100644 index 89be591cf..000000000 --- a/gpytorch/kernels/newton_girard_additive_kernel.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 - -import warnings -from typing import Optional, Tuple - -import torch -from linear_operator import to_dense - -from ..constraints import Positive -from .kernel import Kernel - - -class NewtonGirardAdditiveKernel(Kernel): - def __init__( - self, - base_kernel: Kernel, - num_dims: int, - max_degree: Optional[int] = None, - active_dims: Optional[Tuple[int, ...]] = None, - **kwargs, - ): - """Create an Additive Kernel a la https://arxiv.org/abs/1112.4394 using Newton-Girard Formulae - - :param base_kernel: a base 1-dimensional kernel. NOTE: put ard_num_dims=d in the base kernel... - :param max_degree: the maximum numbers of kernel degrees to compute - :param active_dims: - :param kwargs: - """ - - warnings.warn( - "NewtonGirardAdditiveKernel is deprecated, and will be removed in GPyTorch 2.0. " - 'Please refer to the "Kernels with Additive or Product Structure" tutorial ' - "in the GPyTorch docs for how to implement GPs with additive structure.", - DeprecationWarning, - ) - super(NewtonGirardAdditiveKernel, self).__init__(active_dims=active_dims, **kwargs) - - self.base_kernel = base_kernel - self.num_dims = num_dims - if max_degree is None: - self.max_degree = self.num_dims - elif max_degree > self.num_dims: # force cap on max_degree (silently) - self.max_degree = self.num_dims - else: - self.max_degree = max_degree - - self.register_parameter( - name="raw_outputscale", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, self.max_degree)) - ) - outputscale_constraint = Positive() - self.register_constraint("raw_outputscale", outputscale_constraint) - self.outputscale_constraint = outputscale_constraint - self.outputscale = [1 / self.max_degree for _ in range(self.max_degree)] - - @property - def outputscale(self): - return self.raw_outputscale_constraint.transform(self.raw_outputscale) - - @outputscale.setter - def outputscale(self, value): - self._set_outputscale(value) - - def _set_outputscale(self, value): - if not torch.is_tensor(value): - value = torch.as_tensor(value).to(self.raw_outputscale) - - self.initialize(raw_outputscale=self.outputscale_constraint.inverse_transform(value)) - - def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): - """Forward proceeds by Newton-Girard formulae""" - if last_dim_is_batch: - raise RuntimeError("NewtonGirardAdditiveKernel does not accept the last_dim_is_batch argument.") - - # NOTE: comments about shape are only correct for the single-batch cases. - # kern_values is just the order-1 terms - # kern_values = D x n x n unless diag=True - kern_values = to_dense(self.base_kernel(x1, x2, diag=diag, last_dim_is_batch=True, **params)) - # last dim is batch, which gets moved up to pos. 1 - - kernel_dim = -3 if not diag else -2 - - shape = [1 for _ in range(len(kern_values.shape) + 1)] - shape[kernel_dim - 1] = -1 - kvals = torch.arange(1, self.max_degree + 1, device=kern_values.device).reshape(*shape) - # kvals = R x 1 x 1 x 1 (these are indexes only) - - # e_n = torch.ones(self.max_degree+1, *kern_values.shape[1:], device=kern_values.device) # includes 0 - # e_n: elementary symmetric polynomial of degree n (e.g. z1 z2 + z1 z3 + z2 z3) - # e_n is R x n x n, and the array is properly 0 indexed. - shape = [d_ for d_ in kern_values.shape] - shape[kernel_dim] = self.max_degree + 1 - e_n = torch.empty(*shape, device=kern_values.device) - if kernel_dim == -3: - e_n[..., 0, :, :] = 1.0 - else: - e_n[..., 0, :] = 1.0 - - # power sums s_k (e.g. sum_i^num_dims z_i^k - # s_k is R x n x n - s_k = kern_values.unsqueeze(kernel_dim - 1).pow(kvals).sum(dim=kernel_dim) - - # just the constant -1 - m1 = torch.tensor([-1], dtype=torch.float, device=kern_values.device) - - shape = [1 for _ in range(len(kern_values.shape))] - shape[kernel_dim] = -1 - for deg in range(1, self.max_degree + 1): # deg goes from 1 to R (it's 1-indexed!) - # we avg over k [1, ..., deg] (-1)^(k-1)e_{deg-k} s_{k} - - ks = torch.arange(1, deg + 1, device=kern_values.device, dtype=torch.float).reshape(*shape) # use for pow - kslong = torch.arange(1, deg + 1, device=kern_values.device, dtype=torch.long) # use for indexing - - # note that s_k is 0-indexed, so we must subtract 1 from kslong - sum_ = ( - m1.pow(ks - 1) * e_n.index_select(kernel_dim, deg - kslong) * s_k.index_select(kernel_dim, kslong - 1) - ).sum(dim=kernel_dim) / deg - if kernel_dim == -3: - e_n[..., deg, :, :] = sum_ - else: - e_n[..., deg, :] = sum_ - - if kernel_dim == -3: - return (self.outputscale.unsqueeze(-1).unsqueeze(-1) * e_n.narrow(kernel_dim, 1, self.max_degree)).sum( - dim=kernel_dim - ) - else: - return (self.outputscale.unsqueeze(-1) * e_n.narrow(kernel_dim, 1, self.max_degree)).sum(dim=kernel_dim) diff --git a/gpytorch/kernels/periodic_kernel.py b/gpytorch/kernels/periodic_kernel.py index 2972b523a..f8c543921 100644 --- a/gpytorch/kernels/periodic_kernel.py +++ b/gpytorch/kernels/periodic_kernel.py @@ -124,23 +124,20 @@ def _set_period_length(self, value): self.initialize(raw_period_length=self.raw_period_length_constraint.inverse_transform(value)) def forward(self, x1, x2, diag=False, **params): - # Pop this argument so that we can manually sum over dimensions - last_dim_is_batch = params.pop("last_dim_is_batch", False) # Get lengthscale lengthscale = self.lengthscale x1_ = x1.div(self.period_length / math.pi) x2_ = x2.div(self.period_length / math.pi) - # We are automatically overriding last_dim_is_batch here so that we can manually sum over dimensions. - diff = self.covar_dist(x1_, x2_, diag=diag, last_dim_is_batch=True, **params) + diff = self.covar_dist( + x1_.transpose(-1, -2).unsqueeze(-1), x2_.transpose(-1, -2).unsqueeze(-1), diag=diag, **params + ) # A ... x D x N x N kernel if diag: lengthscale = lengthscale[..., 0, :, None] else: lengthscale = lengthscale[..., 0, :, None, None] exp_term = diff.sin().pow(2.0).div(lengthscale).mul(-2.0) - - if not last_dim_is_batch: - exp_term = exp_term.sum(dim=(-2 if diag else -3)) + exp_term = exp_term.sum(dim=(-2 if diag else -3)) return exp_term.exp() diff --git a/gpytorch/kernels/piecewise_polynomial_kernel.py b/gpytorch/kernels/piecewise_polynomial_kernel.py index 8135979f0..0bdb358c7 100644 --- a/gpytorch/kernels/piecewise_polynomial_kernel.py +++ b/gpytorch/kernels/piecewise_polynomial_kernel.py @@ -101,20 +101,13 @@ def __init__(self, q: Optional[int] = 2, **kwargs): raise ValueError("q expected to be 0, 1, 2 or 3") self.q = q - def forward(self, x1: Tensor, x2: Tensor, last_dim_is_batch: bool = False, diag: bool = False, **params) -> Tensor: + def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Tensor: x1_ = x1.div(self.lengthscale) x2_ = x2.div(self.lengthscale) - if last_dim_is_batch is True: - D = x1.shape[1] - else: - D = x1.shape[-1] + D = x1.shape[-1] j = math.floor(D / 2.0) + self.q + 1 - if last_dim_is_batch and diag: - r = self.covar_dist(x1_, x2_, last_dim_is_batch=True, diag=True) - elif diag: + if diag: r = self.covar_dist(x1_, x2_, diag=True) - elif last_dim_is_batch: - r = self.covar_dist(x1_, x2_, last_dim_is_batch=True) else: r = self.covar_dist(x1_, x2_) cov_matrix = _fmax(r, j, self.q) * _get_cov(r, j, self.q) diff --git a/gpytorch/kernels/polynomial_kernel.py b/gpytorch/kernels/polynomial_kernel.py index 3a98e8d4e..1dd57be88 100644 --- a/gpytorch/kernels/polynomial_kernel.py +++ b/gpytorch/kernels/polynomial_kernel.py @@ -81,15 +81,10 @@ def forward( x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = False, - last_dim_is_batch: Optional[bool] = False, **params, ) -> torch.Tensor: offset = self.offset.view(*self.batch_shape, 1, 1) - if last_dim_is_batch: - x1 = x1.transpose(-1, -2).unsqueeze(-1) - x2 = x2.transpose(-1, -2).unsqueeze(-1) - if diag: return ((x1 * x2).sum(dim=-1) + self.offset).pow(self.power) diff --git a/gpytorch/kernels/polynomial_kernel_grad.py b/gpytorch/kernels/polynomial_kernel_grad.py index f499bc23a..a8a17313d 100644 --- a/gpytorch/kernels/polynomial_kernel_grad.py +++ b/gpytorch/kernels/polynomial_kernel_grad.py @@ -13,7 +13,6 @@ def forward( x1: torch.Tensor, x2: torch.Tensor, diag: Optional[bool] = False, - last_dim_is_batch: Optional[bool] = False, **params, ) -> torch.Tensor: offset = self.offset.view(*self.batch_shape, 1, 1) diff --git a/gpytorch/kernels/product_structure_kernel.py b/gpytorch/kernels/product_structure_kernel.py deleted file mode 100644 index 49f782876..000000000 --- a/gpytorch/kernels/product_structure_kernel.py +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env python3 - -import warnings -from typing import Optional, Tuple - -from linear_operator.operators import to_linear_operator - -from .kernel import Kernel - - -class ProductStructureKernel(Kernel): - r""" - A Kernel decorator for kernels with product structure. If a kernel decomposes - multiplicatively, then this module will be much more computationally efficient. - - A kernel function `k` has product structure if it can be written as - - .. math:: - - \begin{equation*} - k(\mathbf{x_1}, \mathbf{x_2}) = k'(x_1^{(1)}, x_2^{(1)}) * \ldots * k'(x_1^{(d)}, x_2^{(d)}) - \end{equation*} - - for some kernel :math:`k'` that operates on each dimension. - - Given a `b x n x d` input, `ProductStructureKernel` computes `d` one-dimensional kernels - (using the supplied base_kernel), and then multiplies the component kernels together. - Unlike :class:`~gpytorch.kernels.ProductKernel`, `ProductStructureKernel` computes each - of the product terms in batch, making it very fast. - - See `Product Kernel Interpolation for Scalable Gaussian Processes`_ for more detail. - - Args: - base_kernel (Kernel): - The kernel to approximate with KISS-GP - num_dims (int): - The dimension of the input data. - active_dims (tuple of ints, optional): - Passed down to the `base_kernel`. - - .. _Product Kernel Interpolation for Scalable Gaussian Processes: - https://arxiv.org/pdf/1802.08903 - """ - - @property - def is_stationary(self) -> bool: - """ - Kernel is stationary if the base kernel is stationary. - """ - return self.base_kernel.is_stationary - - def __init__( - self, - base_kernel: Kernel, - num_dims: int, - active_dims: Optional[Tuple[int, ...]] = None, - ): - warnings.warn( - "ProductStructureKernel is deprecated, and will be removed in GPyTorch 2.0. " - 'Please refer to the "Kernels with Additive or Product Structure" tutorial ' - "in the GPyTorch docs for how to implement GPs with product structure.", - DeprecationWarning, - ) - - super(ProductStructureKernel, self).__init__(active_dims=active_dims) - self.base_kernel = base_kernel - self.num_dims = num_dims - - def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): - if last_dim_is_batch: - raise RuntimeError("ProductStructureKernel does not accept the last_dim_is_batch argument.") - - res = self.base_kernel(x1, x2, diag=diag, last_dim_is_batch=True, **params) - res = res.prod(-2 if diag else -3) - return res - - def num_outputs_per_input(self, x1, x2): - return self.base_kernel.num_outputs_per_input(x1, x2) - - def __call__(self, x1_, x2_=None, diag=False, last_dim_is_batch=False, **params): - """ - We cannot lazily evaluate actual kernel calls when using SKIP, because we - cannot root decompose rectangular matrices. - - Because we slice in to the kernel during prediction to get the test x train - covar before calling evaluate_kernel, the order of operations would mean we - would get a MulLinearOperator representing a rectangular matrix, which we - cannot matmul with because we cannot root decompose it. Thus, SKIP actually - *requires* that we work with the full (train + test) x (train + test) - kernel matrix. - """ - res = super().__call__(x1_, x2_, diag=diag, last_dim_is_batch=last_dim_is_batch, **params) - res = to_linear_operator(res).evaluate_kernel() - return res diff --git a/gpytorch/kernels/rbf_kernel.py b/gpytorch/kernels/rbf_kernel.py index 932e59724..073d30f3e 100644 --- a/gpytorch/kernels/rbf_kernel.py +++ b/gpytorch/kernels/rbf_kernel.py @@ -71,7 +71,6 @@ def forward(self, x1, x2, diag=False, **params): or x2.requires_grad or (self.ard_num_dims is not None and self.ard_num_dims > 1) or diag - or params.get("last_dim_is_batch", False) or trace_mode.on() ): x1_ = x1.div(self.lengthscale) diff --git a/gpytorch/kernels/rff_kernel.py b/gpytorch/kernels/rff_kernel.py index c6b5e4ccd..8f39be147 100644 --- a/gpytorch/kernels/rff_kernel.py +++ b/gpytorch/kernels/rff_kernel.py @@ -98,6 +98,12 @@ def __init__(self, num_samples: int, num_dims: Optional[int] = None, **kwargs): if num_dims is not None: self._init_weights(num_dims, num_samples) + @property + def _lazily_evaluate(self) -> bool: + # RFF kernels should not lazily evaluate; to use the Woodbury formula, + # we want the Kernel to return a LowRankLinearOperator, not a KernelLinaerOperator. + return False + def _init_weights( self, num_dims: Optional[int] = None, num_samples: Optional[int] = None, randn_weights: Optional[Tensor] = None ): @@ -111,10 +117,7 @@ def _init_weights( ) self.register_buffer("randn_weights", randn_weights) - def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **kwargs) -> Tensor: - if last_dim_is_batch: - x1 = x1.transpose(-1, -2).unsqueeze(-1) - x2 = x2.transpose(-1, -2).unsqueeze(-1) + def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> Tensor: num_dims = x1.size(-1) if not hasattr(self, "randn_weights"): self._init_weights(num_dims, self.num_samples) diff --git a/gpytorch/kernels/scale_kernel.py b/gpytorch/kernels/scale_kernel.py index 520913265..c8c02d30b 100644 --- a/gpytorch/kernels/scale_kernel.py +++ b/gpytorch/kernels/scale_kernel.py @@ -54,13 +54,6 @@ class ScaleKernel(Kernel): >>> covar = scaled_covar_module(x) # Output: LinearOperator of size (10 x 10) """ - @property - def is_stationary(self) -> bool: - """ - Kernel is stationary if base kernel is stationary. - """ - return self.base_kernel.is_stationary - def __init__( self, base_kernel: Kernel, @@ -86,6 +79,17 @@ def __init__( self.register_constraint("raw_outputscale", outputscale_constraint) + @property + def _lazily_evaluate(self) -> bool: + return self.base_kernel._lazily_evaluate + + @property + def is_stationary(self) -> bool: + """ + Kernel is stationary if base kernel is stationary. + """ + return self.base_kernel.is_stationary + def _outputscale_param(self, m): return m.outputscale @@ -105,11 +109,9 @@ def _set_outputscale(self, value): value = torch.as_tensor(value).to(self.raw_outputscale) self.initialize(raw_outputscale=self.raw_outputscale_constraint.inverse_transform(value)) - def forward(self, x1, x2, last_dim_is_batch=False, diag=False, **params): - orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params) + def forward(self, x1, x2, diag=False, **params): + orig_output = self.base_kernel.forward(x1, x2, diag=diag, **params) outputscales = self.outputscale - if last_dim_is_batch: - outputscales = outputscales.unsqueeze(-1) if diag: outputscales = outputscales.unsqueeze(-1) return to_dense(orig_output) * outputscales diff --git a/gpytorch/kernels/spectral_mixture_kernel.py b/gpytorch/kernels/spectral_mixture_kernel.py index c8de79010..037dd2ef8 100644 --- a/gpytorch/kernels/spectral_mixture_kernel.py +++ b/gpytorch/kernels/spectral_mixture_kernel.py @@ -72,8 +72,6 @@ class SpectralMixtureKernel(Kernel): https://arxiv.org/pdf/1302.4245.pdf """ - is_stationary = True # kernel is stationary even though it does not have a lengthscale - def __init__( self, num_mixtures: Optional[int] = None, @@ -116,6 +114,11 @@ def __init__( self.register_constraint("raw_mixture_means", mixture_means_constraint) self.register_constraint("raw_mixture_weights", mixture_weights_constraint) + @property + def is_stationary(self) -> bool: + # kernel is stationary even though it does not have a lengthscale + return True + @property def mixture_scales(self): return self.raw_mixture_scales_constraint.transform(self.raw_mixture_scales) @@ -268,7 +271,7 @@ def initialize_from_data(self, train_x: torch.Tensor, train_y: torch.Tensor, **k self.mixture_weights = train_y.std().div(self.num_mixtures) def _create_input_grid( - self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params + self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params ) -> Tuple[torch.Tensor, torch.Tensor]: """ This is a helper method for creating a grid of the kernel's inputs. @@ -280,33 +283,20 @@ def _create_input_grid( :param torch.Tensor x2: ... x m x d (for diag mode, these must be the same inputs) :param diag: Should the Kernel compute the whole kernel, or just the diag? (Default: True.) :type diag: bool, optional - :param last_dim_is_batch: If this is true, it treats the last dimension - of the data as another batch dimension. (Useful for additive - structure over the dimensions). (Default: False.) - :type last_dim_is_batch: bool, optional :rtype: torch.Tensor, torch.Tensor :return: Grid corresponding to x1 and x2. The shape depends on the kernel's mode: * `full_covar`: (`... x n x 1 x d` and `... x 1 x m x d`) - * `full_covar` with `last_dim_is_batch=True`: (`... x k x n x 1 x 1` and `... x k x 1 x m x 1`) * `diag`: (`... x n x d` and `... x n x d`) - * `diag` with `last_dim_is_batch=True`: (`... x k x n x 1` and `... x k x n x 1`) """ x1_, x2_ = x1, x2 - if last_dim_is_batch: - x1_ = x1_.transpose(-1, -2).unsqueeze(-1) - if torch.equal(x1, x2): - x2_ = x1_ - else: - x2_ = x2_.transpose(-1, -2).unsqueeze(-1) - if diag: return x1_, x2_ else: return x1_.unsqueeze(-2), x2_.unsqueeze(-3) def forward( - self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params + self, x1: torch.Tensor, x2: torch.Tensor, diag: bool = False, **params ) -> Tuple[torch.Tensor, torch.Tensor]: n, num_dims = x1.shape[-2:] @@ -344,10 +334,5 @@ def forward( res = (res * mixture_weights).sum(-3 if diag else -4) # Product over dimensions - if last_dim_is_batch: - # Put feature-dimension in front of data1/data2 dimensions - res = res.permute(*list(range(0, res.dim() - 3)), -1, -3, -2) - else: - res = res.prod(-1) - + res = res.prod(-1) return res diff --git a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py index 3efe398b4..d2c23e810 100644 --- a/gpytorch/lazy/lazy_evaluated_kernel_tensor.py +++ b/gpytorch/lazy/lazy_evaluated_kernel_tensor.py @@ -31,20 +31,17 @@ def wrapped(self, *args, **kwargs): class LazyEvaluatedKernelTensor(LinearOperator): _check_size = False - def _check_args(self, x1, x2, kernel, last_dim_is_batch=False, **params): + def _check_args(self, x1, x2, kernel, **params): if not torch.is_tensor(x1): return "x1 must be a tensor. Got {}".format(x1.__class__.__name__) if not torch.is_tensor(x2): return "x1 must be a tensor. Got {}".format(x1.__class__.__name__) - def __init__(self, x1, x2, kernel, last_dim_is_batch=False, **params): - super(LazyEvaluatedKernelTensor, self).__init__( - x1, x2, kernel=kernel, last_dim_is_batch=last_dim_is_batch, **params - ) + def __init__(self, x1, x2, kernel, **params): + super(LazyEvaluatedKernelTensor, self).__init__(x1, x2, kernel=kernel, **params) self.kernel = kernel self.x1 = x1 self.x2 = x2 - self.last_dim_is_batch = last_dim_is_batch self.params = params self._is_grad_enabled = torch.is_grad_enabled() # records grad state at instantiation @@ -92,7 +89,6 @@ def _bilinear_derivative(self, left_vecs, right_vecs): sub_x1, x2, diag=False, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) ) @@ -115,9 +111,7 @@ def _diagonal(self) -> torch.Tensor: x1 = self.x1 x2 = self.x2 - res = super(Kernel, self.kernel).__call__( - x1, x2, diag=True, last_dim_is_batch=self.last_dim_is_batch, **self.params - ) + res = super(Kernel, self.kernel).__call__(x1, x2, diag=True, **self.params) # Now we'll make sure that the shape we're getting from diag makes sense if settings.debug.on(): @@ -193,12 +187,7 @@ def _getitem(self, row_index, col_index, *batch_indices): col_index = slice(col_start // num_outs_per_in_cols, col_end // num_outs_per_in_cols, None) # Define the index we're using for the last index - # If the last index corresponds to a batch, then we'll use the appropriate batch_index - # Otherwise, we'll use the _noop_index - if self.last_dim_is_batch: - *batch_indices, dim_index = batch_indices - else: - dim_index = _noop_index + dim_index = _noop_index # Get the indices of x1 and x2 that matter for the kernel # Call x1[*batch_indices, row_index, :] @@ -238,7 +227,6 @@ def _getitem(self, row_index, col_index, *batch_indices): x1, x2, kernel=new_kernel, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) @@ -265,7 +253,6 @@ def _matmul(self, rhs): sub_x1, x2, diag=False, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) ) @@ -312,9 +299,6 @@ def _size(self): f"Got x1.shape = {x1.shape} and x2.shape = {x2.shape}" ) - # Handle when the last dim is batch - if self.last_dim_is_batch: - expected_size = expected_size[:-2] + x1.shape[-1:] + expected_size[-2:] return expected_size @recall_grad_state @@ -323,7 +307,6 @@ def _transpose_nonbatch(self): self.x2, self.x1, kernel=self.kernel, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) @@ -335,7 +318,6 @@ def _unsqueeze_batch(self, dim): x1, x2, kernel=self.kernel, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) @@ -356,7 +338,6 @@ def evaluate_kernel(self): x1, x2, diag=False, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) self.kernel.active_dims = temp_active_dims @@ -383,7 +364,6 @@ def repeat(self, *repeats): x1, x2, kernel=self.kernel, - last_dim_is_batch=self.last_dim_is_batch, **self.params, ) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 2e95e2162..ad4760e86 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -835,7 +835,6 @@ def exact_prediction(self, joint_mean, joint_covar): test_test_covar.x1, test_test_covar.x2, test_test_covar.kernel.base_kernel, - test_test_covar.last_dim_is_batch, **test_test_covar.params, ) diff --git a/gpytorch/utils/grid.py b/gpytorch/utils/grid.py index 7cd57d877..aa61b5d39 100644 --- a/gpytorch/utils/grid.py +++ b/gpytorch/utils/grid.py @@ -99,10 +99,6 @@ def choose_grid_size(train_inputs, ratio=1.0, kronecker_structure=True): return ratio * num_data -def convert_legacy_grid(grid: torch.Tensor) -> List[torch.Tensor]: - return [grid[:, i] for i in range(grid.size(-1))] - - def create_data_from_grid(grid: List[torch.Tensor]) -> torch.Tensor: """ :param grid: Each Tensor is a 1D set of increments for the grid in that dimension @@ -110,8 +106,6 @@ def create_data_from_grid(grid: List[torch.Tensor]) -> torch.Tensor: :return: The set of points on the grid going by column-major order :rtype: torch.Tensor """ - if torch.is_tensor(grid): - grid = convert_legacy_grid(grid) ndims = len(grid) assert all(axis.dim() == 1 for axis in grid) projections = torch.meshgrid(*grid, indexing="ij") diff --git a/gpytorch/utils/interpolation.py b/gpytorch/utils/interpolation.py index 0f7e1a596..2349d11ff 100644 --- a/gpytorch/utils/interpolation.py +++ b/gpytorch/utils/interpolation.py @@ -8,8 +8,6 @@ import torch from linear_operator.utils.interpolation import left_interp as _left_interp, left_t_interp as _left_t_interp -from .grid import convert_legacy_grid - class Interpolation(object): def _cubic_interpolation_kernel(self, scaled_grid_dist): @@ -41,8 +39,6 @@ def _cubic_interpolation_kernel(self, scaled_grid_dist): return res def interpolate(self, x_grid: List[torch.Tensor], x_target: torch.Tensor, interp_points=range(-2, 2), eps=1e-10): - if torch.is_tensor(x_grid): - x_grid = convert_legacy_grid(x_grid) num_target_points = x_target.size(0) num_dim = x_target.size(-1) assert num_dim == len(x_grid) diff --git a/gpytorch/variational/grid_interpolation_variational_strategy.py b/gpytorch/variational/grid_interpolation_variational_strategy.py index 15c934043..ad57f9421 100644 --- a/gpytorch/variational/grid_interpolation_variational_strategy.py +++ b/gpytorch/variational/grid_interpolation_variational_strategy.py @@ -30,6 +30,11 @@ class GridInterpolationVariationalStrategy(_VariationalStrategy): :param list grid_bounds: Bounds of each dimension of the grid (should be a list of (float, float) tuples) :param ~gpytorch.variational.VariationalDistribution variational_distribution: A VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)` + + :ivar grid: The grid of points that the inducing points are based on. + The grid is stored as a matrix, where each column corresponds to the + projection of the grid onto one dimension. + :type grid: torch.Tensor (M x D) """ def __init__(self, model, grid_size, grid_bounds, variational_distribution): @@ -51,15 +56,14 @@ def __init__(self, model, grid_size, grid_bounds, variational_distribution): model, inducing_points, variational_distribution, learn_inducing_locations=False ) object.__setattr__(self, "model", model) - self.register_buffer("grid", grid) def _compute_grid(self, inputs): - n_data, n_dimensions = inputs.size(-2), inputs.size(-1) - batch_shape = inputs.shape[:-2] + *batch_shape, n_data, n_dimensions = inputs.shape + grid = tuple(self.grid[..., i] for i in range(n_dimensions)) inputs = inputs.reshape(-1, n_dimensions) - interp_indices, interp_values = Interpolation().interpolate(self.grid, inputs) + interp_indices, interp_values = Interpolation().interpolate(grid, inputs) interp_indices = interp_indices.view(*batch_shape, n_data, -1) interp_values = interp_values.view(*batch_shape, n_data, -1) diff --git a/test/examples/test_grid_gp_regression.py b/test/examples/test_grid_gp_regression.py index 9d4453be6..18176ff51 100644 --- a/test/examples/test_grid_gp_regression.py +++ b/test/examples/test_grid_gp_regression.py @@ -61,12 +61,12 @@ def test_grid_gp_mean_abs_error(self, num_dim=1, cuda=False): device = torch.device("cuda") if cuda else torch.device("cpu") grid_bounds = [(0, 1)] if num_dim == 1 else [(0, 1), (0, 2)] grid_size = 25 - grid = torch.zeros(grid_size, len(grid_bounds), device=device) + grid = [] for i in range(len(grid_bounds)): grid_diff = float(grid_bounds[i][1] - grid_bounds[i][0]) / (grid_size - 2) - grid[:, i] = torch.linspace( + grid.append(torch.linspace( grid_bounds[i][0] - grid_diff, grid_bounds[i][1] + grid_diff, grid_size, device=device - ) + )) train_x, train_y, test_x, test_y = make_data(grid, cuda=cuda) likelihood = gpytorch.likelihoods.GaussianLikelihood() diff --git a/test/examples/test_kissgp_additive_regression.py b/test/examples/test_kissgp_additive_regression.py index 4018e11ad..087c77d16 100644 --- a/test/examples/test_kissgp_additive_regression.py +++ b/test/examples/test_kissgp_additive_regression.py @@ -9,7 +9,7 @@ import gpytorch from gpytorch.distributions import MultivariateNormal -from gpytorch.kernels import AdditiveStructureKernel, GridInterpolationKernel, RBFKernel, ScaleKernel +from gpytorch.kernels import GridInterpolationKernel, RBFKernel, ScaleKernel from gpytorch.likelihoods import GaussianLikelihood from gpytorch.means import ZeroMean @@ -36,14 +36,12 @@ class GPRegressionModel(gpytorch.models.ExactGP): def __init__(self, train_x, train_y, likelihood): super(GPRegressionModel, self).__init__(train_x, train_y, likelihood) self.mean_module = ZeroMean() - self.base_covar_module = ScaleKernel(RBFKernel(ard_num_dims=2)) - self.covar_module = AdditiveStructureKernel( - GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1), num_dims=2 - ) + self.base_covar_module = ScaleKernel(RBFKernel(batch_shape=torch.Size([2]))) + self.covar_module = GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1) def forward(self, x): mean_x = self.mean_module(x) - covar_x = self.covar_module(x) + covar_x = self.covar_module(x.mT[..., None]).sum(dim=-3) return MultivariateNormal(mean_x, covar_x) diff --git a/test/examples/test_kissgp_multiplicative_regression.py b/test/examples/test_kissgp_multiplicative_regression.py index ca8b40360..d16869f95 100644 --- a/test/examples/test_kissgp_multiplicative_regression.py +++ b/test/examples/test_kissgp_multiplicative_regression.py @@ -10,7 +10,7 @@ import gpytorch from gpytorch.distributions import MultivariateNormal -from gpytorch.kernels import GridInterpolationKernel, ProductStructureKernel, RBFKernel, ScaleKernel +from gpytorch.kernels import GridInterpolationKernel, RBFKernel, ScaleKernel from gpytorch.likelihoods import GaussianLikelihood from gpytorch.means import ConstantMean from gpytorch.priors import SmoothedBoxPrior @@ -42,14 +42,12 @@ class GPRegressionModel(gpytorch.models.ExactGP): def __init__(self, train_x, train_y, likelihood): super(GPRegressionModel, self).__init__(train_x, train_y, likelihood) self.mean_module = ConstantMean(constant_prior=SmoothedBoxPrior(-1, 1)) - self.base_covar_module = ScaleKernel(RBFKernel()) - self.covar_module = ProductStructureKernel( - GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1), num_dims=2 - ) + self.base_covar_module = ScaleKernel(RBFKernel(batch_shape=torch.Size([2]))) + self.covar_module = GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1) def forward(self, x): mean_x = self.mean_module(x) - covar_x = self.covar_module(x) + covar_x = self.covar_module(x.mT[..., None]).prod(dim=-3) return MultivariateNormal(mean_x, covar_x) diff --git a/test/examples/test_simple_gp_regression.py b/test/examples/test_simple_gp_regression.py index caae8c3f3..372ada516 100644 --- a/test/examples/test_simple_gp_regression.py +++ b/test/examples/test_simple_gp_regression.py @@ -441,7 +441,7 @@ def test_posterior_latent_gp_and_likelihood_fast_pred_var(self, cuda=False): self.assertLess(torch.max(var_diff / noise), 0.05) - def test_pyro_sampling(self): + def pending_test_pyro_sampling(self): try: import pyro # noqa from pyro.infer.mcmc import MCMC, NUTS diff --git a/test/kernels/test_constant_kernel.py b/test/kernels/test_constant_kernel.py index 849ec3996..af46029fe 100644 --- a/test/kernels/test_constant_kernel.py +++ b/test/kernels/test_constant_kernel.py @@ -46,17 +46,6 @@ def _test_constant_kernel(self, device: torch.device): # standard deviation is zero iff KM is constant self.assertAlmostEqual(KM.std().item(), 0, places=places) - # testing last_dim_is_batch - with self.subTest(last_dim_is_batch=True): - KD = constant_kernel(X, last_dim_is_batch=True).to(device=device) - self.assertIsInstance(KD, LazyEvaluatedKernelTensor) - KM = KD.to_dense() - self.assertIsInstance(KM, Tensor) - self.assertEqual(KM.shape, (*batch_shape, d, n, n)) - self.assertAlmostEqual(KM.std().item(), 0, places=places) - self.assertEqual(KM.dtype, dtype) - self.assertEqual(KM.device.type, device.type) - # testing diag with self.subTest(diag=True): KD = constant_kernel(X, diag=True) @@ -66,15 +55,6 @@ def _test_constant_kernel(self, device: torch.device): self.assertEqual(KD.dtype, dtype) self.assertEqual(KD.device.type, device.type) - # testing diag and last_dim_is_batch - with self.subTest(diag=True, last_dim_is_batch=True): - KD = constant_kernel(X, diag=True, last_dim_is_batch=True) - self.assertIsInstance(KD, Tensor) - self.assertEqual(KD.shape, (*batch_shape, d, n)) - self.assertAlmostEqual(KD.std().item(), 0, places=places) - self.assertEqual(KD.dtype, dtype) - self.assertEqual(KD.device.type, device.type) - # testing AD with self.subTest(requires_grad=True): X.requires_grad = True diff --git a/test/kernels/test_cosine_kernel.py b/test/kernels/test_cosine_kernel.py index e6d903bd5..9e2d8ac5c 100644 --- a/test/kernels/test_cosine_kernel.py +++ b/test/kernels/test_cosine_kernel.py @@ -30,25 +30,11 @@ def test_computes_periodic_function(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = torch.zeros(2, 3, 3) - for i in range(3): - for j in range(3): - for l in range(2): - actual[l, i, j] = torch.cos(math.pi * ((a[i, l] - b[j, l]) / period)) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_batch(self): a = torch.tensor([[4, 2, 8], [1, 2, 3]], dtype=torch.float).view(2, 3, 1) b = torch.tensor([[0, 2, 1], [-1, 2, 0]], dtype=torch.float).view(2, 3, 1) period = torch.tensor(1, dtype=torch.float).view(1, 1, 1) - kernel = CosineKernel().initialize(period_length=period) + kernel = CosineKernel(batch_shape=torch.Size([1])).initialize(period_length=period) kernel.eval() actual = torch.zeros(2, 3, 3) @@ -81,21 +67,6 @@ def test_batch_separate(self): actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = torch.zeros(2, 2, 3, 3) - for k in range(2): - for i in range(3): - for j in range(3): - for l in range(2): - actual[k, l, i, j] = torch.cos(math.pi * ((a[k, i, l] - b[k, j, l]) / period[k])) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-5) - def create_kernel_with_prior(self, period_length_prior): return CosineKernel(period_length_prior=period_length_prior) diff --git a/test/kernels/test_grid_interpolation_kernel.py b/test/kernels/test_grid_interpolation_kernel.py index 3725f1579..787a5b088 100644 --- a/test/kernels/test_grid_interpolation_kernel.py +++ b/test/kernels/test_grid_interpolation_kernel.py @@ -5,6 +5,7 @@ import torch from linear_operator.operators import InterpolatedLinearOperator +import gpytorch from gpytorch.kernels import GridInterpolationKernel, RBFKernel @@ -14,7 +15,8 @@ def test_standard(self): kernel = GridInterpolationKernel(base_kernel, num_dims=2, grid_size=128, grid_bounds=[(-1.2, 1.2)] * 2) xs = torch.randn(5, 2).clamp(-1, 1) - interp_covar = kernel(xs, xs).evaluate_kernel() + with gpytorch.settings.lazily_evaluate_kernels(False): + interp_covar = kernel(xs, xs) self.assertIsInstance(interp_covar, InterpolatedLinearOperator) xs = torch.randn(5, 2).clamp(-1, 1) diff --git a/test/kernels/test_grid_kernel.py b/test/kernels/test_grid_kernel.py index 8cd682afd..971449504 100644 --- a/test/kernels/test_grid_kernel.py +++ b/test/kernels/test_grid_kernel.py @@ -9,48 +9,34 @@ from gpytorch.kernels import GridKernel, LinearKernel, RBFKernel from gpytorch.utils.grid import create_data_from_grid -grid = [torch.linspace(0, 1, 5), torch.linspace(0, 2, 3)] +grid = [torch.linspace(0, 1, 5), torch.linspace(0, 2, 3), torch.linspace(0, 2, 4)] d = len(grid) grid_data = create_data_from_grid(grid) class TestGridKernel(unittest.TestCase): - def test_grid_grid(self, toeplitz=True): - with gpytorch.settings.use_toeplitz(toeplitz): - base_kernel = RBFKernel(ard_num_dims=2) - kernel = GridKernel(base_kernel, grid) - grid_covar = kernel(grid_data, grid_data).evaluate_kernel() - self.assertIsInstance(grid_covar, KroneckerProductLinearOperator) - grid_eval = kernel(grid_data, grid_data).to_dense() - actual_eval = base_kernel(grid_data, grid_data).to_dense() - self.assertLess(torch.norm(grid_eval - actual_eval), 2e-5) - - def test_grid_grid_nontoeplitz(self): - return self.test_grid_grid(toeplitz=False) - - def test_nongrid_grid(self, toeplitz=True): - with gpytorch.settings.use_toeplitz(toeplitz): - base_kernel = RBFKernel(ard_num_dims=2) - data = torch.randn(5, d) - kernel = GridKernel(base_kernel, grid) - grid_eval = kernel(grid_data, data).to_dense() - actual_eval = base_kernel(grid_data, data).to_dense() - self.assertLess(torch.norm(grid_eval - actual_eval), 1e-5) - - def test_nongrid_grid_nontoeplitz(self): - return self.test_nongrid_grid(toeplitz=False) - - def test_nongrid_nongrid(self, toeplitz=True): - with gpytorch.settings.use_toeplitz(toeplitz): - base_kernel = RBFKernel(ard_num_dims=2) - data = torch.randn(5, d) - kernel = GridKernel(base_kernel, grid) - grid_eval = kernel(data, data).to_dense() - actual_eval = base_kernel(data, data).to_dense() - self.assertLess(torch.norm(grid_eval - actual_eval), 1e-5) - - def test_nongrid_nongrid_nontoeplitz(self): - return self.test_nongrid_nongrid(toeplitz=False) + def test_grid(self): + base_kernel = RBFKernel(ard_num_dims=d) + kernel = GridKernel(base_kernel, grid) + with gpytorch.settings.lazily_evaluate_kernels(False): + grid_covar = kernel(grid_data, grid_data) + self.assertIsInstance(grid_covar, KroneckerProductLinearOperator) + grid_eval = grid_covar.to_dense() + actual_eval = base_kernel(grid_data, grid_data).to_dense() + self.assertLess(torch.norm(grid_eval - actual_eval), 2e-5) + + grid_covar_diag = kernel(grid_data, diag=True) + actual_diag = base_kernel(grid_data, grid_data, diag=True) + self.assertLess(torch.norm(grid_covar_diag - actual_diag), 2e-5) + + def test_nongrid(self): + base_kernel = RBFKernel(ard_num_dims=d) + data = torch.randn(5, d) + kernel = GridKernel(base_kernel, grid) + with gpytorch.settings.lazily_evaluate_kernels(False), self.assertWarnsRegex(RuntimeWarning, "non-grid"): + grid_eval = kernel(data, grid_data).to_dense() + actual_eval = base_kernel(data, grid_data).to_dense() + self.assertLess(torch.norm(grid_eval - actual_eval), 1e-5) def test_non_stationary_base(self): base_kernel = LinearKernel() diff --git a/test/kernels/test_linear_kernel.py b/test/kernels/test_linear_kernel.py index b520842fd..708fc1253 100644 --- a/test/kernels/test_linear_kernel.py +++ b/test/kernels/test_linear_kernel.py @@ -42,18 +42,6 @@ def test_computes_linear_function_square(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-4) - # batch_dims - dim_group_a = a - dim_group_a = dim_group_a.permute(1, 0).reshape(-1, 3) - actual = 3.14 * torch.mul(dim_group_a.unsqueeze(-1), dim_group_a.unsqueeze(-2)) - res = kernel(a, a, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-4) - - # batch_dims + diag - res = kernel(a, a, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-4) - def test_computes_linear_function_square_batch(self): a = torch.tensor([[[4, 1], [2, 0], [8, 3]], [[1, 1], [2, 1], [1, 3]]], dtype=torch.float) @@ -68,18 +56,6 @@ def test_computes_linear_function_square_batch(self): actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-4) - # batch_dims - dim_group_a = a - dim_group_a = dim_group_a.transpose(-1, -2).unsqueeze(-1) - actual = dim_group_a.matmul(dim_group_a.transpose(-2, -1)) - res = kernel(a, a, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-4) - - # batch_dims + diag - res = kernel(a, a, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-4) - def create_kernel_with_prior(self, variance_prior): return self.create_kernel_no_ard(variance_prior=variance_prior) diff --git a/test/kernels/test_matern_kernel.py b/test/kernels/test_matern_kernel.py index a544947e8..20b3a18f9 100644 --- a/test/kernels/test_matern_kernel.py +++ b/test/kernels/test_matern_kernel.py @@ -96,18 +96,6 @@ def test_ard(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - dist = torch.tensor([[[0, 0], [2, 2]], [[1, 1], [0, 0]]], dtype=torch.float) - dist.mul_(math.sqrt(5)) - actual = (dist**2 / 3 + dist + 1).mul(torch.exp(-dist)) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 3]], [[2, -1, 2], [2, -1, 0]]], dtype=torch.float) b = torch.tensor([[[1, 4, 3]], [[2, -1, 0]]], dtype=torch.float) @@ -141,26 +129,6 @@ def test_ard_separate_batch(self): actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - dist = torch.tensor( - [ - [[[0.0, 0.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]], - [[[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], - [[[0.0, 0.0], [0.0, 0.0]], [[4.0, 4.0], [0.0, 0.0]]], - ] - ) - - dist.mul_(math.sqrt(5)) - dist = dist.view(3, 2, 2, 2).transpose(0, 1) - actual = (dist**2 / 3 + dist + 1).mul(torch.exp(-dist)) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-5) - def create_kernel_with_prior(self, lengthscale_prior): return self.create_kernel_no_ard(lengthscale_prior=lengthscale_prior) diff --git a/test/kernels/test_newton_girard_additive_kernel.py b/test/kernels/test_newton_girard_additive_kernel.py deleted file mode 100644 index 698c7fa56..000000000 --- a/test/kernels/test_newton_girard_additive_kernel.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env python3 - -from unittest import TestCase - -import torch - -from gpytorch.distributions import MultivariateNormal -from gpytorch.kernels import AdditiveKernel, NewtonGirardAdditiveKernel, RBFKernel, ScaleKernel -from gpytorch.likelihoods import GaussianLikelihood -from gpytorch.means import ConstantMean -from gpytorch.mlls import ExactMarginalLogLikelihood -from gpytorch.models import ExactGP -from gpytorch.test.base_kernel_test_case import BaseKernelTestCase - - -class TestNewtonGirardAdditiveKernel(TestCase, BaseKernelTestCase): - def create_kernel_no_ard(self, **kwargs): - return NewtonGirardAdditiveKernel(RBFKernel(), 4, 2, **kwargs) - - def create_kernel_ard(self, num_dims, **kwargs): - return NewtonGirardAdditiveKernel(RBFKernel(ard_num_dims=num_dims), num_dims, 2, **kwargs) - - def test_degree1(self): - AddK = NewtonGirardAdditiveKernel(RBFKernel(ard_num_dims=3), 3, 1) - self.assertEqual(AddK.base_kernel.lengthscale.numel(), 3) - self.assertEqual(AddK.outputscale.numel(), 1) - - testvals = torch.tensor([[1, 2, 3], [7, 5, 2]], dtype=torch.float) - add_k_val = AddK(testvals, testvals).to_dense() - - manual_k = ScaleKernel( - AdditiveKernel(RBFKernel(active_dims=0), RBFKernel(active_dims=1), RBFKernel(active_dims=2)) - ) - manual_k.initialize(outputscale=1.0) - manual_add_k_val = manual_k(testvals, testvals).to_dense() - - # np.testing.assert_allclose(add_k_val.detach().numpy(), manual_add_k_val.detach().numpy(), atol=1e-5) - self.assertTrue(torch.allclose(add_k_val, manual_add_k_val, atol=1e-5)) - - def test_degree2(self): - AddK = NewtonGirardAdditiveKernel(RBFKernel(ard_num_dims=3), 3, 2) - self.assertEqual(AddK.base_kernel.lengthscale.numel(), 3) - self.assertEqual(AddK.outputscale.numel(), 2) - - testvals = torch.tensor([[1, 2, 3], [7, 5, 2]], dtype=torch.float) - add_k_val = AddK(testvals, testvals).to_dense() - - manual_k1 = ScaleKernel( - AdditiveKernel(RBFKernel(active_dims=0), RBFKernel(active_dims=1), RBFKernel(active_dims=2)) - ) - manual_k1.initialize(outputscale=1 / 2) - manual_k2 = ScaleKernel( - AdditiveKernel(RBFKernel(active_dims=[0, 1]), RBFKernel(active_dims=[1, 2]), RBFKernel(active_dims=[0, 2])) - ) - manual_k2.initialize(outputscale=1 / 2) - manual_k = AdditiveKernel(manual_k1, manual_k2) - manual_add_k_val = manual_k(testvals, testvals).to_dense() - - # np.testing.assert_allclose(add_k_val.detach().numpy(), manual_add_k_val.detach().numpy(), atol=1e-5) - self.assertTrue(torch.allclose(add_k_val, manual_add_k_val, atol=1e-5)) - - def test_degree3(self): - # just make sure it doesn't break here. - AddK = NewtonGirardAdditiveKernel(RBFKernel(ard_num_dims=3), 3, 3) - self.assertEqual(AddK.base_kernel.lengthscale.numel(), 3) - self.assertEqual(AddK.outputscale.numel(), 3) - - testvals = torch.tensor([[1, 2, 3], [7, 5, 2]], dtype=torch.float) - add_k_val = AddK(testvals, testvals).to_dense() - - manual_k1 = ScaleKernel( - AdditiveKernel(RBFKernel(active_dims=0), RBFKernel(active_dims=1), RBFKernel(active_dims=2)) - ) - manual_k1.initialize(outputscale=1 / 3) - manual_k2 = ScaleKernel( - AdditiveKernel(RBFKernel(active_dims=[0, 1]), RBFKernel(active_dims=[1, 2]), RBFKernel(active_dims=[0, 2])) - ) - manual_k2.initialize(outputscale=1 / 3) - - manual_k3 = ScaleKernel(AdditiveKernel(RBFKernel())) - manual_k3.initialize(outputscale=1 / 3) - manual_k = AdditiveKernel(manual_k1, manual_k2, manual_k3) - manual_add_k_val = manual_k(testvals, testvals).to_dense() - # np.testing.assert_allclose(add_k_val.detach().numpy(), manual_add_k_val.detach().numpy(), atol=1e-5) - self.assertTrue(torch.allclose(add_k_val, manual_add_k_val, atol=1e-5)) - - def test_optimizing(self): - # This tests should pass so long as nothing breaks. - torch.random.manual_seed(1) - data = torch.randn(40, 4) - target = torch.sin(data).sum(dim=-1) - d = 4 - - AddK = NewtonGirardAdditiveKernel(RBFKernel(ard_num_dims=d), d, max_degree=3) - - class TestGPModel(ExactGP): - def __init__(self, train_x, train_y, likelihood, kernel): - super().__init__(train_x, train_y, likelihood) - self.mean_module = ConstantMean() - self.covar_module = kernel - - def forward(self, x): - mean_x = self.mean_module(x) - covar_x = self.covar_module(x) - return MultivariateNormal(mean_x, covar_x) - - model = TestGPModel(data, target, GaussianLikelihood(), ScaleKernel(AddK)) - optim = torch.optim.Adam(model.parameters(), lr=0.1) - mll = ExactMarginalLogLikelihood(model.likelihood, model) - model.train() - for i in range(2): - optim.zero_grad() - out = model(data) - loss = -mll(out, target) - loss.backward() - optim.step() - - def test_ard(self): - base_k = RBFKernel(ard_num_dims=3) - base_k.initialize(lengthscale=[1.0, 2.0, 3.0]) - AddK = NewtonGirardAdditiveKernel(base_k, 3, max_degree=1) - - testvals = torch.tensor([[1, 2, 3], [7, 5, 2]], dtype=torch.float) - add_k_val = AddK(testvals, testvals).to_dense() - - ks = [] - for i in range(3): - k = RBFKernel(active_dims=i) - k.initialize(lengthscale=i + 1) - ks.append(k) - manual_k = ScaleKernel(AdditiveKernel(*ks)) - manual_k.initialize(outputscale=1.0) - manual_add_k_val = manual_k(testvals, testvals).to_dense() - - # np.testing.assert_allclose(add_k_val.detach().numpy(), manual_add_k_val.detach().numpy(), atol=1e-5) - self.assertTrue(torch.allclose(add_k_val, manual_add_k_val, atol=1e-5)) - - def test_diag(self): - AddK = NewtonGirardAdditiveKernel(RBFKernel(ard_num_dims=3), 3, 2) - self.assertEqual(AddK.base_kernel.lengthscale.numel(), 3) - self.assertEqual(AddK.outputscale.numel(), 2) - - testvals = torch.tensor([[1, 2, 3], [7, 5, 2]], dtype=torch.float) - add_k_val = AddK(testvals, testvals).diagonal(dim1=-1, dim2=-2) - - manual_k1 = ScaleKernel( - AdditiveKernel(RBFKernel(active_dims=0), RBFKernel(active_dims=1), RBFKernel(active_dims=2)) - ) - manual_k1.initialize(outputscale=1 / 2) - manual_k2 = ScaleKernel( - AdditiveKernel(RBFKernel(active_dims=[0, 1]), RBFKernel(active_dims=[1, 2]), RBFKernel(active_dims=[0, 2])) - ) - manual_k2.initialize(outputscale=1 / 2) - manual_k = AdditiveKernel(manual_k1, manual_k2) - manual_add_k_val = manual_k(testvals, testvals).diagonal(dim1=-1, dim2=-2) - - # np.testing.assert_allclose(add_k_val.detach().numpy(), manual_add_k_val.detach().numpy(), atol=1e-5) - self.assertTrue(torch.allclose(add_k_val, manual_add_k_val, atol=1e-5)) diff --git a/test/kernels/test_piecewise_polynomial_kernel.py b/test/kernels/test_piecewise_polynomial_kernel.py index 3b5f7e766..f09181070 100644 --- a/test/kernels/test_piecewise_polynomial_kernel.py +++ b/test/kernels/test_piecewise_polynomial_kernel.py @@ -51,19 +51,6 @@ def test_fmax(r, j, q): res = kernel(a, b).diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = torch.zeros(2, 3, 3) - for i in range(2): - actual[i] = kernel(a[:, i].unsqueeze(-1), b[:, i].unsqueeze(-1)).to_dense() - - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_piecewise_polynomial_kernel_batch(self): a = torch.tensor([[4, 2, 8], [1, 2, 3]], dtype=torch.float).view(2, 3, 1) b = torch.tensor([[0, 2, 1], [-1, 2, 0]], dtype=torch.float).view(2, 3, 1) diff --git a/test/kernels/test_polynomial_kernel.py b/test/kernels/test_polynomial_kernel.py index ad57536a0..a82bfc12a 100644 --- a/test/kernels/test_polynomial_kernel.py +++ b/test/kernels/test_polynomial_kernel.py @@ -31,19 +31,6 @@ def test_computes_quadratic_kernel(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = torch.zeros(2, 3, 3) - for i in range(2): - actual[i] = kernel(a[:, i].unsqueeze(-1), b[:, i].unsqueeze(-1)).to_dense() - - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_computes_cubic_kernel(self): a = torch.tensor([[4, 1], [2, 2], [8, 0]], dtype=torch.float) b = torch.tensor([[0, 0], [2, 1], [1, 0]], dtype=torch.float) @@ -63,19 +50,6 @@ def test_computes_cubic_kernel(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = torch.zeros(2, 3, 3) - for i in range(2): - actual[i] = kernel(a[:, i].unsqueeze(-1), b[:, i].unsqueeze(-1)).to_dense() - - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims + diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_quadratic_kernel_batch(self): a = torch.tensor([[4, 2, 8], [1, 2, 3]], dtype=torch.float).view(2, 3, 1) b = torch.tensor([[0, 2, 1], [-1, 2, 0]], dtype=torch.float).view(2, 3, 1) diff --git a/test/kernels/test_rbf_kernel.py b/test/kernels/test_rbf_kernel.py index 718fb4e26..38f369dee 100644 --- a/test/kernels/test_rbf_kernel.py +++ b/test/kernels/test_rbf_kernel.py @@ -19,8 +19,8 @@ def create_kernel_ard(self, num_dims, **kwargs): return RBFKernel(ard_num_dims=num_dims, **kwargs) def test_ard(self): - a = torch.tensor([[1, 2], [2, 4]], dtype=torch.float) - b = torch.tensor([[1, 3], [0, 4]], dtype=torch.float) + a = torch.tensor([[1, 2], [2, 4], [1, 2]], dtype=torch.float) + b = torch.tensor([[1, 3], [0, 4], [0, 3]], dtype=torch.float) lengthscales = torch.tensor([1, 2], dtype=torch.float).view(1, 2) kernel = RBFKernel(ard_num_dims=2) @@ -38,17 +38,6 @@ def test_ard(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = scaled_a.transpose(-1, -2).unsqueeze(-1) - scaled_b.transpose(-1, -2).unsqueeze(-2) - actual = actual.pow(2).mul_(-0.5).exp() - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-1, dim2=-2) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) @@ -69,19 +58,6 @@ def test_ard_batch(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - double_batch_a = scaled_a.transpose(-1, -2).unsqueeze(-1) - double_batch_b = scaled_b.transpose(-1, -2).unsqueeze(-2) - actual = double_batch_a - double_batch_b - actual = actual.pow(2).mul_(-0.5).exp() - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_separate_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) diff --git a/test/kernels/test_rq_kernel.py b/test/kernels/test_rq_kernel.py index b7a92e726..ddd5cd059 100644 --- a/test/kernels/test_rq_kernel.py +++ b/test/kernels/test_rq_kernel.py @@ -38,23 +38,12 @@ def test_ard(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - diff = scaled_a.transpose(-1, -2).unsqueeze(-1) - scaled_b.transpose(-1, -2).unsqueeze(-2) - actual = diff.pow(2).div_(2 * kernel.alpha).add_(1.0).pow(-kernel.alpha) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-1, dim2=-2) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) - lengthscales = torch.tensor([[[1, 2, 1]]], dtype=torch.float) + lengthscales = torch.tensor([[1, 2, 1]], dtype=torch.float) - kernel = RQKernel(batch_shape=torch.Size([2]), ard_num_dims=3) + kernel = RQKernel(batch_shape=torch.Size([]), ard_num_dims=3) kernel.initialize(lengthscale=lengthscales) kernel.initialize(alpha=3.0) kernel.eval() @@ -71,20 +60,6 @@ def test_ard_batch(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # # batch_dims - double_batch_a = scaled_a.transpose(-1, -2).unsqueeze(-1) - double_batch_b = scaled_b.transpose(-1, -2).unsqueeze(-2) - actual = double_batch_a - double_batch_b - alpha = kernel.alpha.view(2, 1, 1, 1) - actual = actual.pow_(2).div_(2 * alpha).add_(1.0).pow(-alpha) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_separate_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) diff --git a/test/kernels/test_scale_kernel.py b/test/kernels/test_scale_kernel.py index 57ad0bdf8..3959064bb 100644 --- a/test/kernels/test_scale_kernel.py +++ b/test/kernels/test_scale_kernel.py @@ -44,18 +44,6 @@ def test_ard(self): actual = actual.diagonal(dim1=-1, dim2=-2) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - actual = scaled_a.transpose(-1, -2).unsqueeze(-1) - scaled_b.transpose(-1, -2).unsqueeze(-2) - actual = actual.pow(2).mul_(-0.5).exp().view(2, 2, 2) - actual.mul_(3) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_ard_batch(self): a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) @@ -79,20 +67,6 @@ def test_ard_batch(self): actual = torch.cat([actual[i].diagonal(dim1=-1, dim2=-2).unsqueeze(0) for i in range(actual.size(0))]) self.assertLess(torch.norm(res - actual), 1e-5) - # batch_dims - double_batch_a = scaled_a.transpose(-1, -2) - double_batch_b = scaled_b.transpose(-1, -2) - actual = double_batch_a.unsqueeze(-1) - double_batch_b.unsqueeze(-2) - actual = actual.pow(2).mul_(-0.5).exp() - actual[1, :, :, :].mul_(2) - res = kernel(a, b, last_dim_is_batch=True).to_dense() - self.assertLess(torch.norm(res - actual), 1e-5) - - # batch_dims and diag - res = kernel(a, b, last_dim_is_batch=True).diagonal(dim1=-1, dim2=-2) - actual = actual.diagonal(dim1=-2, dim2=-1) - self.assertLess(torch.norm(res - actual), 1e-5) - def test_initialize_outputscale(self): kernel = ScaleKernel(RBFKernel()) kernel.initialize(outputscale=3.14) diff --git a/test/lazy/test_lazy_evaluated_kernel_tensor.py b/test/lazy/test_lazy_evaluated_kernel_tensor.py index 5a3528704..2041f8ca4 100644 --- a/test/lazy/test_lazy_evaluated_kernel_tensor.py +++ b/test/lazy/test_lazy_evaluated_kernel_tensor.py @@ -181,33 +181,3 @@ def test_half(self): lazy_tensor = self.create_linear_op() lazy_tensor.kernel.data_covar_module.raw_lengthscale_constraint.transform = lambda x: x + 0.1 self._test_half(lazy_tensor) - - -class TestLazyEvaluatedKernelTensorAdditive(TestLazyEvaluatedKernelTensorBatch): - seed = 0 - - def create_linear_op(self): - kern = gpytorch.kernels.AdditiveStructureKernel(gpytorch.kernels.RBFKernel(), num_dims=6) - mat1 = torch.randn(5, 6) - mat2 = mat1.detach().clone() - return kern(mat1, mat2) - - def evaluate_linear_op(self, lazy_tensor): - res = to_dense( - gpytorch.Module.__call__( - lazy_tensor.kernel.base_kernel, - lazy_tensor.x1.transpose(-1, -2).unsqueeze(-1), - lazy_tensor.x2.transpose(-1, -2).unsqueeze(-1), - ) - ).sum(0) - return res - - def test_inv_matmul_matrix_with_checkpointing(self): - pass - - def test_half(self): - # many transform operations aren't supported in half so we overwrite - # this test - lazy_tensor = self.create_linear_op() - lazy_tensor.kernel.base_kernel.raw_lengthscale_constraint.transform = lambda x: x + 0.1 - self._test_half(lazy_tensor) diff --git a/test/utils/test_interpolation.py b/test/utils/test_interpolation.py index c296d75bf..d0642bb5a 100644 --- a/test/utils/test_interpolation.py +++ b/test/utils/test_interpolation.py @@ -12,11 +12,11 @@ class TestCubicInterpolation(unittest.TestCase): def test_interpolation(self): x = torch.linspace(0.01, 1, 100).unsqueeze(1) - grid = torch.linspace(-0.05, 1.05, 50).unsqueeze(1) + grid = [torch.linspace(-0.05, 1.05, 50)] indices, values = Interpolation().interpolate(grid, x) indices = indices.squeeze_(0) values = values.squeeze_(0) - test_func_grid = grid.squeeze(1).pow(2) + test_func_grid = grid[0].pow(2) test_func_x = x.pow(2).squeeze(-1) interp_func_x = left_interp(indices, values, test_func_grid.unsqueeze(1)).squeeze() @@ -25,7 +25,7 @@ def test_interpolation(self): def test_multidim_interpolation(self): x = torch.tensor([[0.25, 0.45, 0.65, 0.85], [0.35, 0.375, 0.4, 0.425], [0.45, 0.5, 0.55, 0.6]]).t().contiguous() - grid = torch.linspace(0.0, 1.0, 11).unsqueeze(1).repeat(1, 3) + grid = [torch.linspace(0.0, 1.0, 11) for _ in range(3)] indices, values = Interpolation().interpolate(grid, x)