Skip to content

Commit

Permalink
Remove unused imports in GPyTorch Fully Bayesian example (#2612)
Browse files Browse the repository at this point in the history
* Remove unused imports in GPyTorch Fully Bayesian example

* Remove spaces, fix typos

* Remove unused comment about LogNormal prior

* Simplify super() call

* Remove unused mll

---------

Co-authored-by: Geoff Pleiss <[email protected]>
  • Loading branch information
chrisyeh96 and gpleiss authored Dec 6, 2024
1 parent 8940078 commit 1a8aa8f
Showing 1 changed file with 35 additions and 37 deletions.
72 changes: 35 additions & 37 deletions examples/01_Exact_GPs/GP_Regression_Fully_Bayesian.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,57 @@
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"import os\n",
"\n",
"import gpytorch\n",
"from gpytorch.priors import UniformPrior\n",
"import matplotlib.pyplot as plt\n",
"import pyro\n",
"from pyro.infer.mcmc import NUTS, MCMC, HMC\n",
"from matplotlib import pyplot as plt\n",
"from pyro.infer.mcmc import NUTS, MCMC\n",
"import torch\n",
"\n",
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2"
"# this is for running the notebook in our testing framework\n",
"smoke_test = ('CI' in os.environ)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Training data is 11 points in [0,1] inclusive regularly spaced\n",
"# Training data is 4 points in [0,1] inclusive regularly spaced\n",
"train_x = torch.linspace(0, 1, 4)\n",
"# True function is sin(2*pi*x) with Gaussian noise\n",
"train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2"
]
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# We will use the simplest form of GP model, exact inference\n",
"class ExactGPModel(gpytorch.models.ExactGP):\n",
" def __init__(self, train_x, train_y, likelihood):\n",
" super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
" super().__init__(train_x, train_y, likelihood)\n",
" self.mean_module = gpytorch.means.ConstantMean()\n",
" self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
" \n",
"\n",
" def forward(self, x):\n",
" mean_x = self.mean_module(x)\n",
" covar_x = self.covar_module(x)\n",
Expand All @@ -78,7 +91,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -90,25 +103,17 @@
}
],
"source": [
"# this is for running the notebook in our testing framework\n",
"import os\n",
"smoke_test = ('CI' in os.environ)\n",
"num_samples = 2 if smoke_test else 100\n",
"warmup_steps = 2 if smoke_test else 100\n",
"\n",
"\n",
"from gpytorch.priors import LogNormalPrior, NormalPrior, UniformPrior\n",
"# Use a positive constraint instead of usual GreaterThan(1e-4) so that LogNormal has support over full range.\n",
"likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())\n",
"likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
"model = ExactGPModel(train_x, train_y, likelihood)\n",
"\n",
"model.mean_module.register_prior(\"mean_prior\", UniformPrior(-1, 1), \"constant\")\n",
"model.covar_module.base_kernel.register_prior(\"lengthscale_prior\", UniformPrior(0.01, 0.5), \"lengthscale\")\n",
"model.covar_module.register_prior(\"outputscale_prior\", UniformPrior(1, 2), \"outputscale\")\n",
"likelihood.register_prior(\"noise_prior\", UniformPrior(0.01, 0.5), \"noise\")\n",
"\n",
"mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
"\n",
"def pyro_model(x, y):\n",
" with gpytorch.settings.fast_computations(False, False, False):\n",
" sampled_model = model.pyro_sample_from_prior()\n",
Expand All @@ -132,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -141,7 +146,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -158,12 +163,12 @@
"source": [
"## Plot Mean Functions\n",
"\n",
"In the next cell, we plot the first 25 mean functions on the samep lot. This particular example has a fairly large amount of data for only 1 dimension, so the hyperparameter posterior is quite tight and there is relatively little variance."
"In the next cell, we plot the first 25 mean functions on the same plot. This particular example has a fairly large amount of data for only 1 dimension, so the hyperparameter posterior is quite tight and there is relatively little variance."
]
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 8,
"metadata": {
"scrolled": false
},
Expand All @@ -185,14 +190,14 @@
"with torch.no_grad():\n",
" # Initialize plot\n",
" f, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
" \n",
"\n",
" # Plot training data as black stars\n",
" ax.plot(train_x.numpy(), train_y.numpy(), 'k*', zorder=10)\n",
" \n",
"\n",
" for i in range(min(num_samples, 25)):\n",
" # Plot predictive means as blue line\n",
" ax.plot(test_x.numpy(), output.mean[i].detach().numpy(), 'b', linewidth=0.3)\n",
" \n",
"\n",
" # Shade between the lower and upper confidence bounds\n",
" # ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)\n",
" ax.set_ylim([-3, 3])\n",
Expand All @@ -212,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -221,7 +226,7 @@
"<All keys matched successfully>"
]
},
"execution_count": 63,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -235,13 +240,6 @@
"\n",
"model.load_state_dict(state_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit 1a8aa8f

Please sign in to comment.