diff --git a/.conda/meta.yaml b/.conda/meta.yaml index 79c4f9714..7bf9d09ce 100644 --- a/.conda/meta.yaml +++ b/.conda/meta.yaml @@ -1,8 +1,12 @@ -{% set data = load_setup_py_data(setup_file="../setup.py", from_recipe_dir=True) %} +{% set _version_match = load_file_regex( + load_file="gpytorch/version.py", + regex_pattern="__version__ = version = '(.+)'" +) %} +{% set version = _version_match[1] %} package: - name: {{ data.get("name")|lower }} - version: {{ data.get("version") }} + name: gpytorch + version: {{ version }} source: path: ../ @@ -17,9 +21,10 @@ requirements: run: - python>=3.8 - - pytorch>=1.11 + - pytorch>=2.0 - scikit-learn - - linear_operator>=0.5.2 + - jaxtyping==0.2.19 + - linear_operator>=0.5.3 test: imports: diff --git a/.github/ISSUE_TEMPLATE/---documentation-examples.md b/.github/ISSUE_TEMPLATE/---documentation-examples.md index b9be64b58..4d2c5e188 100644 --- a/.github/ISSUE_TEMPLATE/---documentation-examples.md +++ b/.github/ISSUE_TEMPLATE/---documentation-examples.md @@ -21,4 +21,4 @@ assignees: '' ** Think you know how to fix the docs? ** (If so, we'd love a pull request from you!) - Link to [GPyTorch documentation](https://gpytorch.readthedocs.io) -- Link to [GPyTorch examples](https://github.com/cornellius-gp/gpytorch/tree/master/examples) +- Link to [GPyTorch examples](https://github.com/cornellius-gp/gpytorch/tree/main/examples) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index c5d7105d8..1e04bf0b6 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -52,9 +52,8 @@ jobs: conda config --set anaconda_upload yes conda config --append channels pytorch conda config --append channels gpytorch + conda config --append channels conda-forge /usr/share/miniconda/bin/anaconda login --username ${{ secrets.CONDA_USERNAME }} --password ${{ secrets.CONDA_PASSWORD }} python -m setuptools_scm - cd .conda - conda build . + conda build .conda /usr/share/miniconda/bin/anaconda logout - cd .. diff --git a/.github/workflows/run_test_suite.yml b/.github/workflows/run_test_suite.yml index 5b12018b3..c8a0519d9 100644 --- a/.github/workflows/run_test_suite.yml +++ b/.github/workflows/run_test_suite.yml @@ -5,9 +5,9 @@ name: Run Test Suite on: push: - branches: [ master ] + branches: [ main, develop ] pull_request: - branches: [ master ] + branches: [ main, develop ] workflow_call: jobs: @@ -50,7 +50,7 @@ jobs: if [[ ${{ matrix.pytorch-version }} = "master" ]]; then pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html; else - pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html; + pip install torch==2.0.1 --index-url https://download.pytorch.org/whl/cpu fi pip install -e . if [[ ${{ matrix.extras }} == "with-extras" ]]; then diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 317914983..ae5aba426 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: hooks: - id: flake8 args: [--config=setup.cfg] - exclude: ^(examples/*)|(docs/*) + exclude: ^(examples/.*)|(docs/.*) - repo: https://github.com/omnilib/ufmt rev: v2.0.0 hooks: @@ -24,15 +24,15 @@ repos: additional_dependencies: - black == 22.3.0 - usort == 1.0.3 - exclude: ^(build/*)|(docs/*)|(examples/*) + exclude: ^(build/.*)|(docs/.*)|(examples/.*) - repo: https://github.com/jumanjihouse/pre-commit-hooks rev: 2.1.6 hooks: - id: require-ascii - exclude: ^(examples/.*\.ipynb)|(.github/ISSUE_TEMPLATE/*) + exclude: ^(examples/.*\.ipynb)|(.github/ISSUE_TEMPLATE/.*) - id: script-must-have-extension - id: forbid-binary - exclude: ^(examples/*)|(test/examples/old_variational_strategy_model.pth) + exclude: ^(examples/.*)|(test/examples/old_variational_strategy_model.pth) - repo: https://github.com/Lucas-C/pre-commit-hooks rev: v1.1.13 hooks: diff --git a/README.md b/README.md index 70a292690..498206039 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ See our [**documentation, examples, tutorials**](https://gpytorch.readthedocs.io **Requirements**: - Python >= 3.8 -- PyTorch >= 1.11 +- PyTorch >= 2.0 Install GPyTorch using pip or conda: @@ -88,7 +88,7 @@ If you use GPyTorch, please cite the following papers: ## Contributing -See the contributing guidelines [CONTRIBUTING.md](https://github.com/cornellius-gp/gpytorch/blob/master/CONTRIBUTING.md) +See the contributing guidelines [CONTRIBUTING.md](https://github.com/cornellius-gp/gpytorch/blob/main/CONTRIBUTING.md) for information on submitting issues and pull requests. diff --git a/docs/source/conf.py b/docs/source/conf.py index 0b872c98a..ec10b7878 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,7 +19,9 @@ import sys import sphinx_rtd_theme # noqa import warnings -from typing import ForwardRef + +import jaxtyping +from uncompyle6.semantics.fragments import code_deparse def read(*names, **kwargs): @@ -112,7 +114,8 @@ def find_version(*file_paths): intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), "torch": ("https://pytorch.org/docs/stable/", None), - "linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", None), + "linear_operator": ("https://linear-operator.readthedocs.io/en/stable/", "linear_operator_objects.inv"), + # The local mapping here is temporary until we get a new release of linear_operator } # Disable docstring inheritance @@ -237,41 +240,81 @@ def find_version(*file_paths): ] -# -- Function to format typehints ---------------------------------------------- +# -- Functions to format typehints ---------------------------------------------- # Adapted from # https://github.com/cornellius-gp/linear_operator/blob/2b33b9f83b45f0cb8cb3490fc5f254cc59393c25/docs/source/conf.py + + +# Helper function +# Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings +# For external classes, the format will be e.g. "torch.Tensor" +# For any internal class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator" +def _convert_internal_and_external_class_to_strings(annotation): + module = annotation.__module__ + "." + if module.split(".")[0] == "gpytorch": + module = "~" + module + elif module == "torch.": + module = "~torch." + elif module == "linear_operator.operators._linear_operator.": + module = "~linear_operator." + elif module == "builtins.": + module = "" + res = f"{module}{annotation.__name__}" + return res + + +# Convert jaxtyping dimensions into strings +def _dim_to_str(dim): + if isinstance(dim, jaxtyping.array_types._NamedVariadicDim): + return "..." + elif isinstance(dim, jaxtyping.array_types._FixedDim): + res = str(dim.size) + if dim.broadcastable: + res = "#" + res + return res + elif isinstance(dim, jaxtyping.array_types._SymbolicDim): + expr = code_deparse(dim.expr).text.strip().split("return ")[1] + return f"({expr})" + elif "jaxtyping" not in str(dim.__class__): # Probably the case that we have an ellipsis + return "..." + else: + res = str(dim.name) + if dim.broadcastable: + res = "#" + res + return res + + +# Function to format type hints def _process(annotation, config): """ A function to convert a type/rtype typehint annotation into a :type:/:rtype: string. This function is a bit hacky, and specific to the type annotations we use most frequently. + This function is recursive. """ # Simple/base case: any string annotation is ready to go if type(annotation) == str: return annotation + # Jaxtyping: shaped tensors or linear operator + elif hasattr(annotation, "__module__") and "jaxtyping" == annotation.__module__: + cls_annotation = _convert_internal_and_external_class_to_strings(annotation.array_type) + shape = " x ".join([_dim_to_str(dim) for dim in annotation.dims]) + return f"{cls_annotation} ({shape})" + # Convert Ellipsis into "..." elif annotation == Ellipsis: return "..." - # Convert any class (i.e. torch.Tensor, LinearOperator, gpytorch, etc.) into appropriate strings - # For external classes, the format will be e.g. "torch.Tensor" - # For any linear_operator class, the format will be e.g. "~linear_operator.operators.TriangularLinearOperator" - # For any internal class, the format will be e.g. "~gpytorch.kernels.RBFKernel" + # Convert any class (i.e. torch.Tensor, LinearOperator, etc.) into appropriate strings elif hasattr(annotation, "__name__"): - module = annotation.__module__ + "." - if module.split(".")[0] == "linear_operator": - if annotation.__name__.endswith("LinearOperator"): - module = "~linear_operator." - elif annotation.__name__.endswith("LinearOperator"): - module = "~linear_operator.operators." - else: - module = "~" + module - elif module.split(".")[0] == "gpytorch": - module = "~" + module - elif module == "builtins.": - module = "" - res = f"{module}{annotation.__name__}" + res = _convert_internal_and_external_class_to_strings(annotation) + + elif str(annotation).startswith("typing.Callable"): + if len(annotation.__args__) == 2: + res = f"Callable[{_process(annotation.__args__[0], config)} -> {_process(annotation.__args__[1], config)}]" + else: + res = "Callable" # Convert any Union[*A*, *B*, *C*] into "*A* or *B* or *C*" # Also, convert any Optional[*A*] into "*A*, optional" @@ -291,33 +334,14 @@ def _process(annotation, config): args = list(annotation.__args__) res = "(" + ", ".join(_process(arg, config) for arg in args) + ")" - # Convert any List[*A*] into "list(*A*)" - elif str(annotation).startswith("typing.List"): - arg = annotation.__args__[0] - res = "list(" + _process(arg, config) + ")" - - # Convert any List[*A*] into "list(*A*)" - elif str(annotation).startswith("typing.Dict"): - res = str(annotation) - - # Convert any Iterable[*A*] into "iterable(*A*)" - elif str(annotation).startswith("typing.Iterable"): - arg = annotation.__args__[0] - res = "iterable(" + _process(arg, config) + ")" - - # Handle "Callable" - elif str(annotation).startswith("typing.Callable"): - res = "callable" - - # Handle "Any" - elif str(annotation).startswith("typing.Any"): - res = "" + # Convert any List[*A*] or Iterable[*A*] into "[*A*, ...]" + elif str(annotation).startswith("typing.Iterable") or str(annotation).startswith("typing.List"): + arg = list(annotation.__args__)[0] + res = f"[{_process(arg, config)}, ...]" - # Special cases for forward references. - # This is brittle, as it only contains case for a select few forward refs - # All others that aren't caught by this are handled by the default case - elif isinstance(annotation, ForwardRef): - res = str(annotation.__forward_arg__) + # Callable typing annotation + elif str(annotation).startswith("typing."): + return str(annotation)[7:] # For everything we didn't catch: use the simplist string representation else: diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index a98bd862b..e8ceeb9ee 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -5,7 +5,7 @@ gpytorch.distributions =================================== GPyTorch distribution objects are essentially the same as torch distribution objects. -For the most part, GpyTorch relies on torch's distribution library. +For the most part, GPyTorch relies on torch's distribution library. However, we offer two custom distributions. We implement a custom :obj:`~gpytorch.distributions.MultivariateNormal` that accepts diff --git a/docs/source/kernels.rst b/docs/source/kernels.rst index 714e46a6c..c11c030d5 100644 --- a/docs/source/kernels.rst +++ b/docs/source/kernels.rst @@ -9,7 +9,7 @@ gpytorch.kernels If you don't know what kernel to use, we recommend that you start out with a -:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel)`. +:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) + gpytorch.kernels.ConstantKernel()`. Kernel @@ -22,6 +22,13 @@ Kernel Standard Kernels ----------------------------- +:hidden:`ConstantKernel` +~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: ConstantKernel + :members: + + :hidden:`CosineKernel` ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/linear_operator_objects.inv b/docs/source/linear_operator_objects.inv new file mode 100644 index 000000000..2de1dfa8b Binary files /dev/null and b/docs/source/linear_operator_objects.inv differ diff --git a/docs/source/utils.rst b/docs/source/utils.rst index e0b941a65..23728f75d 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -20,6 +20,12 @@ Interpolation Utilities .. automodule:: gpytorch.utils.interpolation :members: +Nearest Neighbors Utilities +--------------------------------- + +.. automodule:: gpytorch.utils.nearest_neighbors + :members: + Quadrature Utilities ---------------------------- @@ -31,9 +37,3 @@ Transform Utilities .. automodule:: gpytorch.utils.transforms :members: - -Nearest Neighbors Utilities -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. automodule:: gpytorch.utils.nearest_neighbors - :members: diff --git a/examples/00_Basic_Usage/Implementing_a_custom_Kernel.ipynb b/examples/00_Basic_Usage/Implementing_a_custom_Kernel.ipynb index 3bb5e4c07..1883b85b0 100644 --- a/examples/00_Basic_Usage/Implementing_a_custom_Kernel.ipynb +++ b/examples/00_Basic_Usage/Implementing_a_custom_Kernel.ipynb @@ -209,7 +209,7 @@ "source": [ "### Adding hyperparameters\n", "\n", - "Althogh the `FirstSincKernel` can be used for defining a model, it lacks a parameter that controls the correlation length. This lengthscale will be implemented as a hyperparameter. See also the [tutorial on hyperparamaters](https://docs.gpytorch.ai/en/latest/examples/00_Basic_Usage/Hyperparameters.html), for information on raw vs. actual parameters.\n", + "Although the `FirstSincKernel` can be used for defining a model, it lacks a parameter that controls the correlation length. This lengthscale will be implemented as a hyperparameter. See also the [tutorial on hyperparamaters](https://docs.gpytorch.ai/en/latest/examples/00_Basic_Usage/Hyperparameters.html), for information on raw vs. actual parameters.\n", "\n", "The parameter has to be registered, using the method `register_parameter()`, which `Kernel` inherits from `Module`. Similarly, we register constraints and priors." ] diff --git a/examples/00_Basic_Usage/index.rst b/examples/00_Basic_Usage/index.rst index a87f37cab..386bc3f24 100644 --- a/examples/00_Basic_Usage/index.rst +++ b/examples/00_Basic_Usage/index.rst @@ -6,10 +6,14 @@ parameter constraints and priors, and saving and loading models. Before checking these out, you may want to check out our `simple GP regression tutorial`_ that details the anatomy of a GPyTorch model. -- Check out our `Tutorial on Hyperparameters`_ for information on things like raw versus actual +* Check out our `Tutorial on Hyperparameters`_ for information on things like raw versus actual parameters, constraints, priors and more. -- The `Saving and Loading Models`_ notebook details how to save and load GPyTorch models +* The `Saving and Loading Models`_ notebook details how to save and load GPyTorch models on disk. +* The `Kernels with Additive or Product Structure`_ notebook describes how to compose kernels additively or multiplicatively, + whether for expressivity, sample efficiency, or scalability. +* The `Implementing a Custom Kernel`_ notebook details how to write your own custom kernel in GPyTorch. +* The `Tutorial on Metrics`_ describes various metrics provided by GPyTorch for assessing the generalization of GP models. .. toctree:: :maxdepth: 1 @@ -17,6 +21,7 @@ Before checking these out, you may want to check out our `simple GP regression t Hyperparameters.ipynb Saving_and_Loading_Models.ipynb + kernels_with_additive_or_product_structure.ipynb Implementing_a_custom_Kernel.ipynb Metrics.ipynb @@ -29,6 +34,9 @@ Before checking these out, you may want to check out our `simple GP regression t .. _Saving and Loading Models: Saving_and_Loading_Models.ipynb +.. _Kernels with Additive or Product Structure: + kernels_with_additive_or_product_structure.ipynb + .. _Implementing a custom Kernel: Implementing_a_custom_Kernel.ipynb diff --git a/examples/00_Basic_Usage/kernels_with_additive_or_product_structure.ipynb b/examples/00_Basic_Usage/kernels_with_additive_or_product_structure.ipynb new file mode 100644 index 000000000..45dd43fa4 --- /dev/null +++ b/examples/00_Basic_Usage/kernels_with_additive_or_product_structure.ipynb @@ -0,0 +1,565 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "6f63eab6-9d70-497c-9743-2acd652b7b8c", + "metadata": {}, + "outputs": [], + "source": [ + "# smoke_test = True\n", + "\n", + "import gpytorch\n", + "import torch\n", + "\n", + "from torch.utils import benchmark" + ] + }, + { + "cell_type": "markdown", + "id": "057df92a-445c-4d33-b914-4fb39d9b7c7b", + "metadata": {}, + "source": [ + "# Kernels with Additive or Product Structure\n", + "\n", + "One of the most powerful properties of kernels is their closure under various composition operation.\n", + "Many important covariance functions can be written as the sum or the product of $m$ component kernels:\n", + "\n", + "$$\n", + " k_\\mathrm{sum}(\\boldsymbol x, \\boldsymbol x') = \\sum_{i=1}^m k_i(\\boldsymbol x, \\boldsymbol x'), \\qquad\n", + " k_\\mathrm{prod}(\\boldsymbol x, \\boldsymbol x') = \\prod_{i=1}^m k_i(\\boldsymbol x, \\boldsymbol x')\n", + "$$\n", + "\n", + "Additive and product kernels are used for a variety of reasons.\n", + "1. They are often more interpretable, as argued in [Duvenaud et al. (2011)](https://arxiv.org/pdf/1112.4394).\n", + "2. They can be extremely powerful and expressive, as demonstrated by [Wilson and Adams (2013)](https://proceedings.mlr.press/v28/wilson13.pdf).\n", + "3. They can be extremely sample efficient for Bayesian optimization, as demonstrated by [Kandasamy et al. (2015)](https://arxiv.org/pdf/1503.01673) and [Gardner et al. (2017)](https://proceedings.mlr.press/v54/gardner17a/gardner17a.pdf).\n", + "\n", + "We will discuss various ways to perform additive and product compositions of kernels in GPyTorch.\n", + "The simplest mechanism is to add/multiply the kernel objects together, or add/multiply their outputs.\n", + "However, there are more complex but **far more efficient ways** for adding/multiplying kernels with similar functional forms, which will enable significant parallelism especially on GPUs." + ] + }, + { + "cell_type": "markdown", + "id": "610e82ae-227b-4c0d-b817-9cd6baa73d92", + "metadata": {}, + "source": [ + "## Simple Sums and Products\n", + "\n", + "As an example, consider the [spectral mixture kernel](https://docs.gpytorch.ai/en/stable/kernels.html#spectralmixturekernel) with two components on a univariate input.\n", + "If we remove the scaling components, it can be implemented as:\n", + "\n", + "$$\n", + " k_\\mathrm{SM}(x, x') =\n", + " k_\\mathrm{RBF}(x, x', \\ell_1) k_\\mathrm{cos}(x, x'; \\omega_1) +\n", + " k_\\mathrm{RBF}(x, x', \\ell_2) k_\\mathrm{cos}(x, x'; \\omega_2),\n", + "$$\n", + "\n", + "where $\\ell_1, \\ell_2, \\omega_1, \\omega_2$ are hyperparameters. We can naively implement this kernel in two ways..." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e060386d-5192-4d03-9002-2f4169be7507", + "metadata": {}, + "outputs": [], + "source": [ + "# Toy data\n", + "X = torch.randn(10, 1)\n", + "\n", + "# Base kernels\n", + "rbf_kernel_1 = gpytorch.kernels.RBFKernel()\n", + "cos_kernel_1 = gpytorch.kernels.CosineKernel()\n", + "rbf_kernel_2 = gpytorch.kernels.RBFKernel()\n", + "cos_kernel_2 = gpytorch.kernels.CosineKernel()\n", + "\n", + "# Implementation 1:\n", + "spectral_mixture_kernel = (rbf_kernel_1 * cos_kernel_1) + (rbf_kernel_2 * cos_kernel_2)\n", + "covar = spectral_mixture_kernel(X)\n", + "\n", + "# Implementation 2:\n", + "covar = rbf_kernel_1(X) * cos_kernel_1(X) + rbf_kernel_2(X) * cos_kernel_2(X)" + ] + }, + { + "cell_type": "markdown", + "id": "ddf6ee94-f7eb-4bb2-935a-7a76787fa596", + "metadata": {}, + "source": [ + "Implementation 1 constructs a `spectral_mixture_kernel` object by applying `+` and `*` directly to the component kernel objects.\n", + "Implementation 2 constructrs the resulting covariance matrix by applying `+` and `*` to the outputs of the component kernels.\n", + "Both implementations are equivalent (the `spectral_mixture_kernel` object created by Implementation 1 essentially performs Implementation 2) under the hood.\n", + "\n", + "(Of course, neither implementation should be used in practice for the spectral mixture kernel. The built-in [SpectralMixtureKernel](https://docs.gpytorch.ai/en/stable/kernels.html#spectralmixturekernel) class is far more efficient.)" + ] + }, + { + "cell_type": "markdown", + "id": "cfbc6181-752d-4905-8e08-df27fb5a93d5", + "metadata": {}, + "source": [ + "## Efficient Parallel Implementations of Additive Structure or Product Structure Kernels \n", + "\n", + "Above we considered the sum and products of kernels with different functional forms.\n", + "However, often we are considering the sum/product over kernels with \n", + "The above example is simple to read, but quite slow in practice.\n", + "Under the hood, each of the kernels (and their compositions) are computed sequentially.\n", + "GPyTorch will compute the first cosine kernel, followed by the first RBF kernel, followed by their product, and so on.\n", + "\n", + "When the component kernels have the same function form,\n", + "we can get massive efficieny gains by exploiting parallelism.\n", + "We combine all of the component kernels into a **batch kernel**\n", + "so that each component kernel can be computed simultaneously.\n", + "We then compute the `sum` or `prod` over the batch dimension.\n", + "This strategy will yield significant speedups especially on the GPU." + ] + }, + { + "cell_type": "markdown", + "id": "731fc597-ccbc-441a-830d-0b8d4efe100d", + "metadata": {}, + "source": [ + "### Example #1: Efficient Summations of Univariate Kernels\n", + "\n", + "As an example, let's assume that we have $d$-dimensional input data $\\boldsymbol x, \\boldsymbol x' \\in \\mathbb R^d$.\n", + "We can define an *additive kernel* that is the sum of $d$ univariate RBF kernels, each of which acts on a single dimension of $\\boldsymbol x$ and $\\boldsymbol x'$.\n", + "\n", + "$$\n", + " k_\\mathrm{additive}(\\boldsymbol x, \\boldsymbol x') = \\prod_{i=1}^d k_\\mathrm{RBF}(x^{(i)}, x^{\\prime(i)}; \\ell^{(i)}).\n", + "$$\n", + "\n", + "Here, $\\ell^{(i)}$ is the lengthscale associated with dimension $i$.\n", + "Note that we are using a different lengthscale for each of the component kernels.\n", + "Nevertheless, we can efficiently compute each of the component kernels in parallel using batching.\n", + "First we define a RBFKernel object designed to compute a **batch of $d$ univariate kernels**:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "90e52419-38cd-43e7-a599-bfb7abea89b8", + "metadata": {}, + "outputs": [], + "source": [ + "d = 3\n", + "\n", + "batch_univariate_rbf_kernel = gpytorch.kernels.RBFKernel(\n", + " batch_shape=torch.Size([d]), # A batch of d...\n", + " ard_num_dims=1, # ...univariate kernels\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ba66b398-7429-408f-af23-d64ad6930df6", + "metadata": {}, + "source": [ + "Including the `batch_shape` argument ensures that the `lengthscale` parameter of the `batch_univariate_rbf_kernel` is a `d x 1 x 1` tensor; i.e. each univariate kernel will have its own lengthscale. (We could instead have each univariate kernel share the same lengthscale by omitting the `batch_shape` argument.)\n", + "\n", + "To compute the univariate kernel matrices, we need to feed the appropriate dimensions of $\\boldsymbol X$ into each of the component kernels.\n", + "We accomplish this by reshaping the `n x d` matrix representing $\\boldsymbol X$ into a batch of $d$ `n x 1` matrices\n", + "(i.e. a `d x n x 1` tensor)." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0edaca08-df1e-423d-b76c-e5f8f9fc5338", + "metadata": {}, + "outputs": [], + "source": [ + "n = 10\n", + "\n", + "X = torch.randn(n, d) # Some random data in a n x d matrix\n", + "batched_dimensions_of_X = X.mT.unsqueeze(-1) # Now a d x n x 1 tensor" + ] + }, + { + "cell_type": "markdown", + "id": "8b40265f-a450-42fc-bcd4-f033750a8f7b", + "metadata": {}, + "source": [ + "We then feed the batches of univariate data into the batched kernel object to get our batch of univariate kernel matrices:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e0988873-93fa-4016-880c-d9dff3344032", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([3, 10, 10])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "univariate_rbf_covars = batch_univariate_rbf_kernel(batched_dimensions_of_X)\n", + "univariate_rbf_covars.shape # d x n x n" + ] + }, + { + "cell_type": "markdown", + "id": "a61442cb-a575-44d9-ae97-cc396bc48a53", + "metadata": {}, + "source": [ + "And finally, to get the multivariate kernel, we can compute the sum over the batch (i.e. the sum over the univariate kernels)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "65cbeb00-4dde-4720-8d80-8a8b683caf8e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([10, 10])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "additive_covar = univariate_rbf_covars.sum(dim=-3) # Computes the sum over the batch dimension\n", + "additive_covar.shape # n x n" + ] + }, + { + "cell_type": "markdown", + "id": "3e52b2c1-35f2-4074-a5f5-36fd826d3b2e", + "metadata": {}, + "source": [ + "On a small dataset, this approach is comparable to the naive approach described above. However, it will become much faster on a larger and more high dimensional dataset, especially on the GPU." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6a5d9b77-cd31-40b4-bc89-d913a601ea9b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "naive_additive_kernel(X)\n", + " 3.37 ms\n", + " 1 measurement, 100 runs , 1 thread\n" + ] + } + ], + "source": [ + "d = 10\n", + "n = 500\n", + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "X = torch.randn(n, d, device=device)\n", + "\n", + "naive_additive_kernel = (\n", + " gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[0]) +\n", + " gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[1]) +\n", + " gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[2]) +\n", + " gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[3]) +\n", + " gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[4]) +\n", + " gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[5]) +\n", + " gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[6]) +\n", + " gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[7]) +\n", + " gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[8]) +\n", + " gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=[9])\n", + ").to(device=device)\n", + "\n", + "with gpytorch.settings.lazily_evaluate_kernels(False):\n", + " print(benchmark.Timer(\n", + " stmt=\"naive_additive_kernel(X)\",\n", + " globals={\"naive_additive_kernel\": naive_additive_kernel, \"X\": X}\n", + " ).timeit(100))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8ce95d7a-0419-4d15-b48c-892df4c7711e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "batch_univariate_rbf_kernel(X.mT.unsqueeze(-1)).sum(dim=-3)\n", + " 940.30 us\n", + " 1 measurement, 100 runs , 1 thread\n" + ] + } + ], + "source": [ + "batch_univariate_rbf_kernel = gpytorch.kernels.RBFKernel(\n", + " batch_shape=torch.Size([d]), ard_num_dims=1,\n", + ").to(device=device)\n", + "with gpytorch.settings.lazily_evaluate_kernels(False):\n", + " print(benchmark.Timer(\n", + " stmt=\"batch_univariate_rbf_kernel(X.mT.unsqueeze(-1)).sum(dim=-3)\",\n", + " globals={\"batch_univariate_rbf_kernel\": batch_univariate_rbf_kernel, \"X\": X}\n", + " ).timeit(100))" + ] + }, + { + "cell_type": "markdown", + "id": "aa47ff89-dd55-44cf-b267-ac28ba0a9c27", + "metadata": {}, + "source": [ + "### Full Example\n", + "\n", + "Putting it all together, a GP using this efficient additive kernel would look something like..." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c32b0df1-1fa8-4b27-ab12-2d352ea80064", + "metadata": {}, + "outputs": [], + "source": [ + "class AdditiveKernelGP(gpytorch.models.ExactGP):\n", + " def __init__(self, X_train, y_train, d):\n", + " likelihood = gpytorch.likelihoods.GaussianLikelihood()\n", + " super().__init__(X_train, y_train, likelihood)\n", + "\n", + " self.mean_module = gpytorch.means.ConstantMean()\n", + " self.covar_module = gpytorch.kernels.ScaleKernel(\n", + " gpytorch.kernels.RBFKernel(batch_shape=torch.Size([d]), ard_num_dims=1)\n", + " )\n", + "\n", + " def forward(self, X):\n", + " mean = self.mean_module(X)\n", + " batched_dimensions_of_X = X.mT.unsqueeze(-1) # Now a d x n x 1 tensor\n", + " covar = self.covar_module(batched_dimensions_of_X).sum(dim=-3)\n", + " return gpytorch.distributions.MultivariateNormal(mean, covar)" + ] + }, + { + "cell_type": "markdown", + "id": "6a9c2497-1200-44ed-b398-18a424976ea3", + "metadata": {}, + "source": [ + "### Example #2: Efficient Products of Univariate Kernels\n", + "\n", + "As another example, we can consider a multivariate kernel defined as the product of univariate kernels, i.e.:\n", + "\n", + "$$\n", + " k_\\mathrm{RBF}(\\boldsymbol x, \\boldsymbol x'; \\boldsymbol \\ell) = \\prod_{i=1}^d k_\\mathrm{RBF}(x^{(i)}, x^{\\prime(i)}; \\ell^{(i)}).\n", + "$$" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "8122f229-c7ff-49df-a4ff-2cb3f9f05fe4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([10, 10])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d = 3\n", + "n = 10\n", + "\n", + "batch_univariate_rbf_kernel = gpytorch.kernels.RBFKernel(\n", + " batch_shape=torch.Size([d]), ard_num_dims=1,\n", + ")\n", + "X = torch.randn(n, d)\n", + "\n", + "univariate_rbf_covars = batch_univariate_rbf_kernel(X.mT.unsqueeze(-1))\n", + "with gpytorch.settings.lazily_evaluate_kernels(False):\n", + " prod_covar = univariate_rbf_covars.prod(dim=-3)\n", + "prod_covar.shape # n x n" + ] + }, + { + "cell_type": "markdown", + "id": "68eaaa54-c481-45bb-aa77-fe79b6a6957d", + "metadata": {}, + "source": [ + "This particular example is a bit silly, since the multivariate RBF kernel is exactly equivalent to the product of $d$ univariate RBF kernels,\n", + "\n", + "$$\n", + " k_\\mathrm{RBF}(\\boldsymbol x, \\boldsymbol x') = \\prod_{i=1}^d k_\\mathrm{RBF}(x^{(i)}, x^{\\prime(i)}).\n", + "$$\n", + "\n", + "However, this strategy can actually become advantageous when we approximate each of the univariate component kernels using a scalable $\\ll \\mathcal O(n^3)$ approximation for each of the univariate kernels.\n", + "See [the tutorial on SKIP (structured kernel interpolation of products)](../02_Scalable_Exact_GPs/Scalable_Kernel_Interpolation_for_Products_CUDA.ipynb) for an example of exploiting product structure for scalability." + ] + }, + { + "cell_type": "markdown", + "id": "e83a742f-c24b-4cb2-be86-0c6e0cd67026", + "metadata": {}, + "source": [ + "## Summing Higher Order Interactions Between Univariate Kernels (Additive Gaussian Processes)\n", + "\n", + "[Duvenaud et al. (2011)](https://arxiv.org/pdf/1112.4394) introduce \"Additive Gaussian Processes,\" which are GPs that additively compose interaction terms between univariate kernels.\n", + "For example, with $d$-dimensional data and a max-degree of $3$ interaction terms, the corresponding kernel would be:\n", + "\n", + "$$\n", + "\\begin{align*}\n", + " k(\\boldsymbol x, \\boldsymbol x')\n", + " &= \\sum_{i=1}^d k_i(x^{(i)}, x^{\\prime(i)}; \\ell^{(i)}) \\\\\n", + " &+ \\sum_{i \\ne j} k_i(x^{(i)}, x^{\\prime(i)}; \\ell^{(i)}) k_j(x^{(j)}, x^{\\prime(j)}; \\ell^{(j)}) \\\\\n", + " &+ \\sum_{h \\ne i \\ne j} k_h(x^{(h)}, x^{\\prime(h)}; \\ell^{(h)}) k_i(x^{(i)}, x^{\\prime(i)}; \\ell^{(j)}) k_j(x^{(j)}, x^{\\prime(j)}; \\ell^{(j)})\n", + "\\end{align*}\n", + "$$\n", + "\n", + "Despite the summations having an exponential number of terms, this kernel can be computed in $\\mathcal O(d^2)$ time using the Newton-Girard formula.\n", + "\n", + "To compute this kernel in GPyTorch, we begin with a batch of the univariate covariance matrices (stored in a `d x n x n` Tensor or LinearOperator). We follow the same techniques as we used before:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7ea7337c-6375-4287-8133-6117316e0791", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 10, 10])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "d = 4\n", + "n = 10\n", + "\n", + "batch_univariate_rbf_kernel = gpytorch.kernels.RBFKernel(\n", + " batch_shape=torch.Size([d]), ard_num_dims=1,\n", + ")\n", + "X = torch.randn(n, d)\n", + "\n", + "with gpytorch.settings.lazily_evaluate_kernels(False):\n", + " univariate_rbf_covars = batch_univariate_rbf_kernel(X.mT.unsqueeze(-1))\n", + "univariate_rbf_covars.shape # d x n x n" + ] + }, + { + "cell_type": "markdown", + "id": "fd7ba35a-87e3-4a86-acd7-de44184462cc", + "metadata": {}, + "source": [ + "We then use the `gpytorch.utils.sum_interaction_terms` to compute and sum all of the higher-order interaction terms in $\\mathcal O(d^2)$ time:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "7f6cb2da-46b6-4cd5-9e46-6e3ae49e8444", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([10, 10])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "covar = gpytorch.utils.sum_interaction_terms(univariate_rbf_covars, max_degree=3, dim=-3)\n", + "covar.shape # n x n" + ] + }, + { + "cell_type": "markdown", + "id": "1492c4cc-874b-4423-aab3-ea97501730bc", + "metadata": {}, + "source": [ + "The full GP proposed by [Duvenaud et al. (2011)](https://arxiv.org/pdf/1112.4394) would then look like:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f6dc72f7-d077-426c-9219-15b5cb371198", + "metadata": {}, + "outputs": [], + "source": [ + "class AdditiveGP(gpytorch.models.ExactGP):\n", + " def __init__(self, X_train, y_train, d, max_degree):\n", + " likelihood = gpytorch.likelihoods.GaussianLikelihood()\n", + " super().__init__(X_train, y_train, likelihood)\n", + "\n", + " self.mean_module = gpytorch.means.ConstantMean()\n", + " self.covar_module = gpytorch.kernels.ScaleKernel(\n", + " gpytorch.kernels.RBFKernel(batch_shape=torch.Size([d]), ard_num_dims=1)\n", + " )\n", + " self.max_degree = max_degree\n", + "\n", + " def forward(self, X):\n", + " mean = self.mean_module(X)\n", + " batched_dimensions_of_X = X.mT.unsqueeze(-1) # Now a d x n x 1 tensor\n", + " univariate_rbf_covars = self.covar_module(batched_dimensions_of_X)\n", + " covar = gpytorch.utils.sum_interaction_terms(\n", + " univariate_rbf_covars, max_degree=self.max_degree, dim=-3\n", + " )\n", + " return gpytorch.distributions.MultivariateNormal(mean, covar)" + ] + }, + { + "cell_type": "markdown", + "id": "ba613c79-92a4-41e7-92f7-4b6f598286bd", + "metadata": {}, + "source": [ + "*(For those familiar with previous versions of GPyTorch, `sum_interaction_terms` replaces what was previously implemented by `NewtonGirardAdditiveKernel`.)*" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/01_Exact_GPs/GP_Regression_Fully_Bayesian.ipynb b/examples/01_Exact_GPs/GP_Regression_Fully_Bayesian.ipynb index 10efcfc0a..4a34a6083 100644 --- a/examples/01_Exact_GPs/GP_Regression_Fully_Bayesian.ipynb +++ b/examples/01_Exact_GPs/GP_Regression_Fully_Bayesian.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Fully Bayesian GPs - Sampling Hyperparamters with NUTS\n", + "# Fully Bayesian GPs - Sampling Hyperparameters with NUTS\n", "\n", "In this notebook, we'll demonstrate how to integrate GPyTorch and NUTS to sample GP hyperparameters and perform GP inference in a fully Bayesian way.\n", "\n", diff --git a/examples/01_Exact_GPs/Simple_GP_Regression.ipynb b/examples/01_Exact_GPs/Simple_GP_Regression.ipynb index 8acaf75bc..92b3c6672 100644 --- a/examples/01_Exact_GPs/Simple_GP_Regression.ipynb +++ b/examples/01_Exact_GPs/Simple_GP_Regression.ipynb @@ -112,7 +112,7 @@ "\n", "The simplest likelihood for regression is the `gpytorch.likelihoods.GaussianLikelihood`. This assumes a homoskedastic noise model (i.e. all inputs have the same observational noise).\n", "\n", - "There are other options for exact GP regression, such as the [FixedNoiseGaussianLikelihood](http://docs.gpytorch.ai/likelihoods.html#fixednoisegaussianlikelihood), which assigns a different observed noise value to different training inputs." + "There are other options for exact GP regression, such as the [FixedNoiseGaussianLikelihood](https://docs.gpytorch.ai/en/latest/likelihoods.html#fixednoisegaussianlikelihood), which assigns a different observed noise value to different training inputs." ] }, { diff --git a/examples/02_Scalable_Exact_GPs/KISSGP_Regression.ipynb b/examples/02_Scalable_Exact_GPs/KISSGP_Regression.ipynb index 05e6e89e4..81a5ae71c 100644 --- a/examples/02_Scalable_Exact_GPs/KISSGP_Regression.ipynb +++ b/examples/02_Scalable_Exact_GPs/KISSGP_Regression.ipynb @@ -424,77 +424,26 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## KISS-GP for higher dimensional data w/ Additive Structure" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The method above won't scale to data with much more than ~4 dimensions, since the cost of creating the grid grows exponentially in the amount of data. Therefore, we'll need to make some additional approximations.\n", + "## Scaling to more dimensions\n", "\n", - "If the function you are modeling has additive structure across its dimensions, then SKI can be one of the most efficient methods for your problem.\n", + "If your data is high dimensional, try one of the following methods:\n", "\n", - "To set this up, we'll wrap the `GridInterpolationKernel` used in the previous two models with one additional kernel: the `AdditiveStructureKernel`. The model will look something like this:" + "1. SKIP - or [Scalable Kernel Interpolation for Products](./Scalable_Kernel_Interpolation_for_Products_CUDA.ipynb)\n", + "2. KISS-GP with [Deep Kernel Learning](../06_PyTorch_NN_Integration_DKL/KISSGP_Deep_Kernel_Regression_CUDA.ipynb)" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "class GPRegressionModel(gpytorch.models.ExactGP):\n", - " def __init__(self, train_x, train_y, likelihood):\n", - " super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)\n", - " \n", - " # SKI requires a grid size hyperparameter. This util can help with that\n", - " # We're setting Kronecker structure to False because we're using an additive structure decomposition\n", - " grid_size = gpytorch.utils.grid.choose_grid_size(train_x, kronecker_structure=False)\n", - " \n", - " self.mean_module = gpytorch.means.ConstantMean()\n", - " self.covar_module = gpytorch.kernels.AdditiveStructureKernel(\n", - " gpytorch.kernels.ScaleKernel(\n", - " gpytorch.kernels.GridInterpolationKernel(\n", - " gpytorch.kernels.RBFKernel(), grid_size=128, num_dims=1\n", - " )\n", - " ), num_dims=2\n", - " )\n", - "\n", - " def forward(self, x):\n", - " mean_x = self.mean_module(x)\n", - " covar_x = self.covar_module(x)\n", - " return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n", - "\n", - " \n", - "likelihood = gpytorch.likelihoods.GaussianLikelihood()\n", - "model = GPRegressionModel(train_x, train_y, likelihood)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Essentially, the `AdditiveStructureKernel` makes the base kernel (in this case, the `GridInterpolationKernel` wrapping the `RBFKernel`) to act as 1D kernels on each data dimension. The final kernel matrix will be a sum of these 1D kernel matrices." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Scaling to more dimensions (without additive structure)\n", - "\n", - "If you can't exploit additive structure, then try one of these other two methods:\n", - "\n", - "1. SKIP - or [Scalable Kernel Interpolation for Products](./Scalable_Kernel_Interpolation_for_Products_CUDA.ipynb)\n", - "2. KISS-GP with [Deep Kernel Learning](../06_PyTorch_NN_Integration_DKL/KISSGP_Deep_Kernel_Regression_CUDA.ipynb)" - ] + "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -508,9 +457,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.10.0" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/examples/02_Scalable_Exact_GPs/Scalable_Kernel_Interpolation_for_Products_CUDA.ipynb b/examples/02_Scalable_Exact_GPs/Scalable_Kernel_Interpolation_for_Products_CUDA.ipynb index f8963e2e7..05f1ccb7c 100644 --- a/examples/02_Scalable_Exact_GPs/Scalable_Kernel_Interpolation_for_Products_CUDA.ipynb +++ b/examples/02_Scalable_Exact_GPs/Scalable_Kernel_Interpolation_for_Products_CUDA.ipynb @@ -17,18 +17,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/jrg365/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py:465: UserWarning: matplotlibrc text.usetex option can not be used unless TeX is installed on your system\n", - " warnings.warn('matplotlibrc text.usetex option can not be used unless '\n", - "/home/jrg365/anaconda3/lib/python3.6/site-packages/matplotlib/__init__.py:473: UserWarning: matplotlibrc text.usetex can not be used with *Agg backend unless dvipng-1.6 or later is installed on your system\n", - " 'your system' % dvipng_req)\n" - ] - } - ], + "outputs": [], "source": [ "import math\n", "import torch\n", @@ -50,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -92,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -101,7 +90,7 @@ "torch.Size([16599, 18])" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -118,44 +107,46 @@ "\n", "We now define the GP model. For more details on the use of GP models, see our simpler examples. This model uses a `GridInterpolationKernel` (SKI) with an RBF base kernel. To use SKIP, we make two changes:\n", "\n", - "- First, we use only a 1 dimensional `GridInterpolationKernel` (e.g., by passing `num_dims=1`). The idea of SKIP is to use a product of 1 dimensional `GridInterpolationKernel`s instead of a single `d` dimensional one.\n", - "- Next, we create a `ProductStructureKernel` that wraps our 1D `GridInterpolationKernel` with `num_dims=18`. This specifies that we want to use product structure over 18 dimensions, using the 1D `GridInterpolationKernel` in each dimension.\n", + "- First, we define our `base_covar_module` to have a batch_shape equal to the dimensionality of the data.\n", + " We make this change because we will the `base_covar_module` to construct a batch of univariate kernels\n", + " which we will then multiply using SKIP.\n", + "- We use only a 1 dimensional `GridInterpolationKernel` (e.g., by passing `num_dims=1`). The idea of SKIP is to use a product of 1 dimensional `GridInterpolationKernel`s instead of a single `d` dimensional one.\n", + "- In the `forward` call, we reshape `x` to be `d x n x 1` before passing it through the `covar_module`.\n", + " Our `covar_module` produces a batch of univariate kernels, and `x` must treat each dimension as a batch.\n", + "- After constructing our univariate covariance matrices, we multiply them all together by calling `.prod(dim=-3)`.\n", "\n", - "**Note:** If you've explored the rest of the package, you may be wondering what the differences between `AdditiveKernel`, `AdditiveStructureKernel`, `ProductKernel`, and `ProductStructureKernel` are. The `Structure` kernels (1) assume that we want to apply a single base kernel over a fully decomposed dataset (e.g., every dimension is additive or has product structure), and (2) are significantly more efficient as a result, because they can exploit batch parallel operations instead of using for loops." + "For more details on this construction, see the [Kernels with Additive or Product Structure tutorial](../00_Basic_Usage/kernels_with_additive_or_product_structure.ipynb)." ] }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "collapsed": true - }, + "execution_count": 4, + "metadata": {}, "outputs": [], "source": [ "from gpytorch.means import ConstantMean\n", - "from gpytorch.kernels import ScaleKernel, RBFKernel, ProductStructureKernel, GridInterpolationKernel\n", + "from gpytorch.kernels import ScaleKernel, RBFKernel, GridInterpolationKernel\n", "from gpytorch.distributions import MultivariateNormal\n", "\n", "class GPRegressionModel(gpytorch.models.ExactGP):\n", " def __init__(self, train_x, train_y, likelihood):\n", " super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)\n", " self.mean_module = ConstantMean()\n", - " self.base_covar_module = RBFKernel()\n", - " self.covar_module = ProductStructureKernel(\n", - " ScaleKernel(\n", - " GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1)\n", - " ), num_dims=18\n", + " self.base_covar_module = RBFKernel(batch_shape=torch.Size([train_x.size(-1)]))\n", + " self.covar_module = ScaleKernel(\n", + " GridInterpolationKernel(self.base_covar_module, grid_size=100, num_dims=1)\n", " )\n", "\n", " def forward(self, x):\n", " mean_x = self.mean_module(x)\n", - " covar_x = self.covar_module(x)\n", + " univariate_covars = self.covar_module(x.mT.unsqueeze(-1))\n", + " covar_x = univariate_covars.prod(dim=-3)\n", " return MultivariateNormal(mean_x, covar_x)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -180,42 +171,77 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": { - "scrolled": false - }, + "execution_count": 6, + "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/gpleiss/workspace/linear_operator/linear_operator/utils/sparse.py:51: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", + " if nonzero_indices.storage():\n", + "/home/gpleiss/workspace/linear_operator/linear_operator/utils/sparse.py:66: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:78.)\n", + " res = cls(index_tensor, value_tensor, interp_size)\n", + "/home/gpleiss/workspace/linear_operator/linear_operator/utils/sparse.py:66: UserWarning: torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated. Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=). (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:621.)\n", + " res = cls(index_tensor, value_tensor, interp_size)\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Iter 1/25 - Loss: 0.942\n", - "Iter 2/25 - Loss: 0.919\n", - "Iter 3/25 - Loss: 0.888\n", - "Iter 4/25 - Loss: 0.864\n", - "Iter 5/25 - Loss: 0.840\n", - "Iter 6/25 - Loss: 0.816\n", - "Iter 7/25 - Loss: 0.792\n", - "Iter 8/25 - Loss: 0.767\n", - "Iter 9/25 - Loss: 0.743\n", - "Iter 10/25 - Loss: 0.719\n", - "Iter 11/25 - Loss: 0.698\n", - "Iter 12/25 - Loss: 0.671\n", - "Iter 13/25 - Loss: 0.651\n", - "Iter 14/25 - Loss: 0.624\n", - "Iter 15/25 - Loss: 0.600\n", - "Iter 16/25 - Loss: 0.576\n", - "Iter 17/25 - Loss: 0.553\n", - "Iter 18/25 - Loss: 0.529\n", - "Iter 19/25 - Loss: 0.506\n", - "Iter 20/25 - Loss: 0.483\n", - "Iter 21/25 - Loss: 0.460\n", - "Iter 22/25 - Loss: 0.441\n", - "Iter 23/25 - Loss: 0.413\n", - "Iter 24/25 - Loss: 0.391\n", - "Iter 25/25 - Loss: 0.375\n", - "CPU times: user 1min 6s, sys: 26.8 s, total: 1min 33s\n", - "Wall time: 1min 48s\n" + "Iter 1/50 - Loss: 0.782\n", + "Iter 2/50 - Loss: 0.767\n", + "Iter 3/50 - Loss: 0.749\n", + "Iter 4/50 - Loss: 0.733\n", + "Iter 5/50 - Loss: 0.717\n", + "Iter 6/50 - Loss: 0.699\n", + "Iter 7/50 - Loss: 0.682\n", + "Iter 8/50 - Loss: 0.665\n", + "Iter 9/50 - Loss: 0.648\n", + "Iter 10/50 - Loss: 0.631\n", + "Iter 11/50 - Loss: 0.613\n", + "Iter 12/50 - Loss: 0.596\n", + "Iter 13/50 - Loss: 0.578\n", + "Iter 14/50 - Loss: 0.561\n", + "Iter 15/50 - Loss: 0.544\n", + "Iter 16/50 - Loss: 0.526\n", + "Iter 17/50 - Loss: 0.509\n", + "Iter 18/50 - Loss: 0.491\n", + "Iter 19/50 - Loss: 0.474\n", + "Iter 20/50 - Loss: 0.457\n", + "Iter 21/50 - Loss: 0.439\n", + "Iter 22/50 - Loss: 0.422\n", + "Iter 23/50 - Loss: 0.405\n", + "Iter 24/50 - Loss: 0.388\n", + "Iter 25/50 - Loss: 0.372\n", + "Iter 26/50 - Loss: 0.355\n", + "Iter 27/50 - Loss: 0.339\n", + "Iter 28/50 - Loss: 0.322\n", + "Iter 29/50 - Loss: 0.306\n", + "Iter 30/50 - Loss: 0.291\n", + "Iter 31/50 - Loss: 0.276\n", + "Iter 32/50 - Loss: 0.261\n", + "Iter 33/50 - Loss: 0.246\n", + "Iter 34/50 - Loss: 0.232\n", + "Iter 35/50 - Loss: 0.218\n", + "Iter 36/50 - Loss: 0.204\n", + "Iter 37/50 - Loss: 0.191\n", + "Iter 38/50 - Loss: 0.179\n", + "Iter 39/50 - Loss: 0.167\n", + "Iter 40/50 - Loss: 0.155\n", + "Iter 41/50 - Loss: 0.144\n", + "Iter 42/50 - Loss: 0.134\n", + "Iter 43/50 - Loss: 0.124\n", + "Iter 44/50 - Loss: 0.114\n", + "Iter 45/50 - Loss: 0.106\n", + "Iter 46/50 - Loss: 0.097\n", + "Iter 47/50 - Loss: 0.089\n", + "Iter 48/50 - Loss: 0.082\n", + "Iter 49/50 - Loss: 0.075\n", + "Iter 50/50 - Loss: 0.068\n", + "CPU times: user 53.2 s, sys: 3.96 s, total: 57.1 s\n", + "Wall time: 1min 16s\n" ] } ], @@ -227,7 +253,7 @@ "likelihood.train()\n", "\n", "# Use the adam optimizer\n", - "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=0.05)\n", "\n", "# \"Loss\" for GPs - the marginal log likelihood\n", "mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n", @@ -246,9 +272,7 @@ " optimizer.step()\n", " torch.cuda.empty_cache()\n", " \n", - "# See dkl_mnist.ipynb for explanation of this flag\n", - "with gpytorch.settings.use_toeplitz(True):\n", - " %time train()" + "%time train()" ] }, { @@ -262,48 +286,39 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "model.eval()\n", "likelihood.eval()\n", "with gpytorch.settings.max_preconditioner_size(10), torch.no_grad():\n", - " with gpytorch.settings.use_toeplitz(False), gpytorch.settings.max_root_decomposition_size(30), gpytorch.settings.fast_pred_var():\n", + " with gpytorch.settings.max_root_decomposition_size(30), gpytorch.settings.fast_pred_var():\n", " preds = model(test_x)" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Test MAE: 0.07745790481567383\n" + "Test MAE: 0.18244513869285583\n" ] } ], "source": [ "print('Test MAE: {}'.format(torch.mean(torch.abs(preds.mean - test_y))))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -317,9 +332,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.10.0" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/examples/02_Scalable_Exact_GPs/Simple_GP_Regression_CUDA.ipynb b/examples/02_Scalable_Exact_GPs/Simple_GP_Regression_CUDA.ipynb index 296302040..db920a3fe 100644 --- a/examples/02_Scalable_Exact_GPs/Simple_GP_Regression_CUDA.ipynb +++ b/examples/02_Scalable_Exact_GPs/Simple_GP_Regression_CUDA.ipynb @@ -21,11 +21,11 @@ "$$\n", "\\begin{align}\n", " y &= \\sin(2\\pi x) + \\epsilon \\\\ \n", - " \\epsilon &\\sim \\mathcal{N}(0, 0.2) \n", + " \\epsilon &\\sim \\mathcal{N}(0, 0.04) \n", "\\end{align}\n", "$$\n", "\n", - "with 11 training examples, and testing on 51 test examples." + "with 100 training examples, and testing on 51 test examples." ] }, { @@ -50,7 +50,7 @@ "source": [ "### Set up training data\n", "\n", - "In the next cell, we set up the training data for this example. We'll be using 11 regularly spaced points on [0,1] which we evaluate the function on and add Gaussian noise to get the training labels." + "In the next cell, we set up the training data for this example. We'll be using 100 regularly spaced points on [0,1] which we evaluate the function on and add Gaussian noise to get the training labels." ] }, { @@ -59,10 +59,10 @@ "metadata": {}, "outputs": [], "source": [ - "# Training data is 11 points in [0,1] inclusive regularly spaced\n", + "# Training data is 100 points in [0,1] inclusive regularly spaced\n", "train_x = torch.linspace(0, 1, 100)\n", "# True function is sin(2*pi*x) with Gaussian noise\n", - "train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2" + "train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * math.sqrt(0.04)" ] }, { diff --git a/examples/04_Variational_and_Approximate_GPs/SVGP_Multitask_GP_Regression.ipynb b/examples/04_Variational_and_Approximate_GPs/SVGP_Multitask_GP_Regression.ipynb index fd5e2433c..5fdcf70e9 100644 --- a/examples/04_Variational_and_Approximate_GPs/SVGP_Multitask_GP_Regression.ipynb +++ b/examples/04_Variational_and_Approximate_GPs/SVGP_Multitask_GP_Regression.ipynb @@ -73,8 +73,8 @@ "\n", "We are going to construct a batch variational GP - using a `CholeskyVariationalDistribution` and a `VariationalStrategy`. Each of the batch dimensions is going to correspond to one of the outputs. In addition, we will wrap the variational strategy to make the output appear as a `MultitaskMultivariateNormal` distribution. Here are the changes that we'll need to make:\n", "\n", - "1. Our inducing points will need to have shape `2 x m x 1` (where `m` is the number of inducing points). This ensures that we learn a different set of inducing points for each output dimension.\n", - "1. The `CholeskyVariationalDistribution`, mean module, and covariance modules will all need to include a `batch_shape=torch.Size([2])` argument. This ensures that we learn a different set of variational parameters and hyperparameters for each output dimension.\n", + "1. Our inducing points will need to have shape `4 x m x 1` (where `m` is the number of inducing points). This ensures that we learn a different set of inducing points for each output dimension.\n", + "1. The `CholeskyVariationalDistribution`, mean module, and covariance modules will all need to include a `batch_shape=torch.Size([4])` argument. This ensures that we learn a different set of variational parameters and hyperparameters for each output dimension.\n", "1. The `VariationalStrategy` object should be wrapped by a variational strategy that handles multitask models. We describe them below:\n", "\n", "\n", @@ -97,7 +97,7 @@ "num_tasks = 4\n", "\n", "class MultitaskGPModel(gpytorch.models.ApproximateGP):\n", - " def __init__(self):\n", + " def __init__(self, num_latents, num_tasks):\n", " # Let's use a different set of inducing points for each latent function\n", " inducing_points = torch.rand(num_latents, 16, 1)\n", " \n", @@ -113,8 +113,8 @@ " gpytorch.variational.VariationalStrategy(\n", " self, inducing_points, variational_distribution, learn_inducing_locations=True\n", " ),\n", - " num_tasks=4,\n", - " num_latents=3,\n", + " num_tasks=num_tasks,\n", + " num_latents=num_latents,\n", " latent_dim=-1\n", " )\n", " \n", @@ -136,7 +136,7 @@ " return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n", "\n", "\n", - "model = MultitaskGPModel()\n", + "model = MultitaskGPModel(num_latents, num_tasks)\n", "likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=num_tasks)" ] }, @@ -181,7 +181,7 @@ "outputs": [], "source": [ "class IndependentMultitaskGPModel(gpytorch.models.ApproximateGP):\n", - " def __init__(self):\n", + " def __init__(self, num_tasks):\n", " # Let's use a different set of inducing points for each task\n", " inducing_points = torch.rand(num_tasks, 16, 1)\n", " \n", @@ -195,7 +195,7 @@ " gpytorch.variational.VariationalStrategy(\n", " self, inducing_points, variational_distribution, learn_inducing_locations=True\n", " ),\n", - " num_tasks=4,\n", + " num_tasks=num_tasks,\n", " )\n", " \n", " super().__init__(variational_strategy)\n", diff --git a/examples/08_Advanced_Usage/.gitignore b/examples/08_Advanced_Usage/.gitignore new file mode 100644 index 000000000..4b6ebe5ff --- /dev/null +++ b/examples/08_Advanced_Usage/.gitignore @@ -0,0 +1 @@ +*.pt diff --git a/gpytorch/kernels/__init__.py b/gpytorch/kernels/__init__.py index cc85fe624..55119b784 100644 --- a/gpytorch/kernels/__init__.py +++ b/gpytorch/kernels/__init__.py @@ -2,6 +2,7 @@ from . import keops from .additive_structure_kernel import AdditiveStructureKernel from .arc_kernel import ArcKernel +from .constant_kernel import ConstantKernel from .cosine_kernel import CosineKernel from .cylindrical_kernel import CylindricalKernel from .distributional_input_kernel import DistributionalInputKernel @@ -14,6 +15,7 @@ from .kernel import AdditiveKernel, Kernel, ProductKernel from .lcm_kernel import LCMKernel from .linear_kernel import LinearKernel +from .matern52_kernel_grad import Matern52KernelGrad from .matern_kernel import MaternKernel from .multi_device_kernel import MultiDeviceKernel from .multitask_kernel import MultitaskKernel @@ -38,6 +40,7 @@ "ArcKernel", "AdditiveKernel", "AdditiveStructureKernel", + "ConstantKernel", "CylindricalKernel", "MultiDeviceKernel", "CosineKernel", @@ -67,4 +70,5 @@ "ScaleKernel", "SpectralDeltaKernel", "SpectralMixtureKernel", + "Matern52KernelGrad", ] diff --git a/gpytorch/kernels/additive_structure_kernel.py b/gpytorch/kernels/additive_structure_kernel.py index 1c2ba9c6e..35dfea259 100644 --- a/gpytorch/kernels/additive_structure_kernel.py +++ b/gpytorch/kernels/additive_structure_kernel.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import warnings from typing import Optional, Tuple from .kernel import Kernel @@ -47,6 +48,12 @@ def __init__( num_dims: int, active_dims: Optional[Tuple[int, ...]] = None, ): + warnings.warn( + "AdditiveStructureKernel is deprecated, and will be removed in GPyTorch 2.0. " + 'Please refer to the "Kernels with Additive or Product Structure" tutorial ' + "in the GPyTorch docs for how to implement GPs with additive structure.", + DeprecationWarning, + ) super(AdditiveStructureKernel, self).__init__(active_dims=active_dims) self.base_kernel = base_kernel self.num_dims = num_dims diff --git a/gpytorch/kernels/constant_kernel.py b/gpytorch/kernels/constant_kernel.py new file mode 100644 index 000000000..98a3560e2 --- /dev/null +++ b/gpytorch/kernels/constant_kernel.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 + +from typing import Optional, Tuple + +import torch +from torch import Tensor + +from ..constraints import Interval, Positive +from ..priors import Prior +from .kernel import Kernel + + +class ConstantKernel(Kernel): + """ + Constant covariance kernel for the probabilistic inference of constant coefficients. + + ConstantKernel represents the prior variance `k(x1, x2) = var(c)` of a constant `c`. + The prior variance of the constant is optimized during the GP hyper-parameter + optimization stage. The actual value of the constant is computed (implicitly) using + the linear algebraic approaches for the computation of GP samples and posteriors. + + The constant kernel `k_constant` is most useful as a modification of an arbitrary + base kernel `k_base`: + 1) Additive constants: The modification `k_base + k_constant` allows the GP to + infer a non-zero asymptotic value far from the training data, which generally + leads to more accurate extrapolation. Notably, the uncertainty in this constant + value affects the posterior covariances through the posterior inference equations. + This is not the case when a constant prior mean is not used, since the prior mean + does not show up the posterior covariance and is regularized by the log-determinant + during the optimization of the marginal likelihood. + 2) Multiplicative constants: The modification `k_base * k_constant` allows the GP to + modulate the variance of the kernel `k_base`, and is mathematically identical to + `ScaleKernel(base_kernel)` with the same constant. + """ + + has_lengthscale = False + + def __init__( + self, + batch_shape: Optional[torch.Size] = None, + constant_prior: Optional[Prior] = None, + constant_constraint: Optional[Interval] = None, + active_dims: Optional[Tuple[int, ...]] = None, + ): + """Constructor of ConstantKernel. + + Args: + batch_shape: The batch shape of the kernel. + constant_prior: Prior over the constant parameter. + constant_constraint: Constraint to place on constant parameter. + active_dims: The dimensions of the input with which to evaluate the kernel. + This is mute for the constant kernel, but added for compatability with + the Kernel API. + """ + super().__init__(batch_shape=batch_shape, active_dims=active_dims) + + self.register_parameter( + name="raw_constant", + parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)), + ) + + if constant_prior is not None: + if not isinstance(constant_prior, Prior): + raise TypeError("Expected gpytorch.priors.Prior but got " + type(constant_prior).__name__) + self.register_prior( + "constant_prior", + constant_prior, + lambda m: m.constant, + lambda m, v: m._set_constant(v), + ) + + if constant_constraint is None: + constant_constraint = Positive() + self.register_constraint("raw_constant", constant_constraint) + + @property + def constant(self) -> Tensor: + return self.raw_constant_constraint.transform(self.raw_constant) + + @constant.setter + def constant(self, value: Tensor) -> None: + self._set_constant(value) + + def _set_constant(self, value: Tensor) -> None: + value = value.view(*self.batch_shape, 1) + self.initialize(raw_constant=self.raw_constant_constraint.inverse_transform(value)) + + def forward( + self, + x1: Tensor, + x2: Tensor, + diag: Optional[bool] = False, + last_dim_is_batch: Optional[bool] = False, + ) -> Tensor: + """Evaluates the constant kernel. + + Args: + x1: First input tensor of shape (batch_shape x n1 x d). + x2: Second input tensor of shape (batch_shape x n2 x d). + diag: If True, returns the diagonal of the covariance matrix. + last_dim_is_batch: If True, the last dimension of size `d` of the input + tensors are treated as a batch dimension. + + Returns: + A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of + constant covariance values if diag is False, resp. True. + """ + if last_dim_is_batch: + x1 = x1.transpose(-1, -2).unsqueeze(-1) + x2 = x2.transpose(-1, -2).unsqueeze(-1) + + dtype = torch.promote_types(x1.dtype, x2.dtype) + batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) + shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],)) + constant = self.constant.to(dtype=dtype, device=x1.device) + + if not diag: + constant = constant.unsqueeze(-1) + + if last_dim_is_batch: + constant = constant.unsqueeze(-1) + + return constant.expand(shape) diff --git a/gpytorch/kernels/grid_kernel.py b/gpytorch/kernels/grid_kernel.py index 3c9b33b70..8a3503943 100644 --- a/gpytorch/kernels/grid_kernel.py +++ b/gpytorch/kernels/grid_kernel.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import warnings from typing import Optional import torch @@ -139,7 +140,9 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): # Use padded grid for batch mode first_grid_point = torch.stack([proj[0].unsqueeze(0) for proj in grid], dim=-1) full_grid = torch.stack(padded_grid, dim=-1) - covars = to_dense(self.base_kernel(first_grid_point, full_grid, last_dim_is_batch=True, **params)) + with warnings.catch_warnings(): # Hide the GPyTorch 2.0 deprecation warning + warnings.simplefilter("ignore", DeprecationWarning) + covars = to_dense(self.base_kernel(first_grid_point, full_grid, last_dim_is_batch=True, **params)) if last_dim_is_batch: # Toeplitz expects batches of columns so we concatenate the @@ -155,7 +158,9 @@ def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params): covar = KroneckerProductLinearOperator(*covars[::-1]) else: full_grid = torch.stack(padded_grid, dim=-1) - covars = to_dense(self.base_kernel(full_grid, full_grid, last_dim_is_batch=True, **params)) + with warnings.catch_warnings(): # Hide the GPyTorch 2.0 deprecation warning + warnings.simplefilter("ignore", DeprecationWarning) + covars = to_dense(self.base_kernel(full_grid, full_grid, last_dim_is_batch=True, **params)) if last_dim_is_batch: # Note that this requires all the dimensions to have the same number of grid points covar = covars diff --git a/gpytorch/kernels/kernel.py b/gpytorch/kernels/kernel.py index b8e5f34ec..67e576db3 100644 --- a/gpytorch/kernels/kernel.py +++ b/gpytorch/kernels/kernel.py @@ -236,7 +236,7 @@ def forward( ) -> Union[Tensor, LinearOperator]: r""" Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`. - This method should be imlemented by all Kernel subclasses. + This method should be implemented by all Kernel subclasses. :param x1: First set of data (... x N x D). :param x2: Second set of data (... x M x D). @@ -485,6 +485,15 @@ def __call__( * `diag`: `... x N` * `diag` with `last_dim_is_batch=True`: `... x K x N` """ + if last_dim_is_batch: + warnings.warn( + "The last_dim_is_batch argument is deprecated, and will be removed in GPyTorch 2.0. " + "If you are using it as part of AdditiveStructureKernel or ProductStructureKernel, " + 'please update your code according to the "Kernels with Additive or Product Structure" ' + "tutorial in the GPyTorch docs.", + DeprecationWarning, + ) + x1_, x2_ = x1, x2 # Select the active dimensions diff --git a/gpytorch/kernels/linear_kernel.py b/gpytorch/kernels/linear_kernel.py index d7ecd1014..51936766e 100644 --- a/gpytorch/kernels/linear_kernel.py +++ b/gpytorch/kernels/linear_kernel.py @@ -1,6 +1,5 @@ #!/usr/bin/env python3 -import warnings from typing import Optional, Union import torch @@ -40,15 +39,15 @@ class LinearKernel(Kernel): \top} \mathbf v)`, where the base multiply :math:`\mathbf X \mathbf v` takes only :math:`\mathcal O(ND)` time and space. + :param ard_num_dims: Set this if you want a separate variance priors for each weight. (Default: `None`) :param variance_prior: Prior over the variance parameter. (Default `None`.) :param variance_constraint: Constraint to place on variance parameter. (Default: `Positive`.) - :param active_dims: List of data dimensions to operate on. `len(active_dims)` should equal `num_dimensions`. + :param active_dims: List of data dimensions to operate on. """ def __init__( self, - num_dimensions: Optional[int] = None, - offset_prior: Optional[Prior] = None, + ard_num_dims: Optional[int] = None, variance_prior: Optional[Prior] = None, variance_constraint: Optional[Interval] = None, **kwargs, @@ -56,15 +55,12 @@ def __init__( super(LinearKernel, self).__init__(**kwargs) if variance_constraint is None: variance_constraint = Positive() - - if num_dimensions is not None: - # Remove after 1.0 - warnings.warn("The `num_dimensions` argument is deprecated and no longer used.", DeprecationWarning) - self.register_parameter(name="offset", parameter=torch.nn.Parameter(torch.zeros(1, 1, num_dimensions))) - if offset_prior is not None: - # Remove after 1.0 - warnings.warn("The `offset_prior` argument is deprecated and no longer used.", DeprecationWarning) - self.register_parameter(name="raw_variance", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1))) + self.register_parameter( + name="raw_variance", + parameter=torch.nn.Parameter( + torch.zeros(*self.batch_shape, 1, 1 if ard_num_dims is None else ard_num_dims) + ), + ) if variance_prior is not None: if not isinstance(variance_prior, Prior): raise TypeError("Expected gpytorch.priors.Prior but got " + type(variance_prior).__name__) diff --git a/gpytorch/kernels/matern52_kernel_grad.py b/gpytorch/kernels/matern52_kernel_grad.py new file mode 100644 index 000000000..04aa95c2f --- /dev/null +++ b/gpytorch/kernels/matern52_kernel_grad.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 + +import math + +import torch +from linear_operator.operators import KroneckerProductLinearOperator + +from gpytorch.kernels.matern_kernel import MaternKernel + +sqrt5 = math.sqrt(5) +five_thirds = 5.0 / 3.0 + + +class Matern52KernelGrad(MaternKernel): + r""" + Computes a covariance matrix of the Matern52 kernel that models the covariance + between the values and partial derivatives for inputs :math:`\mathbf{x_1}` + and :math:`\mathbf{x_2}`. + + See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options. + + .. note:: + + This kernel does not have an `outputscale` parameter. To add a scaling parameter, + decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`. + + :param ard_num_dims: Set this if you want a separate lengthscale for each input + dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.) + :param batch_shape: Set this if you want a separate lengthscale for each batch of input + data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is + a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor. + :param active_dims: Set this if you want to compute the covariance of only + a few input dimensions. The ints corresponds to the indices of the + dimensions. (Default: `None`.) + :param lengthscale_prior: Set this if you want to apply a prior to the + lengthscale parameter. (Default: `None`) + :param lengthscale_constraint: Set this if you want to apply a constraint + to the lengthscale parameter. (Default: `Positive`.) + :param eps: The minimum value that the lengthscale can take (prevents + divide by zero errors). (Default: `1e-6`.) + + :ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the + ard_num_dims and batch_shape arguments. + + Example: + >>> x = torch.randn(10, 5) + >>> # Non-batch: Simple option + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad()) + >>> covar = covar_module(x) # Output: LinearOperator of size (60 x 60), where 60 = n * (d + 1) + >>> + >>> batch_x = torch.randn(2, 10, 5) + >>> # Batch: Simple option + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad()) + >>> # Batch: different lengthscale for each batch + >>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad(batch_shape=torch.Size([2]))) # noqa: E501 + >>> covar = covar_module(x) # Output: LinearOperator of size (2 x 60 x 60) + """ + + def __init__(self, **kwargs): + + # remove nu in case it was set + kwargs.pop("nu", None) + super(Matern52KernelGrad, self).__init__(nu=2.5, **kwargs) + + def forward(self, x1, x2, diag=False, **params): + + lengthscale = self.lengthscale + + batch_shape = x1.shape[:-2] + n_batch_dims = len(batch_shape) + n1, d = x1.shape[-2:] + n2 = x2.shape[-2] + + if not diag: + + K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype) + + distance_matrix = self.covar_dist(x1.div(lengthscale), x2.div(lengthscale), diag=diag, **params) + exp_neg_sqrt5r = torch.exp(-sqrt5 * distance_matrix) + + # differences matrix in each dimension to be used for derivatives + # shape of n1 x n2 x d + outer = x1.view(*batch_shape, n1, 1, d) - x2.view(*batch_shape, 1, n2, d) + outer = outer / lengthscale.unsqueeze(-2) ** 2 + # shape of n1 x d x n2 + outer = torch.transpose(outer, -1, -2).contiguous() + + # 1) Kernel block, cov(f^m, f^n) + # shape is n1 x n2 + exp_component = torch.exp(-sqrt5 * distance_matrix) + constant_component = (sqrt5 * distance_matrix).add(1).add(five_thirds * distance_matrix**2) + + K[..., :n1, :n2] = constant_component * exp_component + + # 2) First gradient block, cov(f^m, omega^n_d) + outer1 = outer.view(*batch_shape, n1, n2 * d) + K[..., :n1, n2:] = outer1 * (-five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat( + [*([1] * (n_batch_dims + 1)), d] + ) + + # 3) Second gradient block, cov(omega^m_d, f^n) + outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d) + outer2 = outer2.transpose(-1, -2) + # the - signs on -outer2 and -five_thirds cancel out + K[..., n1:, :n2] = outer2 * (five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat( + [*([1] * n_batch_dims), d, 1] + ) + + # 4) Hessian block, cov(omega^m_d, omega^n_d) + outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d]) + kp = KroneckerProductLinearOperator( + torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / lengthscale**2, + torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1), + ) + + part1 = -five_thirds * exp_neg_sqrt5r + part2 = 5 * outer3 + part3 = 1 + sqrt5 * distance_matrix + + K[..., n1:, n2:] = part1.repeat([*([1] * n_batch_dims), d, d]).mul_( + # need to use kp.to_dense().mul instead of kp.to_dense().mul_ + # because otherwise a RuntimeError is raised due to how autograd works with + # view + inplace operations in the case of 1-dimensional input + part2.sub_(kp.to_dense().mul(part3.repeat([*([1] * n_batch_dims), d, d]))) + ) + + # Symmetrize for stability + if n1 == n2 and torch.eq(x1, x2).all(): + K = 0.5 * (K.transpose(-1, -2) + K) + + # Apply a perfect shuffle permutation to match the MutiTask ordering + pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1))) + pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1))) + K = K[..., pi1, :][..., :, pi2] + + return K + else: + if not (n1 == n2 and torch.eq(x1, x2).all()): + raise RuntimeError("diag=True only works when x1 == x2") + + # nu is set to 2.5 + kernel_diag = super(Matern52KernelGrad, self).forward(x1, x2, diag=True) + grad_diag = ( + five_thirds * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype) + ) / lengthscale**2 + grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d) + k_diag = torch.cat((kernel_diag, grad_diag), dim=-1) + pi = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1))) + return k_diag[..., pi] + + def num_outputs_per_input(self, x1, x2): + return x1.size(-1) + 1 diff --git a/gpytorch/kernels/newton_girard_additive_kernel.py b/gpytorch/kernels/newton_girard_additive_kernel.py index 13a9e5ae8..89be591cf 100644 --- a/gpytorch/kernels/newton_girard_additive_kernel.py +++ b/gpytorch/kernels/newton_girard_additive_kernel.py @@ -1,3 +1,6 @@ +#!/usr/bin/env python3 + +import warnings from typing import Optional, Tuple import torch @@ -23,6 +26,13 @@ def __init__( :param active_dims: :param kwargs: """ + + warnings.warn( + "NewtonGirardAdditiveKernel is deprecated, and will be removed in GPyTorch 2.0. " + 'Please refer to the "Kernels with Additive or Product Structure" tutorial ' + "in the GPyTorch docs for how to implement GPs with additive structure.", + DeprecationWarning, + ) super(NewtonGirardAdditiveKernel, self).__init__(active_dims=active_dims, **kwargs) self.base_kernel = base_kernel diff --git a/gpytorch/kernels/periodic_kernel.py b/gpytorch/kernels/periodic_kernel.py index 1232b96ae..2972b523a 100644 --- a/gpytorch/kernels/periodic_kernel.py +++ b/gpytorch/kernels/periodic_kernel.py @@ -78,7 +78,7 @@ class PeriodicKernel(Kernel): >>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10) .. _David Mackay's Introduction to Gaussian Processes equation 47: - http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.81.1927&rep=rep1&type=pdf + https://inference.org.uk/mackay/gpB.pdf """ has_lengthscale = True diff --git a/gpytorch/kernels/product_structure_kernel.py b/gpytorch/kernels/product_structure_kernel.py index f25f8d7a7..49f782876 100644 --- a/gpytorch/kernels/product_structure_kernel.py +++ b/gpytorch/kernels/product_structure_kernel.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import warnings from typing import Optional, Tuple from linear_operator.operators import to_linear_operator @@ -54,6 +55,13 @@ def __init__( num_dims: int, active_dims: Optional[Tuple[int, ...]] = None, ): + warnings.warn( + "ProductStructureKernel is deprecated, and will be removed in GPyTorch 2.0. " + 'Please refer to the "Kernels with Additive or Product Structure" tutorial ' + "in the GPyTorch docs for how to implement GPs with product structure.", + DeprecationWarning, + ) + super(ProductStructureKernel, self).__init__(active_dims=active_dims) self.base_kernel = base_kernel self.num_dims = num_dims diff --git a/gpytorch/kernels/rbf_kernel.py b/gpytorch/kernels/rbf_kernel.py index 932e59724..97cdc23c4 100644 --- a/gpytorch/kernels/rbf_kernel.py +++ b/gpytorch/kernels/rbf_kernel.py @@ -30,9 +30,9 @@ class RBFKernel(Kernel): decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`. :param ard_num_dims: Set this if you want a separate lengthscale for each input - dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.) + dimension. It should be `d` if :math:`\mathbf{x_1}` is a `n x d` matrix. (Default: `None`.) :param batch_shape: Set this if you want a separate lengthscale for each batch of input - data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is + data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf{x_1}` is a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor. :param active_dims: Set this if you want to compute the covariance of only a few input dimensions. The ints corresponds to the indices of the diff --git a/gpytorch/kernels/rff_kernel.py b/gpytorch/kernels/rff_kernel.py index 68455d220..c6b5e4ccd 100644 --- a/gpytorch/kernels/rff_kernel.py +++ b/gpytorch/kernels/rff_kernel.py @@ -35,7 +35,7 @@ class RFFKernel(Kernel): .. math:: \begin{equation} - k(\Delta) = \exp{(-\frac{\Delta^2}{2\sigma^2})}$ and $p(\omega) = \exp{(-\frac{\sigma^2\omega^2}{2})} + k(\Delta) = \exp{(-\frac{\Delta^2}{2\sigma^2})} \text{ and } p(\omega) = \exp{(-\frac{\sigma^2\omega^2}{2})} \end{equation} where :math:`\Delta = x - x'`. diff --git a/gpytorch/kernels/spectral_delta_kernel.py b/gpytorch/kernels/spectral_delta_kernel.py index 4262c0178..176e788ae 100644 --- a/gpytorch/kernels/spectral_delta_kernel.py +++ b/gpytorch/kernels/spectral_delta_kernel.py @@ -52,7 +52,7 @@ def initialize_from_data(self, train_x, train_y): """ import numpy as np from scipy.fftpack import fft - from scipy.integrate import cumtrapz + from scipy.integrate import cumulative_trapezoid N = train_x.size(-2) emp_spect = np.abs(fft(train_y.cpu().detach().numpy())) ** 2 / N @@ -65,7 +65,7 @@ def initialize_from_data(self, train_x, train_y): emp_spect = emp_spect[: M + 1] total_area = np.trapz(emp_spect, freq) - spec_cdf = np.hstack((np.zeros(1), cumtrapz(emp_spect, freq))) + spec_cdf = np.hstack((np.zeros(1), cumulative_trapezoid(emp_spect, freq))) spec_cdf = spec_cdf / total_area a = np.random.rand(self.raw_Z.size(-2), 1) diff --git a/gpytorch/kernels/spectral_mixture_kernel.py b/gpytorch/kernels/spectral_mixture_kernel.py index e63185ff4..c8de79010 100644 --- a/gpytorch/kernels/spectral_mixture_kernel.py +++ b/gpytorch/kernels/spectral_mixture_kernel.py @@ -167,7 +167,7 @@ def initialize_from_data_empspect(self, train_x: torch.Tensor, train_y: torch.Te import numpy as np from scipy.fftpack import fft - from scipy.integrate import cumtrapz + from scipy.integrate import cumulative_trapezoid with torch.no_grad(): if not torch.is_tensor(train_x) or not torch.is_tensor(train_y): @@ -192,7 +192,7 @@ def initialize_from_data_empspect(self, train_x: torch.Tensor, train_y: torch.Te emp_spect = emp_spect[: M + 1] total_area = np.trapz(emp_spect, freq) - spec_cdf = np.hstack((np.zeros(1), cumtrapz(emp_spect, freq))) + spec_cdf = np.hstack((np.zeros(1), cumulative_trapezoid(emp_spect, freq))) spec_cdf = spec_cdf / total_area a = np.random.rand(1000, self.ard_num_dims) diff --git a/gpytorch/mlls/exact_marginal_log_likelihood.py b/gpytorch/mlls/exact_marginal_log_likelihood.py index 7b2987f50..66f92eafd 100644 --- a/gpytorch/mlls/exact_marginal_log_likelihood.py +++ b/gpytorch/mlls/exact_marginal_log_likelihood.py @@ -49,7 +49,7 @@ def _add_other_terms(self, res, params): return res - def forward(self, function_dist, target, *params): + def forward(self, function_dist, target, *params, **kwargs): r""" Computes the MLL given :math:`p(\mathbf f)` and :math:`\mathbf y`. @@ -63,7 +63,7 @@ def forward(self, function_dist, target, *params): raise RuntimeError("ExactMarginalLogLikelihood can only operate on Gaussian random variables") # Determine output likelihood - output = self.likelihood(function_dist, *params) + output = self.likelihood(function_dist, *params, **kwargs) # Remove NaN values if enabled if settings.observation_nan_policy.value() == "mask": diff --git a/gpytorch/mlls/leave_one_out_pseudo_likelihood.py b/gpytorch/mlls/leave_one_out_pseudo_likelihood.py index f71c819e8..becb3b06e 100644 --- a/gpytorch/mlls/leave_one_out_pseudo_likelihood.py +++ b/gpytorch/mlls/leave_one_out_pseudo_likelihood.py @@ -47,7 +47,7 @@ def __init__(self, likelihood, model): def forward(self, function_dist: MultivariateNormal, target: Tensor, *params) -> Tensor: r""" - Computes the leave one out likelihood given :math:`p(\mathbf f)` and `\mathbf y` + Computes the leave one out likelihood given :math:`p(\mathbf f)` and :math:`\mathbf y` :param ~gpytorch.distributions.MultivariateNormal output: the outputs of the latent function (the :obj:`~gpytorch.models.GP`) diff --git a/gpytorch/models/exact_prediction_strategies.py b/gpytorch/models/exact_prediction_strategies.py index 2b716d73f..0a8092e15 100644 --- a/gpytorch/models/exact_prediction_strategies.py +++ b/gpytorch/models/exact_prediction_strategies.py @@ -10,7 +10,6 @@ AddedDiagLinearOperator, BatchRepeatLinearOperator, ConstantMulLinearOperator, - DenseLinearOperator, InterpolatedLinearOperator, LinearOperator, LowRankRootAddedDiagLinearOperator, @@ -211,8 +210,8 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ # now update the root and root inverse new_lt = self.lik_train_train_covar.cat_rows(fant_train_covar, fant_fant_covar) - new_root = new_lt.root_decomposition().root.to_dense() - new_covar_cache = new_lt.root_inv_decomposition().root.to_dense() + new_root = new_lt.root_decomposition().root + new_covar_cache = new_lt.root_inv_decomposition().root # Expand inputs accordingly if necessary (for fantasies at the same points) if full_inputs[0].dim() <= full_targets.dim(): @@ -222,7 +221,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ full_inputs = [fi.expand(fant_batch_shape + fi.shape) for fi in full_inputs] full_mean = full_mean.expand(fant_batch_shape + full_mean.shape) full_covar = BatchRepeatLinearOperator(full_covar, repeat_shape) - new_root = BatchRepeatLinearOperator(DenseLinearOperator(new_root), repeat_shape) + new_root = BatchRepeatLinearOperator(new_root, repeat_shape) # no need to repeat the covar cache, broadcasting will do the right thing if isinstance(full_output, MultitaskMultivariateNormal): @@ -238,7 +237,7 @@ def get_fantasy_strategy(self, inputs, targets, full_inputs, full_targets, full_ inv_root=new_covar_cache, ) add_to_cache(fant_strat, "mean_cache", fant_mean_cache) - add_to_cache(fant_strat, "covar_cache", new_covar_cache) + add_to_cache(fant_strat, "covar_cache", new_covar_cache.to_dense()) return fant_strat @property @@ -866,5 +865,5 @@ def exact_predictive_covar(self, test_test_covar, test_train_covar): "This is likely a bug in GPyTorch." ) - res = test_test_covar - (L @ (covar_cache @ L.transpose(-1, -2))) + res = test_test_covar - MatmulLinearOperator(L, covar_cache @ L.mT) return res diff --git a/gpytorch/priors/prior.py b/gpytorch/priors/prior.py index 1a6e6e1f7..2c5468bf1 100644 --- a/gpytorch/priors/prior.py +++ b/gpytorch/priors/prior.py @@ -1,10 +1,18 @@ #!/usr/bin/env python3 from abc import ABC +from typing import Any, Mapping +from torch.distributions import TransformedDistribution from torch.nn import Module from ..distributions import Distribution +from .utils import _load_transformed_to_base_dist + + +TRANSFORMED_ERROR_MSG = """Priors of TransformedDistributions should not have their \ +'_transformed' attributes modified, these are just copies of the base attribute. \ +Please modify the base attribute (e.g. {}) instead.""" class Prior(Distribution, Module, ABC): @@ -25,3 +33,20 @@ def log_prob(self, x): :rtype: torch.Tensor """ return super(Prior, self).log_prob(self.transform(x)) + + def load_state_dict(self, state_dict: Mapping[str, Any], *args, **kwargs): + Module.load_state_dict(self, state_dict, *args, **kwargs) + if isinstance(self, TransformedDistribution): + _load_transformed_to_base_dist(self) + + def __setattr__(self, name: str, value: Any) -> None: + if hasattr(self, name) and "_transformed_" in name: + base_attr_name = name.replace("_transformed_", "") + raise AttributeError(TRANSFORMED_ERROR_MSG.format(base_attr_name)) + + elif hasattr(self, f"_transformed_{name}"): + self.base_dist.__setattr__(name, value) + super().__setattr__(f"_transformed_{name}", value) + + else: + return super().__setattr__(name, value) diff --git a/gpytorch/priors/torch_priors.py b/gpytorch/priors/torch_priors.py index 5e5dd2669..a3e243384 100644 --- a/gpytorch/priors/torch_priors.py +++ b/gpytorch/priors/torch_priors.py @@ -40,6 +40,7 @@ class HalfNormalPrior(Prior, HalfNormal): def __init__(self, scale, validate_args=None, transform=None): TModule.__init__(self) HalfNormal.__init__(self, scale=scale, validate_args=validate_args) + _bufferize_attributes(self, ("scale",)) self._transform = transform def expand(self, batch_shape): @@ -54,6 +55,7 @@ class LogNormalPrior(Prior, LogNormal): def __init__(self, loc, scale, validate_args=None, transform=None): TModule.__init__(self) LogNormal.__init__(self, loc=loc, scale=scale, validate_args=validate_args) + _bufferize_attributes(self, ("loc", "scale")) self._transform = transform def expand(self, batch_shape): @@ -84,6 +86,7 @@ class HalfCauchyPrior(Prior, HalfCauchy): def __init__(self, scale, validate_args=None, transform=None): TModule.__init__(self) HalfCauchy.__init__(self, scale=scale, validate_args=validate_args) + _bufferize_attributes(self, ("scale",)) self._transform = transform def expand(self, batch_shape): diff --git a/gpytorch/priors/utils.py b/gpytorch/priors/utils.py index 3cfce190e..e4468ab78 100644 --- a/gpytorch/priors/utils.py +++ b/gpytorch/priors/utils.py @@ -1,11 +1,36 @@ #!/usr/bin/env python3 +from torch.distributions import TransformedDistribution + def _bufferize_attributes(module, attributes): - attr_clones = {attr: getattr(module, attr).clone() for attr in attributes} - for attr, value in attr_clones.items(): - delattr(module, attr) - module.register_buffer(attr, value) + r""" + Adds the parameters of the prior as a torch buffer to enable saving/ + loading to/from state_dicts. + For TransformedDistributions Adds a _transformed_ attribute to the + parameters. This enables its parameters to be saved and + loaded to/from state_dicts, as the original parameters cannot be. + """ + if isinstance(module, TransformedDistribution): + for attr in attributes: + module.register_buffer(f"_transformed_{attr}", getattr(module, attr)) + else: + attr_clones = {attr: getattr(module, attr).clone() for attr in attributes} + for attr, value in attr_clones.items(): + delattr(module, attr) + module.register_buffer(attr, value) + + +def _load_transformed_to_base_dist(module): + r"""loads the _transformed_ attributes to the parameters of a torch + TransformedDistribution. This enables its parameters to be saved and + loaded to/from state_dicts, as the original parameters cannot be. + """ + transf_str = "_transformed_" + transformed_attrs = [attr for attr in dir(module) if transf_str in attr] + for transf_attr in transformed_attrs: + base_attr_name = transf_attr.replace(transf_str, "") + setattr(module.base_dist, base_attr_name, getattr(module, transf_attr)) def _del_attributes(module, attributes, raise_on_error=False): diff --git a/gpytorch/test/base_keops_test_case.py b/gpytorch/test/base_keops_test_case.py index fb261c860..ca32b4d64 100644 --- a/gpytorch/test/base_keops_test_case.py +++ b/gpytorch/test/base_keops_test_case.py @@ -66,7 +66,7 @@ def test_forward_x1_neq_x2(self, use_keops=True, ard=False, **kwargs): # The patch makes sure that we're actually using KeOps k1 = kern1(x1, x2).to_dense() k2 = kern2(x1, x2).to_dense() - self.assertLess(torch.norm(k1 - k2), 1e-4) + self.assertLess(torch.norm(k1 - k2), 1e-3) if use_keops: self.assertTrue(keops_mock.called) @@ -86,7 +86,7 @@ def test_batch_matmul(self, use_keops=True, **kwargs): # The patch makes sure that we're actually using KeOps res1 = kern1(x1, x1).matmul(rhs) res2 = kern2(x1, x1).matmul(rhs) - self.assertLess(torch.norm(res1 - res2), 1e-4) + self.assertLess(torch.norm(res1 - res2), 1e-3) if use_keops: self.assertTrue(keops_mock.called) @@ -115,7 +115,7 @@ def test_gradient(self, use_keops=True, ard=False, **kwargs): # stack all gradients into a tensor grad_s1 = torch.vstack(torch.autograd.grad(s1, [*kern1.hyperparameters()])) grad_s2 = torch.vstack(torch.autograd.grad(s2, [*kern2.hyperparameters()])) - self.assertAllClose(grad_s1, grad_s2, rtol=1e-4, atol=1e-5) + self.assertAllClose(grad_s1, grad_s2, rtol=1e-3, atol=1e-3) if use_keops: self.assertTrue(keops_mock.called) diff --git a/gpytorch/test/base_kernel_test_case.py b/gpytorch/test/base_kernel_test_case.py index 5301ce2d9..88f6afbd5 100644 --- a/gpytorch/test/base_kernel_test_case.py +++ b/gpytorch/test/base_kernel_test_case.py @@ -122,23 +122,21 @@ def test_no_batch_kernel_double_batch_x_ard(self): actual_diag = actual_covar_mat.diagonal(dim1=-1, dim2=-2) self.assertAllClose(kernel_diag, actual_diag, rtol=1e-3, atol=1e-5) - def test_smoke_double_batch_kernel_double_batch_x_no_ard(self): + def test_smoke_double_batch_kernel_double_batch_x_no_ard(self) -> None: kernel = self.create_kernel_no_ard(batch_shape=torch.Size([3, 2])) x = self.create_data_double_batch() - batch_covar_mat = kernel(x).evaluate_kernel().to_dense() + kernel(x).evaluate_kernel().to_dense() kernel(x, diag=True) - return batch_covar_mat - def test_smoke_double_batch_kernel_double_batch_x_ard(self): + def test_smoke_double_batch_kernel_double_batch_x_ard(self) -> None: try: kernel = self.create_kernel_ard(num_dims=2, batch_shape=torch.Size([3, 2])) except NotImplementedError: return x = self.create_data_double_batch() - batch_covar_mat = kernel(x).evaluate_kernel().to_dense() + kernel(x).evaluate_kernel().to_dense() kernel(x, diag=True) - return batch_covar_mat def test_kernel_getitem_single_batch(self): kernel = self.create_kernel_no_ard(batch_shape=torch.Size([2])) diff --git a/gpytorch/utils/__init__.py b/gpytorch/utils/__init__.py index 5b8d9e391..76ce357c2 100644 --- a/gpytorch/utils/__init__.py +++ b/gpytorch/utils/__init__.py @@ -10,6 +10,7 @@ from . import deprecation, errors, generic, grid, interpolation, quadrature, transforms, warnings from .memoize import cached from .nearest_neighbors import NNUtil +from .sum_interaction_terms import sum_interaction_terms __all__ = [ "cached", @@ -19,6 +20,7 @@ "grid", "interpolation", "quadrature", + "sum_interaction_terms", "transforms", "warnings", "NNUtil", diff --git a/gpytorch/utils/sum_interaction_terms.py b/gpytorch/utils/sum_interaction_terms.py new file mode 100644 index 000000000..be7462132 --- /dev/null +++ b/gpytorch/utils/sum_interaction_terms.py @@ -0,0 +1,59 @@ +from typing import Optional, Union + +import torch + +from jaxtyping import Float +from linear_operator import LinearOperator, to_dense +from torch import Tensor + + +def sum_interaction_terms( + covars: Float[Union[LinearOperator, Tensor], "... D N N"], + max_degree: Optional[int] = None, + dim: int = -3, +) -> Float[Tensor, "... N N"]: + r""" + Given a batch of D x N x N covariance matrices :math:`\boldsymbol K_1, \ldots, \boldsymbol K_D`, + compute the sum of each covariance matrix as well as the interaction terms up to degree `max_degree` + (denoted as :math:`M` below): + + .. math:: + + \sum_{1 \leq i_1 < i_2 < \ldots < i_M < D} \left[ + \prod_{j=1}^M \boldsymbol K_{i_j} + \right]. + + This function is useful for computing the sum of additive kernels as defined in + `Additive Gaussian Processes (Duvenaud et al., 2011)`_. + + Note that the summation is computed in :math:`\mathcal O(D)` time using the Newton-Girard formula. + + .. _Additive Gaussian Processes (Duvenaud et al., 2011): + https://arxiv.org/pdf/1112.4394 + + :param covars: A batch of covariance matrices, representing the base covariances to sum over + :param max_degree: The maximum degree of the interaction terms to compute. + If not provided, this will default to `D`. + :param dim: The dimension to sum over (i.e. the batch dimension containing the base covariance matrices). + Note that dim must be a negative integer (i.e. -3, not 0). + """ + if dim >= 0: + raise ValueError("Argument 'dim' must be a negative integer.") + + covars = to_dense(covars) + ks = torch.arange(max_degree, dtype=covars.dtype, device=covars.device) + neg_one = torch.tensor(-1.0, dtype=covars.dtype, device=covars.device) + + # S_times_factor[k] = factor[k] * S[k] + # = (-1)^{k} * \sum_{i=1}^D covar_i^{k+1} + S_times_factor_ks = torch.vmap(lambda k: neg_one.pow(k) * torch.sum(covars.pow(k + 1), dim=dim))(ks) + + # E[deg] = 1/(deg+1) \sum_{j=0}^{deg} factor[k] * S[k] * E[deg-k] + # = 1/(deg+1) [ (factor[deg] * S[deg]) + \sum_{j=1}^{deg - 1} factor * S_ks[k] * E_ks[deg-k] ] + E_ks = torch.empty_like(S_times_factor_ks) + E_ks[0] = S_times_factor_ks[0] + for deg in range(1, max_degree): + sum_term = torch.einsum("m...,m...->...", S_times_factor_ks[:deg], E_ks[:deg].flip(0)) + E_ks[deg] = (S_times_factor_ks[deg] + sum_term) / (deg + 1) + + return E_ks.sum(0) diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index 6f9b429b4..b1c3e8b4b 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -3,6 +3,7 @@ from typing import Any, Optional import torch +from jaxtyping import Float from linear_operator import to_dense from linear_operator.operators import DiagLinearOperator, LinearOperator, TriangularLinearOperator from linear_operator.utils.cholesky import psd_safe_cholesky @@ -62,7 +63,8 @@ class NNVariationalStrategy(UnwhitenedVariationalStrategy): VariationalDistribution object that represents the form of the variational distribution :math:`q(\mathbf u)` :param k: Number of nearest neighbors. :param training_batch_size: The number of data points that will be in the training batch size. - :param jitter_val: Amount of diagonal jitter to add for Cholesky factorization numerical stability + :param jitter_val: Amount of diagonal jitter to add for covariance matrix numerical stability. + :param compute_full_kl: Whether to compute full kl divergence or stochastic estimate. .. _Wu et al (2022): https://arxiv.org/pdf/2202.01694.pdf @@ -75,11 +77,12 @@ class NNVariationalStrategy(UnwhitenedVariationalStrategy): def __init__( self, model: ApproximateGP, - inducing_points: Tensor, - variational_distribution: _VariationalDistribution, + inducing_points: Float[Tensor, "... M D"], + variational_distribution: Float[_VariationalDistribution, "... M"], k: int, - training_batch_size: int, - jitter_val: Optional[float] = None, + training_batch_size: Optional[int] = None, + jitter_val: Optional[float] = 1e-3, + compute_full_kl: Optional[bool] = False, ): assert isinstance( variational_distribution, MeanFieldVariationalDistribution @@ -88,18 +91,15 @@ def __init__( super().__init__( model, inducing_points, variational_distribution, learn_inducing_locations=False, jitter_val=jitter_val ) - # Make sure we don't try to initialize variational parameters - because of minibatching - self.variational_params_initialized.fill_(1) # Model object.__setattr__(self, "model", model) self.inducing_points = inducing_points - self.M: int = inducing_points.shape[-2] - self.D: int = inducing_points.shape[-1] + self.M, self.D = inducing_points.shape[-2:] self.k = k - assert self.k <= self.M, ( - f"Number of nearest neighbors k must be smaller than or equal to number of inducing points, " + assert self.k < self.M, ( + f"Number of nearest neighbors k must be smaller than the number of inducing points, " f"but got k = {k}, M = {self.M}." ) @@ -111,37 +111,61 @@ def __init__( k, dim=self.D, batch_shape=self._inducing_batch_shape, device=inducing_points.device ) self._compute_nn() + # otherwise, no nearest neighbor approximation is used - self.training_batch_size = training_batch_size + self.training_batch_size = training_batch_size if training_batch_size is not None else self.M self._set_training_iterator() + self.compute_full_kl = compute_full_kl + @property @cached(name="prior_distribution_memo") - def prior_distribution(self) -> MultivariateNormal: + def prior_distribution(self) -> Float[MultivariateNormal, "... M"]: out = self.model.forward(self.inducing_points) res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val)) return res - def _cholesky_factor(self, induc_induc_covar: LinearOperator) -> TriangularLinearOperator: + def _cholesky_factor( + self, induc_induc_covar: Float[LinearOperator, "... M M"] + ) -> Float[TriangularLinearOperator, "... M M"]: # Uncached version L = psd_safe_cholesky(to_dense(induc_induc_covar)) return TriangularLinearOperator(L) - def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> MultivariateNormal: + def __call__( + self, x: Float[Tensor, "... N D"], prior: bool = False, **kwargs: Any + ) -> Float[MultivariateNormal, "... N"]: # If we're in prior mode, then we're done! if prior: return self.model.forward(x, **kwargs) if x is not None: - assert self.inducing_points.shape[:-2] == x.shape[:-2], ( - f"x batch shape must matches inducing points batch shape, " - f"but got train data batch shape = {x.shape[:-2]}, " - f"inducing points batch shape = {self.inducing_points.shape[:-2]}." - ) + # Make sure x and inducing points have the same batch shape + if not (self.inducing_points.shape[:-2] == x.shape[:-2]): + try: + x = x.expand(*self.inducing_points.shape[:-2], *x.shape[-2:]).contiguous() + except RuntimeError: + raise RuntimeError( + f"x batch shape must match or broadcast with the inducing points' batch shape, " + f"but got x batch shape = {x.shape[:-2]}, " + f"inducing points batch shape = {self.inducing_points.shape[:-2]}." + ) # Delete previously cached items from the training distribution if self.training: self._clear_cache() + + # (Maybe) initialize variational distribution + if not self.variational_params_initialized.item(): + prior_dist = self.prior_distribution + self._variational_distribution.variational_mean.data.copy_(prior_dist.mean) + self._variational_distribution.variational_mean.data.add_( + torch.randn_like(prior_dist.mean), alpha=self._variational_distribution.mean_init_std + ) + # initialize with a small variational stddev for quicker conv. of kl divergence + self._variational_distribution._variational_stddev.data.copy_(torch.tensor(1e-2)) + self.variational_params_initialized.fill_(1) + return self.forward( x, self.inducing_points, inducing_values=None, variational_inducing_covar=None, **kwargs ) @@ -152,12 +176,12 @@ def __call__(self, x: Tensor, prior: bool = False, **kwargs: Any) -> Multivariat def forward( self, - x: Tensor, - inducing_points: Tensor, - inducing_values: Optional[Tensor] = None, - variational_inducing_covar: Optional[LinearOperator] = None, + x: Float[Tensor, "... N D"], + inducing_points: Float[Tensor, "... M D"], + inducing_values: Float[Tensor, "... M"], + variational_inducing_covar: Optional[Float[LinearOperator, "... M M"]] = None, **kwargs: Any, - ) -> MultivariateNormal: + ) -> Float[MultivariateNormal, "... N"]: if self.training: # In training mode, note that the full inducing points set = full training dataset # Users have the option to choose input None or a tensor of training data for x @@ -193,20 +217,20 @@ def forward( return MultivariateNormal(predictive_mean, DiagLinearOperator(predictive_var)) else: - nn_indices = self.nn_util.find_nn_idx(x.float()) x_batch_shape = x.shape[:-2] + batch_shape = torch.broadcast_shapes(self._batch_shape, x_batch_shape) x_bsz = x.shape[-2] assert nn_indices.shape == (*x_batch_shape, x_bsz, self.k), nn_indices.shape + # select K nearest neighbors from inducing points for test point x expanded_nn_indices = nn_indices.unsqueeze(-1).expand(*x_batch_shape, x_bsz, self.k, self.D) expanded_inducing_points = inducing_points.unsqueeze(-2).expand(*x_batch_shape, self.M, self.k, self.D) inducing_points = expanded_inducing_points.gather(-3, expanded_nn_indices) assert inducing_points.shape == (*x_batch_shape, x_bsz, self.k, self.D) # get variational mean and covar for nearest neighbors - batch_shape = torch.broadcast_shapes(self._model_batch_shape, x_batch_shape) inducing_values = self._variational_distribution.variational_mean expanded_inducing_values = inducing_values.unsqueeze(-1).expand(*batch_shape, self.M, self.k) expanded_nn_indices = nn_indices.expand(*batch_shape, x_bsz, self.k) @@ -224,11 +248,24 @@ def forward( # Make everything batch mode x = x.unsqueeze(-2) assert x.shape == (*x_batch_shape, x_bsz, 1, self.D) + x = x.expand(*batch_shape, x_bsz, 1, self.D) # Compute forward mode in the standard way - dist = super().forward(x, inducing_points, inducing_values, variational_inducing_covar, **kwargs) - predictive_mean = dist.mean # (*batch_shape, x_bsz, 1) - predictive_covar = dist.covariance_matrix # (*batch_shape, x_bsz, 1, 1) + _batch_dims = tuple(range(len(batch_shape))) + _x = x.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, 1, D) + + # inducing_points.shape (*x_batch_shape, x_bsz, self.k, self.D) + inducing_points = inducing_points.expand(*batch_shape, x_bsz, self.k, self.D) + _inducing_points = inducing_points.permute((-3,) + _batch_dims + (-2, -1)) # (x_bsz, *batch_shape, k, D) + _inducing_values = inducing_values.permute((-2,) + _batch_dims + (-1,)) + _variational_inducing_covar = variational_inducing_covar.permute((-3,) + _batch_dims + (-2, -1)) + dist = super().forward(_x, _inducing_points, _inducing_values, _variational_inducing_covar, **kwargs) + + _x_batch_dims = tuple(range(1, 1 + len(batch_shape))) + predictive_mean = dist.mean # (x_bsz, *x_batch_shape, 1) + predictive_covar = dist.covariance_matrix # (x_bsz, *x_batch_shape, 1, 1) + predictive_mean = predictive_mean.permute(_x_batch_dims + (0, -1)) + predictive_covar = predictive_covar.permute(_x_batch_dims + (0, -2, -1)) # Undo batch mode predictive_mean = predictive_mean.squeeze(-1) @@ -241,8 +278,8 @@ def forward( def get_fantasy_model( self, - inputs: Tensor, - targets: Tensor, + inputs: Float[Tensor, "... N D"], + targets: Float[Tensor, "... N"], mean_module: Optional[Module] = None, covar_module: Optional[Module] = None, **kwargs, @@ -254,8 +291,15 @@ def get_fantasy_model( def _set_training_iterator(self) -> None: self._training_indices_iter = 0 - training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k - self._training_indices_iterator = (torch.arange(self.k),) + training_indices.split(self.training_batch_size) + if self.training_batch_size == self.M: + self._training_indices_iterator = (torch.arange(self.M, device=self.inducing_points.device),) + else: + # The first training batch always contains the first k inducing points + # This is because computing the KL divergence for the first k inducing points is special-cased + # (since the first k inducing points have < k neighbors) + # Note that there is a special function _firstk_kl_helper for this + training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k + self._training_indices_iterator = (torch.arange(self.k),) + training_indices.split(self.training_batch_size) self._total_training_batches = len(self._training_indices_iterator) def _get_training_indices(self) -> LongTensor: @@ -265,7 +309,7 @@ def _get_training_indices(self) -> LongTensor: self._set_training_iterator() return self.current_training_indices - def _firstk_kl_helper(self) -> Tensor: + def _firstk_kl_helper(self) -> Float[Tensor, "..."]: # Compute the KL divergence for first k inducing points train_x_firstk = self.inducing_points[..., : self.k, :] full_output = self.model.forward(train_x_firstk) @@ -283,77 +327,122 @@ def _firstk_kl_helper(self) -> Tensor: kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape return kl - def _stochastic_kl_helper(self, kl_indices: Tensor) -> Tensor: - # Compute the KL divergence for a mini batch of the rest M-1 inducing points + def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[Tensor, "..."]: # noqa: F821 + # Compute the KL divergence for a mini batch of the rest M-k inducing points # See paper appendix for kl breakdown - kl_bs = len(kl_indices) - variational_mean = self._variational_distribution.variational_mean + kl_bs = len(kl_indices) # training_batch_size + variational_mean = self._variational_distribution.variational_mean # (*model_bs, M) variational_stddev = self._variational_distribution._variational_stddev - # compute logdet_q + # (1) compute logdet_q inducing_point_log_variational_covar = (variational_stddev[..., kl_indices] ** 2).log() - logdet_q = torch.sum(inducing_point_log_variational_covar, dim=-1) + logdet_q = torch.sum(inducing_point_log_variational_covar, dim=-1) # model_bs - # Select a mini-batch of inducing points according to kl_indices, and their k-nearest neighbors - inducing_points = self.inducing_points[..., kl_indices, :] + # (2) compute lodet_p + # Select a mini-batch of inducing points according to kl_indices + inducing_points = self.inducing_points[..., kl_indices, :].expand(*self._batch_shape, kl_bs, self.D) + # (*bs, kl_bs, D) + # Select their K nearest neighbors nearest_neighbor_indices = self.nn_xinduce_idx[..., kl_indices - self.k, :].to(inducing_points.device) + # (*bs, kl_bs, K) expanded_inducing_points_all = self.inducing_points.unsqueeze(-2).expand( - *self._inducing_batch_shape, self.M, self.k, self.D + *self._batch_shape, self.M, self.k, self.D ) expanded_nearest_neighbor_indices = nearest_neighbor_indices.unsqueeze(-1).expand( - *self._inducing_batch_shape, kl_bs, self.k, self.D + *self._batch_shape, kl_bs, self.k, self.D ) nearest_neighbors = expanded_inducing_points_all.gather(-3, expanded_nearest_neighbor_indices) + # (*bs, kl_bs, K, D) + + # Compute prior distribution + # Move the kl_bs dimension to the first dimension to enable batch covar_module computation + nearest_neighbors_ = nearest_neighbors.permute((-3,) + tuple(range(len(self._batch_shape))) + (-2, -1)) + # (kl_bs, *bs, K, D) + inducing_points_ = inducing_points.permute((-2,) + tuple(range(len(self._batch_shape))) + (-1,)) + # (kl_bs, *bs, D) + full_output = self.model.forward(torch.cat([nearest_neighbors_, inducing_points_.unsqueeze(-2)], dim=-2)) + full_mean, full_covar = full_output.mean, full_output.covariance_matrix + + # Mean terms + _undo_permute_dims = tuple(range(1, 1 + len(self._batch_shape))) + (0, -1) + nearest_neighbors_prior_mean = full_mean[..., : self.k].permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K) + inducing_prior_mean = full_mean[..., self.k :].permute(_undo_permute_dims).squeeze(-1) # (*inducing_bs, kl_bs) + # Covar terms + nearest_neighbors_prior_cov = full_covar[..., : self.k, : self.k] + nearest_neighbors_inducing_prior_cross_cov = full_covar[..., : self.k, self.k :] + inducing_prior_cov = full_covar[..., self.k :, self.k :] + inducing_prior_cov = ( + inducing_prior_cov.squeeze(-1).squeeze(-1).permute((-1,) + tuple(range(len(self._batch_shape)))) + ) - # compute interp_term - cov = self.model.covar_module.forward(nearest_neighbors, nearest_neighbors) - cross_cov = to_dense(self.model.covar_module.forward(nearest_neighbors, inducing_points.unsqueeze(-2))) + # Interpolation term K_nn^{-1} k_{nu} interp_term = torch.linalg.solve( - cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device), cross_cov - ).squeeze(-1) - - # compte logdet_p - invquad_term_for_F = torch.sum(interp_term * cross_cov.squeeze(-1), dim=-1) - cov_inducing_points = self.model.covar_module.forward(inducing_points, inducing_points, diag=True) - F = cov_inducing_points - invquad_term_for_F + nearest_neighbors_prior_cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device), + nearest_neighbors_inducing_prior_cross_cov, + ).squeeze( + -1 + ) # (kl_bs, *inducing_bs, K) + interp_term = interp_term.permute(_undo_permute_dims) # (*inducing_bs, kl_bs, K) + nearest_neighbors_inducing_prior_cross_cov = nearest_neighbors_inducing_prior_cross_cov.squeeze(-1).permute( + _undo_permute_dims + ) # k_{n(j),j}, (*inducing_bs, kl_bs, K) + + invquad_term_for_F = torch.sum( + interp_term * nearest_neighbors_inducing_prior_cross_cov, dim=-1 + ) # (*inducing_bs, kl_bs) + + inducing_prior_cov = self.model.covar_module.forward( + inducing_points, inducing_points, diag=True + ) # (*inducing_bs, kl_bs) + + F = inducing_prior_cov - invquad_term_for_F F = F + self.jitter_val - logdet_p = F.log().sum(dim=-1) + # K_uu - k_un K_nn^{-1} k_nu + logdet_p = F.log().sum(dim=-1) # shape: inducing_bs - # compute trace_term + # (3) compute trace_term expanded_variational_stddev = variational_stddev.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k) expanded_variational_mean = variational_mean.unsqueeze(-1).expand(*self._batch_shape, self.M, self.k) expanded_nearest_neighbor_indices = nearest_neighbor_indices.expand(*self._batch_shape, kl_bs, self.k) nearest_neighbor_variational_covar = ( expanded_variational_stddev.gather(-2, expanded_nearest_neighbor_indices) ** 2 + ) # (*batch_shape, kl_bs, k) + bjsquared_s_nearest_neighbors = torch.sum( + interp_term**2 * nearest_neighbor_variational_covar, dim=-1 + ) # (*batch_shape, kl_bs) + inducing_point_variational_covar = variational_stddev[..., kl_indices] ** 2 # (model_bs, kl_bs) + trace_term = (1.0 / F * (bjsquared_s_nearest_neighbors + inducing_point_variational_covar)).sum( + dim=-1 + ) # batch_shape + + # (4) compute invquad_term + nearest_neighbors_variational_mean = expanded_variational_mean.gather(-2, expanded_nearest_neighbor_indices) + Bj_m_nearest_neighbors = torch.sum( + interp_term * (nearest_neighbors_variational_mean - nearest_neighbors_prior_mean), dim=-1 + ) + inducing_variational_mean = variational_mean[..., kl_indices] + invquad_term = torch.sum( + (inducing_variational_mean - inducing_prior_mean - Bj_m_nearest_neighbors) ** 2 / F, dim=-1 ) - bjsquared_s = torch.sum(interp_term**2 * nearest_neighbor_variational_covar, dim=-1) - inducing_point_covar = variational_stddev[..., kl_indices] ** 2 - trace_term = (1.0 / F * (bjsquared_s + inducing_point_covar)).sum(dim=-1) - - # compute invquad_term - nearest_neighbor_variational_mean = expanded_variational_mean.gather(-2, expanded_nearest_neighbor_indices) - Bj_m = torch.sum(interp_term * nearest_neighbor_variational_mean, dim=-1) - inducing_point_variational_mean = variational_mean[..., kl_indices] - invquad_term = torch.sum((inducing_point_variational_mean - Bj_m) ** 2 / F, dim=-1) kl = (logdet_p - logdet_q - kl_bs + trace_term + invquad_term) * (1.0 / 2) assert kl.shape == self._batch_shape, kl.shape - kl = kl.mean() return kl def _kl_divergence( - self, kl_indices: Optional[LongTensor] = None, compute_full: bool = False, batch_size: Optional[int] = None - ) -> Tensor: - if compute_full: + self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None + ) -> Float[Tensor, "..."]: + if self.compute_full_kl or (self._total_training_batches == 1): if batch_size is None: batch_size = self.training_batch_size kl = self._firstk_kl_helper() for kl_indices in torch.split(torch.arange(self.k, self.M), batch_size): kl += self._stochastic_kl_helper(kl_indices) else: + # compute a stochastic estimate assert kl_indices is not None - if (self._training_indices_iter == 1) or (self.M == self.k): + if self._training_indices_iter == 1: assert len(kl_indices) == self.k, ( f"kl_indices sould be the first batch data of length k, " f"but got len(kl_indices) = {len(kl_indices)} and k = {self.k}." @@ -363,7 +452,7 @@ def _kl_divergence( kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices) return kl - def kl_divergence(self) -> Tensor: + def kl_divergence(self) -> Float[Tensor, "..."]: try: return pop_from_cache(self, "kl_divergence_memo") except CachingError: @@ -374,4 +463,5 @@ def _compute_nn(self) -> "NNVariationalStrategy": inducing_points_fl = self.inducing_points.data.float() self.nn_util.set_nn_idx(inducing_points_fl) self.nn_xinduce_idx = self.nn_util.build_sequential_nn_idx(inducing_points_fl) + # shape (*_inducing_batch_shape, M-k, k) return self diff --git a/setup.cfg b/setup.cfg index 5e41c09c4..080837e12 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ max-line-length = 120 [flake8] max-line-length = 120 -ignore = E203, F403, F405, E731, E741, W503, W605 +ignore = E203, E731, E741, F403, F405, F722, W503, W605 exclude = build,examples diff --git a/setup.py b/setup.py index f86a41a7c..8464d2b14 100644 --- a/setup.py +++ b/setup.py @@ -37,11 +37,13 @@ def find_version(*file_paths): readme = open("README.md").read() -torch_min = "1.11" +torch_min = "2.0" install_requires = [ + "jaxtyping==0.2.19", + "mpmath>=0.19,<=1.3", # avoid incompatibiltiy with torch+sympy with mpmath 1.4 "scikit-learn", - "scipy", - "linear_operator>=0.5.2", + "scipy>=1.6.0", + "linear_operator>=0.5.3", ] # if recent dev version of PyTorch is installed, no need to install stable try: @@ -81,11 +83,13 @@ def find_version(*file_paths): "nbclient<=0.7.3", "nbformat<=5.8.0", "nbsphinx<=0.9.1", + "lxml_html_clean", "platformdirs<=3.2.0", "setuptools_scm<=7.1.0", "sphinx<=6.2.1", "sphinx_autodoc_typehints<=1.23.0", "sphinx_rtd_theme<0.5", + "uncompyle6<=3.9.0", ], "examples": ["ipython", "jupyter", "matplotlib", "scipy", "torchvision", "tqdm"], "keops": ["pykeops>=1.1.1"], diff --git a/test/examples/test_keops_gp_regression.py b/test/examples/test_keops_gp_regression.py index f47785adc..f422dadef 100644 --- a/test/examples/test_keops_gp_regression.py +++ b/test/examples/test_keops_gp_regression.py @@ -16,7 +16,7 @@ # Simple training data: let's try to learn a sine function -train_x = torch.randn(1000, 2) +train_x = torch.randn(300, 2) train_y = torch.sin(train_x[..., 0] * (2 * pi) + train_x[..., 1]) train_y = train_y + torch.randn_like(train_y).mul(0.001) @@ -52,11 +52,11 @@ def test_keops_gp_mean_abs_error(self): # Optimize the model gp_model.train() likelihood.train() - optimizer = optim.Adam(list(gp_model.parameters()), lr=0.01) + optimizer = optim.Adam(list(gp_model.parameters()), lr=0.1) optimizer.n_iter = 0 with gpytorch.settings.max_cholesky_size(0): # Ensure that we're using KeOps - for i in range(300): + for i in range(25): optimizer.zero_grad() output = gp_model(train_x) loss = -mll(output, train_y) diff --git a/test/examples/test_svgp_gp_classification.py b/test/examples/test_svgp_gp_classification.py index 1645b8c70..8a6efe689 100644 --- a/test/examples/test_svgp_gp_classification.py +++ b/test/examples/test_svgp_gp_classification.py @@ -16,7 +16,7 @@ def train_data(cuda=False): - train_x = torch.linspace(0, 1, 260) + train_x = torch.linspace(0, 1, 150) train_y = torch.cos(train_x * (2 * math.pi)).gt(0).float() if cuda: return train_x.cuda(), train_y.cuda() @@ -49,7 +49,7 @@ class TestSVGPClassification(BaseTestCase, unittest.TestCase): def test_classification_error(self, cuda=False, mll_cls=gpytorch.mlls.VariationalELBO): train_x, train_y = train_data(cuda=cuda) likelihood = BernoulliLikelihood() - model = SVGPClassificationModel(torch.linspace(0, 1, 25)) + model = SVGPClassificationModel(torch.linspace(0, 1, 64)) mll = mll_cls(likelihood, model, num_data=len(train_y)) if cuda: likelihood = likelihood.cuda() @@ -59,12 +59,12 @@ def test_classification_error(self, cuda=False, mll_cls=gpytorch.mlls.Variationa # Find optimal model hyperparameters model.train() likelihood.train() - optimizer = optim.Adam([{"params": model.parameters()}, {"params": likelihood.parameters()}], lr=0.1) + optimizer = optim.Adam([{"params": model.parameters()}, {"params": likelihood.parameters()}], lr=0.03) _wrapped_cg = MagicMock(wraps=linear_operator.utils.linear_cg) _cg_mock = patch("linear_operator.utils.linear_cg", new=_wrapped_cg) with _cg_mock as cg_mock: - for _ in range(400): + for _ in range(100): optimizer.zero_grad() output = model(train_x) loss = -mll(output, train_y) diff --git a/test/kernels/test_constant_kernel.py b/test/kernels/test_constant_kernel.py new file mode 100644 index 000000000..849ec3996 --- /dev/null +++ b/test/kernels/test_constant_kernel.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +import itertools +import unittest + +import torch + +from torch import Tensor + +from gpytorch.kernels import AdditiveKernel, ConstantKernel, MaternKernel, ProductKernel, ScaleKernel +from gpytorch.lazy import LazyEvaluatedKernelTensor +from gpytorch.priors.torch_priors import GammaPrior +from gpytorch.test.base_kernel_test_case import BaseKernelTestCase + + +class TestConstantKernel(unittest.TestCase, BaseKernelTestCase): + def create_kernel_no_ard(self, **kwargs): + return ConstantKernel(**kwargs) + + def test_constant_kernel(self): + with self.subTest(device="cpu"): + self._test_constant_kernel(torch.device("cpu")) + + if torch.cuda.is_available(): + with self.subTest(device="cuda"): + self._test_constant_kernel(torch.device("cuda")) + + def _test_constant_kernel(self, device: torch.device): + n, d = 3, 5 + dtypes = [torch.float, torch.double] + batch_shapes = [(), (2,), (7, 2)] + torch.manual_seed(123) + for dtype, batch_shape in itertools.product(dtypes, batch_shapes): + tkwargs = {"dtype": dtype, "device": device} + places = 6 if dtype == torch.float else 12 + X = torch.rand(*batch_shape, n, d, **tkwargs) + + constant_kernel = ConstantKernel(batch_shape=batch_shape) + KL = constant_kernel(X) + self.assertIsInstance(KL, LazyEvaluatedKernelTensor) + KM = KL.to_dense() + self.assertIsInstance(KM, Tensor) + self.assertEqual(KM.shape, (*batch_shape, n, n)) + self.assertEqual(KM.dtype, dtype) + self.assertEqual(KM.device.type, device.type) + # standard deviation is zero iff KM is constant + self.assertAlmostEqual(KM.std().item(), 0, places=places) + + # testing last_dim_is_batch + with self.subTest(last_dim_is_batch=True): + KD = constant_kernel(X, last_dim_is_batch=True).to(device=device) + self.assertIsInstance(KD, LazyEvaluatedKernelTensor) + KM = KD.to_dense() + self.assertIsInstance(KM, Tensor) + self.assertEqual(KM.shape, (*batch_shape, d, n, n)) + self.assertAlmostEqual(KM.std().item(), 0, places=places) + self.assertEqual(KM.dtype, dtype) + self.assertEqual(KM.device.type, device.type) + + # testing diag + with self.subTest(diag=True): + KD = constant_kernel(X, diag=True) + self.assertIsInstance(KD, Tensor) + self.assertEqual(KD.shape, (*batch_shape, n)) + self.assertAlmostEqual(KD.std().item(), 0, places=places) + self.assertEqual(KD.dtype, dtype) + self.assertEqual(KD.device.type, device.type) + + # testing diag and last_dim_is_batch + with self.subTest(diag=True, last_dim_is_batch=True): + KD = constant_kernel(X, diag=True, last_dim_is_batch=True) + self.assertIsInstance(KD, Tensor) + self.assertEqual(KD.shape, (*batch_shape, d, n)) + self.assertAlmostEqual(KD.std().item(), 0, places=places) + self.assertEqual(KD.dtype, dtype) + self.assertEqual(KD.device.type, device.type) + + # testing AD + with self.subTest(requires_grad=True): + X.requires_grad = True + constant_kernel(X).to_dense().sum().backward() + self.assertIsNone(X.grad) # constant kernel is not dependent on X + + # testing algebraic combinations with another kernel + base_kernel = MaternKernel().to(device=device) + + with self.subTest(additive=True): + sum_kernel = base_kernel + constant_kernel + self.assertIsInstance(sum_kernel, AdditiveKernel) + self.assertAllClose( + sum_kernel(X).to_dense(), + base_kernel(X).to_dense() + constant_kernel.constant.unsqueeze(-1), + ) + + # product with constant is equivalent to scale kernel + with self.subTest(product=True): + product_kernel = base_kernel * constant_kernel + self.assertIsInstance(product_kernel, ProductKernel) + + scale_kernel = ScaleKernel(base_kernel, batch_shape=batch_shape) + scale_kernel.to(device=device) + self.assertAllClose(scale_kernel(X).to_dense(), product_kernel(X).to_dense()) + + # setting constant + pies = torch.full_like(constant_kernel.constant, torch.pi) + constant_kernel.constant = pies + self.assertAllClose(constant_kernel.constant, pies) + + # specifying prior + constant_kernel = ConstantKernel(constant_prior=GammaPrior(concentration=2.4, rate=2.7)) + + with self.assertRaisesRegex(TypeError, "Expected gpytorch.priors.Prior but got"): + ConstantKernel(constant_prior=1) diff --git a/test/kernels/test_linear_kernel.py b/test/kernels/test_linear_kernel.py index b520842fd..d4eb2a033 100644 --- a/test/kernels/test_linear_kernel.py +++ b/test/kernels/test_linear_kernel.py @@ -10,14 +10,16 @@ class TestLinearKernel(unittest.TestCase, BaseKernelTestCase): + kernel_kwargs = {} + def create_kernel_no_ard(self, **kwargs): - return LinearKernel(**kwargs) + return LinearKernel(**kwargs, **self.kernel_kwargs) def test_computes_linear_function_rectangular(self): a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1) b = torch.tensor([0, 2, 1], dtype=torch.float).view(3, 1) - kernel = LinearKernel().initialize(variance=1.0) + kernel = self.create_kernel_no_ard().initialize(variance=1.0) kernel.eval() actual = torch.matmul(a, b.t()) res = kernel(a, b).to_dense() @@ -31,7 +33,7 @@ def test_computes_linear_function_rectangular(self): def test_computes_linear_function_square(self): a = torch.tensor([[4, 1], [2, 0], [8, 3]], dtype=torch.float) - kernel = LinearKernel().initialize(variance=3.14) + kernel = self.create_kernel_no_ard().initialize(variance=3.14) kernel.eval() actual = torch.matmul(a, a.t()) * 3.14 res = kernel(a, a).to_dense() @@ -57,7 +59,7 @@ def test_computes_linear_function_square(self): def test_computes_linear_function_square_batch(self): a = torch.tensor([[[4, 1], [2, 0], [8, 3]], [[1, 1], [2, 1], [1, 3]]], dtype=torch.float) - kernel = LinearKernel().initialize(variance=1.0) + kernel = self.create_kernel_no_ard().initialize(variance=1.0) kernel.eval() actual = torch.matmul(a, a.transpose(-1, -2)) res = kernel(a, a).to_dense() @@ -92,5 +94,20 @@ def test_prior_type(self): self.assertRaises(TypeError, self.create_kernel_with_prior, 1) +class TestLinearKernelARD(TestLinearKernel): + def test_kernel_ard(self) -> None: + self.kernel_kwargs = {"ard_num_dims": 2} + kernel = self.create_kernel_no_ard() + self.assertEqual(kernel.variance.shape, torch.Size([1, 2])) + + def test_computes_linear_function_rectangular(self): + self.kernel_kwargs = {"ard_num_dims": 1} + super().test_computes_linear_function_rectangular() + + def test_computes_linear_function_square_batch(self): + self.kernel_kwargs = {"ard_num_dims": 2} + super().test_computes_linear_function_square_batch() + + if __name__ == "__main__": unittest.main() diff --git a/test/kernels/test_matern52_kernel_grad.py b/test/kernels/test_matern52_kernel_grad.py new file mode 100644 index 000000000..5a76a2a33 --- /dev/null +++ b/test/kernels/test_matern52_kernel_grad.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +import unittest + +import torch + +from gpytorch.kernels import Matern52KernelGrad +from gpytorch.test.base_kernel_test_case import BaseKernelTestCase + + +class TestMatern52KernelGrad(unittest.TestCase, BaseKernelTestCase): + def create_kernel_no_ard(self, **kwargs): + return Matern52KernelGrad(**kwargs) + + def create_kernel_ard(self, num_dims, **kwargs): + return Matern52KernelGrad(ard_num_dims=num_dims, **kwargs) + + def test_kernel(self, cuda=False): + a = torch.tensor([[[1, 2], [2, 4]]], dtype=torch.float) + b = torch.tensor([[[1, 3], [0, 4]]], dtype=torch.float) + + actual = torch.tensor( + [ + [0.3056225, -0.0000000, 0.5822443, 0.0188260, -0.0209871, 0.0419742], + [0.0000000, 0.5822443, 0.0000000, 0.0209871, -0.0056045, 0.0531832], + [-0.5822443, 0.0000000, -0.8515886, -0.0419742, 0.0531832, -0.0853792], + [0.1304891, -0.2014212, -0.2014212, 0.0336440, -0.0815567, -0.0000000], + [0.2014212, -0.1754366, -0.3768578, 0.0815567, -0.1870145, -0.0000000], + [0.2014212, -0.3768578, -0.1754366, 0.0000000, -0.0000000, 0.0407784], + ] + ) + + kernel = Matern52KernelGrad() + + if cuda: + a = a.cuda() + b = b.cuda() + actual = actual.cuda() + kernel = kernel.cuda() + + res = kernel(a, b).to_dense() + + self.assertLess(torch.norm(res - actual), 1e-5) + + def test_kernel_cuda(self): + if torch.cuda.is_available(): + self.test_kernel(cuda=True) + + def test_kernel_batch(self): + a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float) + b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1) + + kernel = Matern52KernelGrad() + res = kernel(a, b).to_dense() + + # Compute each batch separately + actual = torch.zeros(2, 8, 8) + actual[0, :, :] = kernel(a[0, :, :].squeeze(), b[0, :, :].squeeze()).to_dense() + actual[1, :, :] = kernel(a[1, :, :].squeeze(), b[1, :, :].squeeze()).to_dense() + + self.assertLess(torch.norm(res - actual), 1e-5) + + def test_initialize_lengthscale(self): + kernel = Matern52KernelGrad() + kernel.initialize(lengthscale=3.14) + actual_value = torch.tensor(3.14).view_as(kernel.lengthscale) + self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5) + + def test_initialize_lengthscale_batch(self): + kernel = Matern52KernelGrad(batch_shape=torch.Size([2])) + ls_init = torch.tensor([3.14, 4.13]) + kernel.initialize(lengthscale=ls_init) + actual_value = ls_init.view_as(kernel.lengthscale) + self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/priors/test_prior.py b/test/priors/test_prior.py new file mode 100644 index 000000000..53fef5976 --- /dev/null +++ b/test/priors/test_prior.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +import unittest + +from torch import Tensor + +from gpytorch.priors import GammaPrior, HalfCauchyPrior, LogNormalPrior, NormalPrior + + +TRANSFORMED_ERROR_MSG = """Priors of TransformedDistributions should not have their \ +'_transformed' attributes modified, these are just copies of the base attribute. \ +Please modify the base attribute (e.g. {}) instead.""" + + +class TestPrior(unittest.TestCase): + def test_state_dict(self): + normal = NormalPrior(0.1, 1).state_dict() + self.assertTrue("loc" in normal) + self.assertTrue("scale" in normal) + self.assertEqual(normal["loc"], 0.1) + + gamma = GammaPrior(1.1, 2).state_dict() + self.assertTrue("concentration" in gamma) + self.assertTrue("rate" in gamma) + self.assertEqual(gamma["concentration"], 1.1) + + ln = LogNormalPrior(2.1, 1.2).state_dict() + self.assertTrue("_transformed_loc" in ln) + self.assertTrue("_transformed_scale" in ln) + self.assertEqual(ln["_transformed_loc"], 2.1) + + hc = HalfCauchyPrior(1.3).state_dict() + self.assertTrue("_transformed_scale" in hc) + + def test_load_state_dict(self): + ln1 = LogNormalPrior(loc=0.5, scale=0.1) + ln2 = LogNormalPrior(loc=2.5, scale=2.1) + gm1 = GammaPrior(concentration=0.5, rate=0.1) + gm2 = GammaPrior(concentration=2.5, rate=2.1) + hc1 = HalfCauchyPrior(scale=1.1) + hc2 = HalfCauchyPrior(scale=101.1) + + ln2.load_state_dict(ln1.state_dict()) + self.assertEqual(ln2.loc, ln1.loc) + self.assertEqual(ln2.scale, ln1.scale) + + gm2.load_state_dict(gm1.state_dict()) + self.assertEqual(gm2.concentration, gm1.concentration) + self.assertEqual(gm2.rate, gm1.rate) + + hc2.load_state_dict(hc1.state_dict()) + self.assertEqual(hc2.scale, hc1.scale) + + def test_transformed_attributes(self): + norm = NormalPrior(loc=2.5, scale=2.1) + ln = LogNormalPrior(loc=2.5, scale=2.1) + hc = HalfCauchyPrior(scale=2.2) + + with self.assertRaisesRegex(AttributeError, "'NormalPrior' object has no attribute '_transformed_loc'"): + getattr(norm, "_transformed_loc") + + self.assertTrue(getattr(ln, "_transformed_loc"), 2.5) + norm.loc = Tensor([1.01]) + ln.loc = Tensor([1.01]) + self.assertEqual(ln._transformed_loc, 1.01) + with self.assertRaises(AttributeError): + ln._transformed_loc = 1.1 + + with self.assertRaises(AttributeError): + hc._transformed_scale = 1.01 diff --git a/test/priors/test_utils.py b/test/priors/test_utils.py new file mode 100644 index 000000000..c62bbaefb --- /dev/null +++ b/test/priors/test_utils.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 + +import unittest + +from torch import Tensor + +from gpytorch.priors import GammaPrior, HalfCauchyPrior, LogNormalPrior, NormalPrior + + +class TestPrior(unittest.TestCase): + def test_state_dict(self): + normal = NormalPrior(0.1, 1).state_dict() + self.assertTrue("loc" in normal) + self.assertTrue("scale" in normal) + self.assertEqual(normal["loc"], 0.1) + + gamma = GammaPrior(1.1, 2).state_dict() + self.assertTrue("concentration" in gamma) + self.assertTrue("rate" in gamma) + self.assertEqual(gamma["concentration"], 1.1) + + ln = LogNormalPrior(2.1, 1.2).state_dict() + self.assertTrue("_transformed_loc" in ln) + self.assertTrue("_transformed_scale" in ln) + self.assertEqual(ln["_transformed_loc"], 2.1) + + hc = HalfCauchyPrior(1.3).state_dict() + self.assertTrue("_transformed_scale" in hc) + + def test_load_state_dict(self): + ln1 = LogNormalPrior(loc=0.5, scale=0.1) + ln2 = LogNormalPrior(loc=2.5, scale=2.1) + gm1 = GammaPrior(concentration=0.5, rate=0.1) + gm2 = GammaPrior(concentration=2.5, rate=2.1) + hc1 = HalfCauchyPrior(scale=1.1) + hc2 = HalfCauchyPrior(scale=101.1) + + ln2.load_state_dict(ln1.state_dict()) + self.assertEqual(ln2.loc, ln1.loc) + self.assertEqual(ln2.scale, ln1.scale) + + gm2.load_state_dict(gm1.state_dict()) + self.assertEqual(gm2.concentration, gm1.concentration) + self.assertEqual(gm2.rate, gm1.rate) + + hc2.load_state_dict(hc1.state_dict()) + self.assertEqual(hc2.scale, hc1.scale) + + def test_transformed_attributes(self): + norm = NormalPrior(loc=2.5, scale=2.1) + ln = LogNormalPrior(loc=2.5, scale=2.1) + hc = HalfCauchyPrior(scale=2.2) + + with self.assertRaisesRegex(AttributeError, "'NormalPrior' object has no attribute '_transformed_loc'"): + getattr(norm, "_transformed_loc") + + self.assertTrue(getattr(ln, "_transformed_loc"), 2.5) + norm.loc = Tensor([1.01]) + ln.loc = Tensor([1.01]) + self.assertEqual(ln._transformed_loc, 1.01) + self.assertEqual(hc._transformed_scale, 2.2) diff --git a/test/utils/test_sum_interaction_terms.py b/test/utils/test_sum_interaction_terms.py new file mode 100644 index 000000000..559364356 --- /dev/null +++ b/test/utils/test_sum_interaction_terms.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +import unittest +from functools import reduce +from itertools import combinations +from operator import mul + +import torch +from linear_operator import to_dense + +import gpytorch +from gpytorch.test.base_test_case import BaseTestCase + + +def prod(iterable): + return reduce(mul, iterable, 1) + + +class TestSumInteractionTerms(BaseTestCase, unittest.TestCase): + def test_sum_interaction_terms(self): + batch_shape = torch.Size([2, 1]) + D = 5 + M = 4 + N = 20 + + base_kernel = gpytorch.kernels.RBFKernel(batch_shape=torch.Size([D])) + x = torch.randn(*batch_shape, D, N, 1) + with torch.no_grad(), gpytorch.settings.lazily_evaluate_kernels(False): + covars = base_kernel(x) + + actual = torch.zeros(*batch_shape, N, N) + for degree in range(1, M + 1): + for interaction_term_indices in combinations(range(D), degree): + actual = actual + prod([to_dense(covars[..., i, :, :]) for i in interaction_term_indices]) + + res = gpytorch.utils.sum_interaction_terms(covars, max_degree=M) + self.assertAllClose(res, actual) diff --git a/test/variational/test_nearest_neighbor_variational_strategy.py b/test/variational/test_nearest_neighbor_variational_strategy.py index 91a7594f7..e827d2f63 100644 --- a/test/variational/test_nearest_neighbor_variational_strategy.py +++ b/test/variational/test_nearest_neighbor_variational_strategy.py @@ -8,7 +8,7 @@ from gpytorch.test.variational_test_case import VariationalTestCase -class TestVNNGP(VariationalTestCase, unittest.TestCase): +class TestVNNGPNonInducingData(VariationalTestCase, unittest.TestCase): @property def batch_shape(self): return torch.Size([]) @@ -33,6 +33,15 @@ def likelihood_cls(self): def event_shape(self): return torch.Size([32]) + # VNNGP specific + @property + def full_batch(self): + return False + + @property + def computed_full_kl(self): + return False + def _make_model_and_likelihood( self, num_inducing=32, @@ -42,11 +51,21 @@ def _make_model_and_likelihood( distribution_cls=gpytorch.variational.MeanFieldVariationalDistribution, constant_mean=True, ): + # VNNGP variational strategy takes slightly different inputs than other variational strategies + # (i.e. it does not accept a learn_inducing_locations argument, and it expects + # a k and training_batch_size argument) + # We supply a custom method here for that purpose + class _VNNGPRegressionModel(gpytorch.models.ApproximateGP): - def __init__(self, inducing_points, k, training_batch_size): + def __init__(self, inducing_points, k, training_batch_size, compute_full_kl): variational_distribution = distribution_cls(num_inducing, batch_shape=batch_shape) variational_strategy = strategy_cls( - self, inducing_points, variational_distribution, k=k, training_batch_size=training_batch_size + self, + inducing_points, + variational_distribution, + k=k, + training_batch_size=training_batch_size, + compute_full_kl=compute_full_kl, ) super().__init__(variational_strategy) @@ -56,7 +75,10 @@ def __init__(self, inducing_points, k, training_batch_size): else: self.mean_module = gpytorch.means.ZeroMean() - self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) + self.covar_module = gpytorch.kernels.ScaleKernel( + gpytorch.kernels.RBFKernel(batch_shape=batch_shape, ard_num_dims=2), + batch_shape=batch_shape, + ) def forward(self, x): mean_x = self.mean_module(x) @@ -71,10 +93,43 @@ def __call__(self, x, prior=False, **kwargs): k = 3 d = 2 - training_batch_size = 4 + training_batch_size = num_inducing if self.full_batch else 4 + compute_full_kl = self.computed_full_kl inducing_points = torch.randn(*inducing_batch_shape, num_inducing, d) - return _VNNGPRegressionModel(inducing_points, k, training_batch_size), self.likelihood_cls() + return _VNNGPRegressionModel(inducing_points, k, training_batch_size, compute_full_kl), self.likelihood_cls() + + def test_training_iteration_batch_data(self): + # Data batch shape must always be subsumed by the inducing batch shape for VNNGP models + # So this test does not apply to VNNGP models + pass + + def test_eval_smaller_pred_batch(self): + # Data batch shape must always be subsumed by the inducing batch shape for VNNGP models + # So this test does not apply to VNNGP models + pass + def test_eval_larger_pred_batch(self): + # Data batch shape must always be subsumed by the inducing batch shape for VNNGP models + # So this test does not apply to VNNGP models + pass + + def test_training_all_batch_zero_mean(self): + # Original test in VariationalTestCase has a data_batch_shape that is not subsumed + # by the inducing_batch_shape (not allowed for VNNGP models). + return self.test_training_iteration( + model_batch_shape=(torch.Size([3, 4]) + self.batch_shape), + inducing_batch_shape=(torch.Size([3, 4]) + self.batch_shape), + data_batch_shape=(torch.Size([4]) + self.batch_shape), + expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape), + constant_mean=False, + ) + + def test_fantasy_call(self, *args, **kwargs): + with self.assertRaises(NotImplementedError): + super().test_fantasy_call(*args, **kwargs) + + +class TestVNNGP(TestVNNGPNonInducingData, unittest.TestCase): def _training_iter( self, model, @@ -98,8 +153,10 @@ def _training_iter( # Single optimization iteration model.train() likelihood.train() - output = model(train_x) - loss = -mll(output, train_y) + output = model(x=None) + current_training_indices = model.variational_strategy.current_training_indices + y_batch = train_y[..., current_training_indices] + loss = -mll(output, y_batch) loss.sum().backward() # Make sure we have gradients for all parameters @@ -112,20 +169,6 @@ def _training_iter( return output, loss - def _eval_iter(self, model, cuda=False): - inducing_batch_shape = model.variational_strategy.inducing_points.shape[:-2] - test_x = torch.randn(*inducing_batch_shape, 32, 2).clamp(-2.5, 2.5) - if cuda: - test_x = test_x.cuda() - model = model.cuda() - - # Single optimization iteration - model.eval() - with torch.no_grad(): - output = model(test_x) - - return output - def test_training_iteration( self, data_batch_shape=None, @@ -134,8 +177,9 @@ def test_training_iteration( expected_batch_shape=None, constant_mean=True, ): - # We cannot inheret the superclass method - # Because it expects `variational_params_intialized` to be set to 0 + # We cannot inheret the superclass method because it expects the + # expected output.event_shape should be the training_batch_size not + # self.event_shape (which is reserved for test_eval_iteration) # Batch shapes model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape @@ -154,6 +198,7 @@ def test_training_iteration( # Do forward pass # Iter 1 + self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 0) self._training_iter( model, likelihood, @@ -162,6 +207,7 @@ def test_training_iteration( cuda=self.cuda, ) # Iter 2 + self.assertEqual(model.variational_strategy.variational_params_initialized.item(), 1) output, loss = self._training_iter( model, likelihood, @@ -170,88 +216,20 @@ def test_training_iteration( cuda=self.cuda, ) self.assertEqual(output.batch_shape, expected_batch_shape) - self.assertEqual(output.event_shape, self.event_shape) + self.assertEqual(output.event_shape, torch.Size([model.variational_strategy.training_batch_size])) self.assertEqual(loss.shape, expected_batch_shape) - def test_training_iteration_batch_inducing(self): - # We need different batch sizes than the superclass - return self.test_training_iteration( - model_batch_shape=(torch.Size([3]) + self.batch_shape), - inducing_batch_shape=(torch.Size([3]) + self.batch_shape), - expected_batch_shape=(torch.Size([3]) + self.batch_shape), - ) - - def test_training_iteration_batch_data(self): - # We need different batch sizes than the superclass - return self.test_training_iteration( - model_batch_shape=self.batch_shape, - inducing_batch_shape=self.batch_shape, - expected_batch_shape=(self.batch_shape), - ) - - def test_training_iteration_batch_model(self): - # We need different batch sizes than the superclass - return self.test_training_iteration( - model_batch_shape=(torch.Size([3]) + self.batch_shape), - inducing_batch_shape=self.batch_shape, - expected_batch_shape=(torch.Size([3]) + self.batch_shape), - ) - - def test_training_all_batch_zero_mean(self): - # We need different batch sizes than the superclass - return self.test_training_iteration( - model_batch_shape=(torch.Size([3, 4]) + self.batch_shape), - inducing_batch_shape=(torch.Size([3, 1]) + self.batch_shape), - expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape), - constant_mean=False, - ) - - def test_eval_iteration( - self, - inducing_batch_shape=None, - model_batch_shape=None, - expected_batch_shape=None, - ): - # Batch shapes - model_batch_shape = model_batch_shape if model_batch_shape is not None else self.batch_shape - inducing_batch_shape = inducing_batch_shape if inducing_batch_shape is not None else self.batch_shape - expected_batch_shape = expected_batch_shape if expected_batch_shape is not None else self.batch_shape - # Make model and likelihood - model, likelihood = self._make_model_and_likelihood( - batch_shape=model_batch_shape, - inducing_batch_shape=inducing_batch_shape, - ) - - # Do one forward pass - self._training_iter(model, likelihood, mll_cls=self.mll_cls, cuda=self.cuda) - - # Now do evaluation - # Iter 1 - _ = self._eval_iter(model, cuda=self.cuda) - output = self._eval_iter(model, cuda=self.cuda) - self.assertEqual(output.batch_shape, expected_batch_shape) - self.assertEqual(output.event_shape, self.event_shape) - - def test_eval_smaller_pred_batch(self): - # We need different batch sizes than the superclass - return self.test_eval_iteration( - model_batch_shape=(torch.Size([3, 4]) + self.batch_shape), - inducing_batch_shape=(torch.Size([3, 1]) + self.batch_shape), - expected_batch_shape=(torch.Size([3, 4]) + self.batch_shape), - ) +class TestVNNGPFullBatch(TestVNNGP, unittest.TestCase): + @property + def full_batch(self): + return True - def test_eval_larger_pred_batch(self): - # We need different batch sizes than the superclass - return self.test_eval_iteration( - model_batch_shape=(torch.Size([4]) + self.batch_shape), - inducing_batch_shape=(self.batch_shape), - expected_batch_shape=(torch.Size([4]) + self.batch_shape), - ) - def test_fantasy_call(self, *args, **kwargs): - with self.assertRaises(NotImplementedError): - super().test_fantasy_call(*args, **kwargs) +class TestVNNGPFullKL(TestVNNGP, unittest.TestCase): + @property + def compute_full_kl(self): + return True if __name__ == "__main__":