Skip to content

Commit

Permalink
Merge branch 'main' into module
Browse files Browse the repository at this point in the history
  • Loading branch information
gpleiss authored Dec 6, 2024
2 parents 3c0a274 + 1a8aa8f commit 4398b9d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 42 deletions.
72 changes: 35 additions & 37 deletions examples/01_Exact_GPs/GP_Regression_Fully_Bayesian.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,57 @@
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"import os\n",
"\n",
"import gpytorch\n",
"from gpytorch.priors import UniformPrior\n",
"import matplotlib.pyplot as plt\n",
"import pyro\n",
"from pyro.infer.mcmc import NUTS, MCMC, HMC\n",
"from matplotlib import pyplot as plt\n",
"from pyro.infer.mcmc import NUTS, MCMC\n",
"import torch\n",
"\n",
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2"
"# this is for running the notebook in our testing framework\n",
"smoke_test = ('CI' in os.environ)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Training data is 11 points in [0,1] inclusive regularly spaced\n",
"# Training data is 4 points in [0,1] inclusive regularly spaced\n",
"train_x = torch.linspace(0, 1, 4)\n",
"# True function is sin(2*pi*x) with Gaussian noise\n",
"train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2"
]
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# We will use the simplest form of GP model, exact inference\n",
"class ExactGPModel(gpytorch.models.ExactGP):\n",
" def __init__(self, train_x, train_y, likelihood):\n",
" super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
" super().__init__(train_x, train_y, likelihood)\n",
" self.mean_module = gpytorch.means.ConstantMean()\n",
" self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
" \n",
"\n",
" def forward(self, x):\n",
" mean_x = self.mean_module(x)\n",
" covar_x = self.covar_module(x)\n",
Expand All @@ -78,7 +91,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -90,25 +103,17 @@
}
],
"source": [
"# this is for running the notebook in our testing framework\n",
"import os\n",
"smoke_test = ('CI' in os.environ)\n",
"num_samples = 2 if smoke_test else 100\n",
"warmup_steps = 2 if smoke_test else 100\n",
"\n",
"\n",
"from gpytorch.priors import LogNormalPrior, NormalPrior, UniformPrior\n",
"# Use a positive constraint instead of usual GreaterThan(1e-4) so that LogNormal has support over full range.\n",
"likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())\n",
"likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
"model = ExactGPModel(train_x, train_y, likelihood)\n",
"\n",
"model.mean_module.register_prior(\"mean_prior\", UniformPrior(-1, 1), \"constant\")\n",
"model.covar_module.base_kernel.register_prior(\"lengthscale_prior\", UniformPrior(0.01, 0.5), \"lengthscale\")\n",
"model.covar_module.register_prior(\"outputscale_prior\", UniformPrior(1, 2), \"outputscale\")\n",
"likelihood.register_prior(\"noise_prior\", UniformPrior(0.01, 0.5), \"noise\")\n",
"\n",
"mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
"\n",
"def pyro_model(x, y):\n",
" with gpytorch.settings.fast_computations(False, False, False):\n",
" sampled_model = model.pyro_sample_from_prior()\n",
Expand All @@ -132,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -141,7 +146,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -158,12 +163,12 @@
"source": [
"## Plot Mean Functions\n",
"\n",
"In the next cell, we plot the first 25 mean functions on the samep lot. This particular example has a fairly large amount of data for only 1 dimension, so the hyperparameter posterior is quite tight and there is relatively little variance."
"In the next cell, we plot the first 25 mean functions on the same plot. This particular example has a fairly large amount of data for only 1 dimension, so the hyperparameter posterior is quite tight and there is relatively little variance."
]
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 8,
"metadata": {
"scrolled": false
},
Expand All @@ -185,14 +190,14 @@
"with torch.no_grad():\n",
" # Initialize plot\n",
" f, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
" \n",
"\n",
" # Plot training data as black stars\n",
" ax.plot(train_x.numpy(), train_y.numpy(), 'k*', zorder=10)\n",
" \n",
"\n",
" for i in range(min(num_samples, 25)):\n",
" # Plot predictive means as blue line\n",
" ax.plot(test_x.numpy(), output.mean[i].detach().numpy(), 'b', linewidth=0.3)\n",
" \n",
"\n",
" # Shade between the lower and upper confidence bounds\n",
" # ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)\n",
" ax.set_ylim([-3, 3])\n",
Expand All @@ -212,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -221,7 +226,7 @@
"<All keys matched successfully>"
]
},
"execution_count": 63,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -235,13 +240,6 @@
"\n",
"model.load_state_dict(state_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
10 changes: 5 additions & 5 deletions gpytorch/kernels/linear_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
variance_constraint: Optional[Interval] = None,
**kwargs,
):
super(LinearKernel, self).__init__(**kwargs)
super().__init__(**kwargs)
if variance_constraint is None:
variance_constraint = Positive()
self.register_parameter(
Expand All @@ -73,17 +73,17 @@ def variance(self) -> Tensor:
return self.raw_variance_constraint.transform(self.raw_variance)

@variance.setter
def variance(self, value: Union[float, Tensor]):
def variance(self, value: Union[float, Tensor]) -> None:
self._set_variance(value)

def _set_variance(self, value: Union[float, Tensor]):
def _set_variance(self, value: Union[float, Tensor]) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_variance)
self.initialize(raw_variance=self.raw_variance_constraint.inverse_transform(value))

def forward(
self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, last_dim_is_batch: Optional[bool] = False, **params
) -> LinearOperator:
self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: Optional[bool] = False, **params
) -> Union[Tensor, LinearOperator]:
x1_ = x1 * self.variance.sqrt()
if last_dim_is_batch:
x1_ = x1_.transpose(-1, -2).unsqueeze(-1)
Expand Down

0 comments on commit 4398b9d

Please sign in to comment.