diff --git a/README.md b/README.md index b204e40..0945716 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,29 @@ pip install git+https://github.com/activatedgeek/torch-sgld.git ## Usage +The general idea is to modify the usual gradient-based update loops +in PyTorch with the `SGLD` optimizer. +```python +from torch_sgld import SGLD + +f = module() ## construct PyTorch nn.Module. + +sgld = SGLD(f.parameters(), lr=lr, momentum=.9) ## Add momentum to make it SG-HMC. +sgld_scheduler = ## Optionally add a step-size scheduler. + +for _ in range(num_steps): + energy = f() + energy.backward() + + sgld.step() + + sgld_scheduler.step() ## Optional scheduler step. +``` + +`cSGLD` can be implemented by using a cyclical learning rate schedule. +See the [toy_csgld.ipynb](./notebooks/toy_csgld.ipynb) notebook for a +complete example. ## License diff --git a/notebooks/toy_csgld.ipynb b/notebooks/toy_csgld.ipynb new file mode 100644 index 0000000..9d75d19 --- /dev/null +++ b/notebooks/toy_csgld.ipynb @@ -0,0 +1,192 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import numpy as np\n", + "from tqdm.auto import tqdm\n", + "\n", + "sns.set(font_scale=1.5, style='whitegrid')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Mixture of Two Gaussians" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.distributions import MultivariateNormal\n", + "\n", + "class MoG2(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " self.p1 = MultivariateNormal(torch.zeros(2) + 2., covariance_matrix=.5 * torch.eye(2))\n", + " self.p2 = MultivariateNormal(torch.zeros(2) - 2., covariance_matrix=torch.eye(2))\n", + "\n", + " def forward(self, x):\n", + " log_half = torch.tensor(1/2).log()\n", + " v1 = self.p1.log_prob(x) + log_half\n", + " v2 = self.p2.log_prob(x) + log_half\n", + "\n", + " return torch.stack([v1, v2]).logsumexp(0)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Likelihood Module" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "class LL(nn.Module):\n", + " '''Log-likelihood Module'''\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " self.theta = nn.Parameter(2. * torch.randn(1,2))\n", + " self.mog = MoG2()\n", + "\n", + " def forward(self):\n", + " return self.mog(self.theta)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5e613bab4a9b4817a6022cd58693ab4c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/10000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "grid = torch.stack(torch.meshgrid(\n", + " torch.linspace(-6, 6, 100), torch.linspace(-6, 6, 100), indexing='ij')).permute(1, 2, 0) # 100 x 100 x 2\n", + "\n", + "mog = MoG2()\n", + "logpgrid = mog(grid)\n", + "\n", + "fig, ax = plt.subplots(figsize=(7,7))\n", + "\n", + "ax.contourf(grid[..., 0].numpy(), grid[..., 1].numpy(), logpgrid.exp().numpy(), levels=10,\n", + " cmap=sns.color_palette(\"crest_r\", as_cmap=True))\n", + "\n", + "ax.scatter(samples[:, 0].numpy(), samples[:, 1].numpy(), c='red', alpha=.1)\n", + "\n", + "ax.set(xlabel='x', ylabel='y')\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fspace", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}