Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused imports in GPyTorch Fully Bayesian example #2612

Merged
merged 6 commits into from
Dec 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 35 additions & 37 deletions examples/01_Exact_GPs/GP_Regression_Fully_Bayesian.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,57 @@
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import torch\n",
"import os\n",
"\n",
"import gpytorch\n",
"from gpytorch.priors import UniformPrior\n",
"import matplotlib.pyplot as plt\n",
"import pyro\n",
"from pyro.infer.mcmc import NUTS, MCMC, HMC\n",
"from matplotlib import pyplot as plt\n",
"from pyro.infer.mcmc import NUTS, MCMC\n",
"import torch\n",
"\n",
"%matplotlib inline\n",
"%load_ext autoreload\n",
"%autoreload 2"
"# this is for running the notebook in our testing framework\n",
"smoke_test = ('CI' in os.environ)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Training data is 11 points in [0,1] inclusive regularly spaced\n",
"# Training data is 4 points in [0,1] inclusive regularly spaced\n",
"train_x = torch.linspace(0, 1, 4)\n",
"# True function is sin(2*pi*x) with Gaussian noise\n",
"train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2"
]
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# We will use the simplest form of GP model, exact inference\n",
"class ExactGPModel(gpytorch.models.ExactGP):\n",
" def __init__(self, train_x, train_y, likelihood):\n",
" super(ExactGPModel, self).__init__(train_x, train_y, likelihood)\n",
" super().__init__(train_x, train_y, likelihood)\n",
" self.mean_module = gpytorch.means.ConstantMean()\n",
" self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
" \n",
"\n",
" def forward(self, x):\n",
" mean_x = self.mean_module(x)\n",
" covar_x = self.covar_module(x)\n",
Expand All @@ -78,7 +91,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -90,25 +103,17 @@
}
],
"source": [
"# this is for running the notebook in our testing framework\n",
"import os\n",
"smoke_test = ('CI' in os.environ)\n",
"num_samples = 2 if smoke_test else 100\n",
"warmup_steps = 2 if smoke_test else 100\n",
"\n",
"\n",
"from gpytorch.priors import LogNormalPrior, NormalPrior, UniformPrior\n",
"# Use a positive constraint instead of usual GreaterThan(1e-4) so that LogNormal has support over full range.\n",
"likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.Positive())\n",
"likelihood = gpytorch.likelihoods.GaussianLikelihood()\n",
"model = ExactGPModel(train_x, train_y, likelihood)\n",
"\n",
"model.mean_module.register_prior(\"mean_prior\", UniformPrior(-1, 1), \"constant\")\n",
"model.covar_module.base_kernel.register_prior(\"lengthscale_prior\", UniformPrior(0.01, 0.5), \"lengthscale\")\n",
"model.covar_module.register_prior(\"outputscale_prior\", UniformPrior(1, 2), \"outputscale\")\n",
"likelihood.register_prior(\"noise_prior\", UniformPrior(0.01, 0.5), \"noise\")\n",
"\n",
"mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n",
"\n",
"def pyro_model(x, y):\n",
" with gpytorch.settings.fast_computations(False, False, False):\n",
" sampled_model = model.pyro_sample_from_prior()\n",
Expand All @@ -132,7 +137,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -141,7 +146,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -158,12 +163,12 @@
"source": [
"## Plot Mean Functions\n",
"\n",
"In the next cell, we plot the first 25 mean functions on the samep lot. This particular example has a fairly large amount of data for only 1 dimension, so the hyperparameter posterior is quite tight and there is relatively little variance."
"In the next cell, we plot the first 25 mean functions on the same plot. This particular example has a fairly large amount of data for only 1 dimension, so the hyperparameter posterior is quite tight and there is relatively little variance."
]
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 8,
"metadata": {
"scrolled": false
},
Expand All @@ -185,14 +190,14 @@
"with torch.no_grad():\n",
" # Initialize plot\n",
" f, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
" \n",
"\n",
" # Plot training data as black stars\n",
" ax.plot(train_x.numpy(), train_y.numpy(), 'k*', zorder=10)\n",
" \n",
"\n",
" for i in range(min(num_samples, 25)):\n",
" # Plot predictive means as blue line\n",
" ax.plot(test_x.numpy(), output.mean[i].detach().numpy(), 'b', linewidth=0.3)\n",
" \n",
"\n",
" # Shade between the lower and upper confidence bounds\n",
" # ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)\n",
" ax.set_ylim([-3, 3])\n",
Expand All @@ -212,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -221,7 +226,7 @@
"<All keys matched successfully>"
]
},
"execution_count": 63,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -235,13 +240,6 @@
"\n",
"model.load_state_dict(state_dict)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading