diff --git a/Lab.ipynb b/Lab.ipynb new file mode 100644 index 00000000..2d7849ef --- /dev/null +++ b/Lab.ipynb @@ -0,0 +1,683 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "91129cb1", + "metadata": {}, + "source": [ + "# No-glue-code" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "896323ee", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `~/Cambdrige`\n" + ] + } + ], + "source": [ + "using Pkg\n", + "Pkg.activate(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "baed58e3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mPrecompiling AdvancedHMC [0bf59076-c3b1-5ca4-86bd-e02cd72cde3d]\n", + "\u001b[36m\u001b[1m[ \u001b[22m\u001b[39m\u001b[36m\u001b[1mInfo: \u001b[22m\u001b[39mPrecompiling Turing [fce5fe82-541a-59a6-adf8-730c64b5f9a0]\n", + "WARNING: Method definition sample(Random.AbstractRNG, AbstractMCMC.AbstractModel, AbstractMCMC.AbstractSampler, AbstractMCMC.AbstractMCMCEnsemble, Integer, Integer) in module AbstractMCMC at /home/jaimerz/.julia/packages/AbstractMCMC/bE6VB/src/sample.jl:81 overwritten in module Inference at /home/jaimerz/Cambdrige/Turing.jl/src/inference/Inference.jl:210.\n", + " ** incremental compilation may be fatally broken for this module **\n", + "\n", + "WARNING: Method definition kwcall(Any, typeof(StatsBase.sample), Random.AbstractRNG, AbstractMCMC.AbstractModel, AbstractMCMC.AbstractSampler, AbstractMCMC.AbstractMCMCEnsemble, Integer, Integer) in module AbstractMCMC at /home/jaimerz/.julia/packages/AbstractMCMC/bE6VB/src/sample.jl:81 overwritten in module Inference at /home/jaimerz/Cambdrige/Turing.jl/src/inference/Inference.jl:210.\n", + " ** incremental compilation may be fatally broken for this module **\n", + "\n" + ] + } + ], + "source": [ + "using Random\n", + "using LinearAlgebra\n", + "using PyPlot\n", + "\n", + "#What we are tweaking\n", + "using Revise\n", + "using AdvancedHMC\n", + "using Turing\n", + "using DynamicPPL" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "3d76390f", + "metadata": {}, + "source": [ + "## Model" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a7d6f81c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "funnel (generic function with 2 methods)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Just a simple Neal Funnel\n", + "d = 21\n", + "@model function funnel()\n", + " θ ~ Uniform(-1, 1) #Normal(0, 3)\n", + " z ~ MvNormal(zeros(d-1), exp(θ)*I)\n", + " x ~ MvNormal(z, I)\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5f408f2b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Model{typeof(funnel), (), (), (), Tuple{}, Tuple{}, ConditionContext{NamedTuple{(:x,), Tuple{Vector{Float64}}}, DefaultContext}}(funnel, NamedTuple(), NamedTuple(), ConditionContext((x = [1.2142074831535152, 1.23371919965455, -0.8480146960461767, 0.1600994648479841, 1.9180385508479283, -3.401523464506408, -0.0957684186471088, 0.6734622629464286, -3.2749467689509633, -1.6760091758453226, 1.9567202902549736, 0.1136169088905351, 0.11117896909388916, -0.5373922347882832, -0.12436857036298687, -1.2901071061088532, 1.702584517514787, -0.44460133117954226, 1.0818722439221686, 1.2208011493237483],), DefaultContext()))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Random.seed!(1)\n", + "(;x) = rand(funnel() | (θ=0,))\n", + "funnel_model = funnel() | (;x)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "d852c160", + "metadata": {}, + "source": [ + "## Sampling" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "486d475d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AdvancedHMC.HMCSampler{Nothing, Nothing, Nothing, Nothing}(AdvancedHMC.NUTS_alg(500, 0.95, 10, 1000.0, 0.1), nothing, nothing, nothing, nothing)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nadapts=500 \n", + "TAP=0.95\n", + "ϵ=0.1\n", + "nuts = AdvancedHMC.NUTS(nadapts, TAP; ϵ=ϵ)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "9e114ad8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AdvancedHMC.HMCSampler{Nothing, Nothing, Nothing, Nothing}(AdvancedHMC.HMC_alg(0.1, 20), nothing, nothing, nothing, nothing)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ϵ=0.1\n", + "n_leapfrog=20\n", + "hmc = AdvancedHMC.HMC(ϵ, n_leapfrog)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1f729dc6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AdvancedHMC.HMCSampler{Nothing, Nothing, Nothing, Nothing}(AdvancedHMC.HMCDA_alg(500, 0.95, 1.0, 0.1), nothing, nothing, nothing, nothing)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_adapts = 500\n", + "TAP = 0.95\n", + "λ = 0.1 * 10\n", + "ϵ=0.1\n", + "hmcda = AdvancedHMC.HMCDA(n_adapts, TAP, λ; ϵ=ϵ)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b0193663", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[33m\u001b[1mWarning: \u001b[22m\u001b[39m[DynamicPPL] attempt to link a linked vi\n", + "\u001b[33m\u001b[1m└ \u001b[22m\u001b[39m\u001b[90m@ DynamicPPL ~/.julia/packages/DynamicPPL/jjVG9/src/varinfo.jl:791\u001b[39m\n", + "\u001b[33m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[33m\u001b[1mWarning: \u001b[22m\u001b[39m[DynamicPPL] attempt to link a linked vi\n", + "\u001b[33m\u001b[1m└ \u001b[22m\u001b[39m\u001b[90m@ DynamicPPL ~/.julia/packages/DynamicPPL/jjVG9/src/varinfo.jl:791\u001b[39m\n", + "\u001b[32mSampling: 100%|█████████████████████████████████████████| Time: 0:00:02\u001b[39m\n" + ] + }, + { + "data": { + "text/plain": [ + "Chains MCMC chain (5000×34×1 Array{Real, 3}):\n", + "\n", + "Iterations = 1:1:5000\n", + "Number of chains = 1\n", + "Samples per chain = 5000\n", + "parameters = param_1, param_2, param_3, param_4, param_5, param_6, param_7, param_8, param_9, param_10, param_11, param_12, param_13, param_14, param_15, param_16, param_17, param_18, param_19, param_20, param_21\n", + "internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size, is_adapt\n", + "\n", + "Summary Statistics\n", + " \u001b[1m parameters \u001b[0m \u001b[1m mean \u001b[0m \u001b[1m std \u001b[0m \u001b[1m mcse \u001b[0m \u001b[1m ess_bulk \u001b[0m \u001b[1m rhat \u001b[0m \u001b[1m ess_per_se\u001b[0m ⋯\n", + " \u001b[90m Symbol \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Real \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Missin\u001b[0m ⋯\n", + "\n", + " param_1 0.1027 0.4682 0.0125 1316.8261 1.0006 missin ⋯\n", + " param_2 0.6380 0.7443 0.0088 7305.3358 1.0007 missin ⋯\n", + " param_3 0.6571 0.7388 0.0087 7222.3134 0.9999 missin ⋯\n", + " param_4 -0.4590 0.7424 0.0081 8600.6777 0.9998 missin ⋯\n", + " param_5 0.0827 0.7254 0.0078 8658.7613 1.0009 missin ⋯\n", + " param_6 1.0204 0.7597 0.0109 4919.8215 0.9999 missin ⋯\n", + " param_7 -1.7932 0.8261 0.0145 3273.3659 1.0001 missin ⋯\n", + " param_8 -0.0484 0.7195 0.0071 10192.8327 1.0002 missin ⋯\n", + " param_9 0.3575 0.7262 0.0076 9149.6800 1.0002 missin ⋯\n", + " param_10 -1.7292 0.8133 0.0135 3701.3245 0.9999 missin ⋯\n", + " param_11 -0.8752 0.7379 0.0093 6376.3368 1.0004 missin ⋯\n", + " param_12 1.0242 0.7599 0.0103 5479.1056 1.0000 missin ⋯\n", + " param_13 0.0675 0.7458 0.0079 8945.0993 1.0009 missin ⋯\n", + " param_14 0.0668 0.7140 0.0072 9814.6348 1.0006 missin ⋯\n", + " param_15 -0.2908 0.7255 0.0076 9112.4223 0.9998 missin ⋯\n", + " param_16 -0.0508 0.7068 0.0070 10008.6090 1.0001 missin ⋯\n", + " param_17 -0.6693 0.7322 0.0087 7073.7412 0.9999 missin ⋯\n", + " param_18 0.8904 0.7460 0.0093 6393.5556 1.0004 missin ⋯\n", + " param_19 -0.2438 0.7394 0.0079 8715.5189 1.0000 missin ⋯\n", + " param_20 0.5602 0.7217 0.0082 7751.5157 1.0000 missin ⋯\n", + " param_21 0.6376 0.7380 0.0084 7807.3097 1.0011 missin ⋯\n", + "\u001b[36m 1 column omitted\u001b[0m\n", + "\n", + "Quantiles\n", + " \u001b[1m parameters \u001b[0m \u001b[1m 2.5% \u001b[0m \u001b[1m 25.0% \u001b[0m \u001b[1m 50.0% \u001b[0m \u001b[1m 75.0% \u001b[0m \u001b[1m 97.5% \u001b[0m\n", + " \u001b[90m Symbol \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m\n", + "\n", + " param_1 -0.8352 -0.2336 0.1326 0.4584 0.9071\n", + " param_2 -0.7920 0.1390 0.6151 1.1211 2.1535\n", + " param_3 -0.7435 0.1493 0.6307 1.1539 2.1429\n", + " param_4 -1.9727 -0.9420 -0.4536 0.0433 0.9569\n", + " param_5 -1.3355 -0.4084 0.0832 0.5671 1.4991\n", + " param_6 -0.3763 0.4823 1.0017 1.5315 2.5668\n", + " param_7 -3.4720 -2.3401 -1.7762 -1.2272 -0.2403\n", + " param_8 -1.4292 -0.5439 -0.0520 0.4395 1.3675\n", + " param_9 -1.0777 -0.1229 0.3547 0.8448 1.7903\n", + " param_10 -3.4370 -2.2589 -1.6796 -1.1744 -0.2281\n", + " param_11 -2.4021 -1.3726 -0.8447 -0.3686 0.5171\n", + " param_12 -0.4100 0.5065 0.9956 1.5327 2.5705\n", + " param_13 -1.4140 -0.4160 0.0706 0.5537 1.5270\n", + " param_14 -1.3651 -0.4031 0.0653 0.5342 1.4844\n", + " param_15 -1.7440 -0.7812 -0.2779 0.1959 1.0957\n", + " param_16 -1.3863 -0.5423 -0.0520 0.4442 1.3074\n", + " param_17 -2.1487 -1.1499 -0.6642 -0.1710 0.6959\n", + " param_18 -0.5586 0.3798 0.8693 1.3917 2.4085\n", + " param_19 -1.7016 -0.7273 -0.2266 0.2350 1.2082\n", + " param_20 -0.8251 0.0794 0.5603 1.0190 1.9974\n", + " param_21 -0.7633 0.1377 0.6239 1.1340 2.1128\n" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[33m\u001b[1mWarning: \u001b[22m\u001b[39mTail ESS calculation failed: OverflowError(\"4750 * 4503599627370496 overflowed for type Int64\")\n", + "\u001b[33m\u001b[1m└ \u001b[22m\u001b[39m\u001b[90m@ MCMCChains ~/.julia/packages/MCMCChains/OVsxE/src/stats.jl:319\u001b[39m\n" + ] + } + ], + "source": [ + "nuts_samples = sample(funnel_model, nuts, 5000)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f610b909", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[33m\u001b[1mWarning: \u001b[22m\u001b[39m[DynamicPPL] attempt to link a linked vi\n", + "\u001b[33m\u001b[1m└ \u001b[22m\u001b[39m\u001b[90m@ DynamicPPL ~/.julia/packages/DynamicPPL/jjVG9/src/varinfo.jl:791\u001b[39m\n", + "\u001b[33m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[33m\u001b[1mWarning: \u001b[22m\u001b[39m[DynamicPPL] attempt to link a linked vi\n", + "\u001b[33m\u001b[1m└ \u001b[22m\u001b[39m\u001b[90m@ DynamicPPL ~/.julia/packages/DynamicPPL/jjVG9/src/varinfo.jl:791\u001b[39m\n", + "\u001b[32mSampling: 100%|█████████████████████████████████████████| Time: 0:00:02\u001b[39m\n" + ] + }, + { + "data": { + "text/plain": [ + "Chains MCMC chain (5000×32×1 Array{Real, 3}):\n", + "\n", + "Iterations = 1:1:5000\n", + "Number of chains = 1\n", + "Samples per chain = 5000\n", + "parameters = param_1, param_2, param_3, param_4, param_5, param_6, param_7, param_8, param_9, param_10, param_11, param_12, param_13, param_14, param_15, param_16, param_17, param_18, param_19, param_20, param_21\n", + "internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size, is_adapt\n", + "\n", + "Summary Statistics\n", + " \u001b[1m parameters \u001b[0m \u001b[1m mean \u001b[0m \u001b[1m std \u001b[0m \u001b[1m mcse \u001b[0m \u001b[1m ess_bulk \u001b[0m \u001b[1m rhat \u001b[0m \u001b[1m ess_per_se\u001b[0m ⋯\n", + " \u001b[90m Symbol \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Real \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Missin\u001b[0m ⋯\n", + "\n", + " param_1 0.1116 0.4844 0.0126 1412.2510 1.0030 missin ⋯\n", + " param_2 0.6409 0.7630 0.0056 18494.8500 1.0003 missin ⋯\n", + " param_3 0.6563 0.7341 0.0054 18494.8500 1.0023 missin ⋯\n", + " param_4 -0.4489 0.7738 0.0057 18494.8500 1.0013 missin ⋯\n", + " param_5 0.0916 0.7387 0.0054 18494.8500 1.0008 missin ⋯\n", + " param_6 1.0122 0.7602 0.0068 13709.0981 1.0030 missin ⋯\n", + " param_7 -1.7991 0.8076 0.0124 4323.3788 1.0009 missin ⋯\n", + " param_8 -0.0475 0.7271 0.0053 18494.8500 1.0059 missin ⋯\n", + " param_9 0.3593 0.7176 0.0053 18494.8500 0.9999 missin ⋯\n", + " param_10 -1.7389 0.8314 0.0122 4786.2571 1.0019 missin ⋯\n", + " param_11 -0.8884 0.7405 0.0064 17067.3833 1.0013 missin ⋯\n", + " param_12 1.0324 0.7586 0.0068 12775.6485 1.0027 missin ⋯\n", + " param_13 0.0612 0.7115 0.0052 18494.8500 1.0026 missin ⋯\n", + " param_14 0.0576 0.7049 0.0052 18494.8500 1.0025 missin ⋯\n", + " param_15 -0.2848 0.7059 0.0052 18494.8500 0.9999 missin ⋯\n", + " param_16 -0.0663 0.7493 0.0055 18494.8500 1.0001 missin ⋯\n", + " param_17 -0.6799 0.7329 0.0054 18494.8500 1.0002 missin ⋯\n", + " param_18 0.9009 0.7595 0.0060 16083.8415 1.0022 missin ⋯\n", + " param_19 -0.2384 0.7235 0.0053 18494.8500 0.9999 missin ⋯\n", + " param_20 0.5663 0.7420 0.0055 18494.8500 1.0001 missin ⋯\n", + " param_21 0.6437 0.7433 0.0055 18494.8500 1.0003 missin ⋯\n", + "\u001b[36m 1 column omitted\u001b[0m\n", + "\n", + "Quantiles\n", + " \u001b[1m parameters \u001b[0m \u001b[1m 2.5% \u001b[0m \u001b[1m 25.0% \u001b[0m \u001b[1m 50.0% \u001b[0m \u001b[1m 75.0% \u001b[0m \u001b[1m 97.5% \u001b[0m\n", + " \u001b[90m Symbol \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m\n", + "\n", + " param_1 -0.8729 -0.2411 0.1414 0.4873 0.9276\n", + " param_2 -0.7746 0.1213 0.6172 1.1341 2.1519\n", + " param_3 -0.7742 0.1636 0.6370 1.1344 2.1403\n", + " param_4 -1.9930 -0.9673 -0.4454 0.0924 1.0236\n", + " param_5 -1.3644 -0.4021 0.0955 0.5800 1.5932\n", + " param_6 -0.4151 0.4951 0.9882 1.5015 2.5778\n", + " param_7 -3.4943 -2.3275 -1.7638 -1.2414 -0.2990\n", + " param_8 -1.4757 -0.5401 -0.0424 0.4405 1.3866\n", + " param_9 -1.0262 -0.1276 0.3563 0.8391 1.8048\n", + " param_10 -3.4816 -2.2922 -1.6942 -1.1709 -0.2376\n", + " param_11 -2.4214 -1.3706 -0.8625 -0.3788 0.5300\n", + " param_12 -0.4144 0.5254 1.0030 1.5337 2.5786\n", + " param_13 -1.3274 -0.4277 0.0578 0.5478 1.4726\n", + " param_14 -1.3147 -0.4071 0.0520 0.5357 1.4133\n", + " param_15 -1.7091 -0.7450 -0.2665 0.1876 1.0607\n", + " param_16 -1.5507 -0.5647 -0.0675 0.4274 1.4156\n", + " param_17 -2.1845 -1.1587 -0.6694 -0.1713 0.6950\n", + " param_18 -0.5178 0.3903 0.8748 1.4069 2.4258\n", + " param_19 -1.6924 -0.6976 -0.2310 0.2270 1.1589\n", + " param_20 -0.8190 0.0547 0.5392 1.0695 2.0687\n", + " param_21 -0.8290 0.1653 0.6314 1.1214 2.1541\n" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[33m\u001b[1mWarning: \u001b[22m\u001b[39mTail ESS calculation failed: OverflowError(\"4750 * 4503599627370496 overflowed for type Int64\")\n", + "\u001b[33m\u001b[1m└ \u001b[22m\u001b[39m\u001b[90m@ MCMCChains ~/.julia/packages/MCMCChains/OVsxE/src/stats.jl:319\u001b[39m\n" + ] + } + ], + "source": [ + "hmc_samples = sample(funnel_model, hmc, 5000)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "88df45a3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[33m\u001b[1mWarning: \u001b[22m\u001b[39m[DynamicPPL] attempt to link a linked vi\n", + "\u001b[33m\u001b[1m└ \u001b[22m\u001b[39m\u001b[90m@ DynamicPPL ~/.julia/packages/DynamicPPL/jjVG9/src/varinfo.jl:791\u001b[39m\n", + "\u001b[33m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[33m\u001b[1mWarning: \u001b[22m\u001b[39m[DynamicPPL] attempt to link a linked vi\n", + "\u001b[33m\u001b[1m└ \u001b[22m\u001b[39m\u001b[90m@ DynamicPPL ~/.julia/packages/DynamicPPL/jjVG9/src/varinfo.jl:791\u001b[39m\n", + "\u001b[32mSampling: 100%|█████████████████████████████████████████| Time: 0:00:01\u001b[39m\n" + ] + }, + { + "data": { + "text/plain": [ + "Chains MCMC chain (5000×32×1 Array{Real, 3}):\n", + "\n", + "Iterations = 1:1:5000\n", + "Number of chains = 1\n", + "Samples per chain = 5000\n", + "parameters = param_1, param_2, param_3, param_4, param_5, param_6, param_7, param_8, param_9, param_10, param_11, param_12, param_13, param_14, param_15, param_16, param_17, param_18, param_19, param_20, param_21\n", + "internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size, is_adapt\n", + "\n", + "Summary Statistics\n", + " \u001b[1m parameters \u001b[0m \u001b[1m mean \u001b[0m \u001b[1m std \u001b[0m \u001b[1m mcse \u001b[0m \u001b[1m ess_bulk \u001b[0m \u001b[1m rhat \u001b[0m \u001b[1m ess_per_sec\u001b[0m ⋯\n", + " \u001b[90m Symbol \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Real \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Missing\u001b[0m ⋯\n", + "\n", + " param_1 0.0979 0.4865 0.0229 427.6675 1.0077 missing ⋯\n", + " param_2 0.6547 0.7415 0.0160 2189.7809 1.0004 missing ⋯\n", + " param_3 0.6347 0.7416 0.0140 2846.6874 1.0009 missing ⋯\n", + " param_4 -0.4482 0.7324 0.0148 2459.9117 1.0002 missing ⋯\n", + " param_5 0.0916 0.7201 0.0128 3150.8292 1.0022 missing ⋯\n", + " param_6 0.9939 0.7645 0.0163 2285.0805 1.0002 missing ⋯\n", + " param_7 -1.7991 0.8208 0.0261 1001.8156 1.0031 missing ⋯\n", + " param_8 -0.0504 0.7234 0.0136 2815.2275 1.0008 missing ⋯\n", + " param_9 0.3700 0.7229 0.0132 3028.1210 0.9998 missing ⋯\n", + " param_10 -1.7251 0.8101 0.0261 966.5697 1.0029 missing ⋯\n", + " param_11 -0.8600 0.7541 0.0168 2021.1769 1.0020 missing ⋯\n", + " param_12 1.0075 0.7484 0.0167 2050.6918 1.0005 missing ⋯\n", + " param_13 0.0569 0.7187 0.0117 3750.8085 1.0008 missing ⋯\n", + " param_14 0.0608 0.7254 0.0134 2916.2452 1.0003 missing ⋯\n", + " param_15 -0.2655 0.7254 0.0126 3303.5375 1.0016 missing ⋯\n", + " param_16 -0.0366 0.7243 0.0128 3216.3677 1.0016 missing ⋯\n", + " param_17 -0.6590 0.7431 0.0154 2371.9178 1.0009 missing ⋯\n", + " param_18 0.8751 0.7536 0.0160 2242.7235 1.0004 missing ⋯\n", + " param_19 -0.2233 0.7202 0.0123 3419.9118 1.0002 missing ⋯\n", + " param_20 0.6038 0.7478 0.0142 2803.1610 1.0011 missing ⋯\n", + " param_21 0.6409 0.7377 0.0137 2922.8470 1.0005 missing ⋯\n", + "\n", + "Quantiles\n", + " \u001b[1m parameters \u001b[0m \u001b[1m 2.5% \u001b[0m \u001b[1m 25.0% \u001b[0m \u001b[1m 50.0% \u001b[0m \u001b[1m 75.0% \u001b[0m \u001b[1m 97.5% \u001b[0m\n", + " \u001b[90m Symbol \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m \u001b[90m Float64 \u001b[0m\n", + "\n", + " param_1 -0.8759 -0.2497 0.1436 0.4725 0.9130\n", + " param_2 -0.7777 0.1435 0.6347 1.1346 2.1667\n", + " param_3 -0.7896 0.1384 0.6153 1.1279 2.1692\n", + " param_4 -1.9185 -0.9338 -0.4423 0.0496 0.9832\n", + " param_5 -1.3330 -0.3826 0.0886 0.5713 1.4915\n", + " param_6 -0.4397 0.4663 0.9664 1.4970 2.5635\n", + " param_7 -3.4716 -2.3299 -1.7589 -1.2145 -0.2936\n", + " param_8 -1.4562 -0.5463 -0.0707 0.4393 1.3843\n", + " param_9 -1.0222 -0.1147 0.3627 0.8514 1.8522\n", + " param_10 -3.3582 -2.2815 -1.6821 -1.1519 -0.2374\n", + " param_11 -2.3854 -1.3465 -0.8462 -0.3597 0.6050\n", + " param_12 -0.4173 0.4949 0.9801 1.4995 2.5221\n", + " param_13 -1.3876 -0.4168 0.0545 0.5379 1.4619\n", + " param_14 -1.3516 -0.4284 0.0526 0.5433 1.4733\n", + " param_15 -1.7321 -0.7393 -0.2599 0.2137 1.1228\n", + " param_16 -1.4597 -0.5141 -0.0427 0.4371 1.4198\n", + " param_17 -2.1839 -1.1502 -0.6285 -0.1511 0.7155\n", + " param_18 -0.6034 0.3688 0.8616 1.3863 2.3647\n", + " param_19 -1.6165 -0.7083 -0.2238 0.2560 1.1999\n", + " param_20 -0.7926 0.0770 0.5904 1.0971 2.1209\n", + " param_21 -0.7719 0.1303 0.6252 1.1271 2.1344\n" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m\u001b[1m┌ \u001b[22m\u001b[39m\u001b[33m\u001b[1mWarning: \u001b[22m\u001b[39mTail ESS calculation failed: OverflowError(\"4750 * 4503599627370496 overflowed for type Int64\")\n", + "\u001b[33m\u001b[1m└ \u001b[22m\u001b[39m\u001b[90m@ MCMCChains ~/.julia/packages/MCMCChains/OVsxE/src/stats.jl:319\u001b[39m\n" + ] + } + ], + "source": [ + "hmcda_samples = sample(funnel_model, hmcda, 5000)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "bbf0131e", + "metadata": {}, + "source": [ + "### Plotting" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "9c61e0ab", + "metadata": {}, + "outputs": [], + "source": [ + "theta_nuts = Vector(nuts_samples[\"param_1\"][:, 1])\n", + "x10_nuts =Vector(nuts_samples[\"param_11\"][:, 1]);" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0b0923f1", + "metadata": {}, + "outputs": [], + "source": [ + "theta_hmc = Vector(hmc_samples[\"param_1\"][:, 1])\n", + "x10_hmc =Vector(hmc_samples[\"param_11\"][:, 1]);" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "fec8ace5", + "metadata": {}, + "outputs": [], + "source": [ + "theta_hmcda = Vector(hmcda_samples[\"param_1\"][:, 1])\n", + "x10_hmcda =Vector(hmcda_samples[\"param_11\"][:, 1]);" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "8869229b", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "Figure(PyObject
)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axis = plt.subplots(2, 2, figsize=(8,8))\n", + "fig.suptitle(\"AdvancedHMC's NUTS - 21-D Neal's Funnel\", fontsize=16)\n", + "\n", + "fig.delaxes(axis[1,2])\n", + "fig.subplots_adjust(hspace=0)\n", + "fig.subplots_adjust(wspace=0)\n", + "\n", + "axis[1,1].hist(x10_nuts, bins=100, range=[-6,2])\n", + "axis[1,1].set_yticks([])\n", + "\n", + "axis[2,2].hist(theta_nuts, bins=100, orientation=\"horizontal\", range=[-4, 2])\n", + "axis[2,2].set_xticks([])\n", + "axis[2,2].set_yticks([])\n", + "\n", + "axis[2,1].hist2d(x10_nuts, theta_nuts, bins=100, range=[[-6,2],[-4, 2]])\n", + "axis[2,1].set_xlabel(\"x10\")\n", + "axis[2,1].set_ylabel(\"theta\");" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "fe4c8b70", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "Figure(PyObject
)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axis = plt.subplots(2, 2, figsize=(8,8))\n", + "fig.suptitle(\"HMC - 21-D Neal's Funnel\", fontsize=16)\n", + "\n", + "fig.delaxes(axis[1,2])\n", + "fig.subplots_adjust(hspace=0)\n", + "fig.subplots_adjust(wspace=0)\n", + "\n", + "axis[1,1].hist(x10_hmc, bins=100, range=[-6,2])\n", + "axis[1,1].set_yticks([])\n", + "\n", + "axis[2,2].hist(theta_hmc, bins=100, orientation=\"horizontal\", range=[-4, 2])\n", + "axis[2,2].set_xticks([])\n", + "axis[2,2].set_yticks([])\n", + "\n", + "axis[2,1].hist2d(x10_hmc, theta_hmc, bins=100, range=[[-6,2],[-4, 2]])\n", + "axis[2,1].set_xlabel(\"x10\")\n", + "axis[2,1].set_ylabel(\"theta\");" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "2c9052ab", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "Figure(PyObject
)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axis = plt.subplots(2, 2, figsize=(8,8))\n", + "fig.suptitle(\"HMCDA - 21-D Neal's Funnel\", fontsize=16)\n", + "\n", + "fig.delaxes(axis[1,2])\n", + "fig.subplots_adjust(hspace=0)\n", + "fig.subplots_adjust(wspace=0)\n", + "\n", + "axis[1,1].hist(x10_hmcda, bins=100, range=[-6,2])\n", + "axis[1,1].set_yticks([])\n", + "\n", + "axis[2,2].hist(theta_hmcda, bins=100, orientation=\"horizontal\", range=[-4, 2])\n", + "axis[2,2].set_xticks([])\n", + "axis[2,2].set_yticks([])\n", + "\n", + "axis[2,1].hist2d(x10_hmcda, theta_hmcda, bins=100, range=[[-6,2],[-4, 2]])\n", + "axis[2,1].set_xlabel(\"x10\")\n", + "axis[2,1].set_ylabel(\"theta\");" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "843becb3", + "metadata": {}, + "source": [] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "91baadc8", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.9.0", + "language": "julia", + "name": "julia-1.9" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/AdvancedHMC.jl b/src/AdvancedHMC.jl index 28bc440f..df123320 100644 --- a/src/AdvancedHMC.jl +++ b/src/AdvancedHMC.jl @@ -169,6 +169,7 @@ include("sampler.jl") export sample include("abstractmcmc.jl") +include("constructors.jl") ## Without explicit AD backend function Hamiltonian(metric::AbstractMetric, ℓ::LogDensityModel; kwargs...) @@ -263,4 +264,4 @@ function __init__() end end -end # module +end # module \ No newline at end of file diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index e491b53b..b4344ef0 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -1,30 +1,3 @@ -""" - HMCSampler - -A `AbstractMCMC.AbstractSampler` for kernels in AdvancedHMC.jl. - -# Fields - -$(FIELDS) - -# Notes - -Note that all the fields have the prefix `initial_` to indicate -that these will not necessarily correspond to the `kernel`, `metric`, -and `adaptor` after sampling. - -To access the updated fields use the resulting [`HMCState`](@ref). -""" -struct HMCSampler{K,M,A} <: AbstractMCMC.AbstractSampler - "Initial [`AbstractMCMCKernel`](@ref)." - initial_kernel::K - "Initial [`AbstractMetric`](@ref)." - initial_metric::M - "Initial [`AbstractAdaptor`](@ref)." - initial_adaptor::A -end -HMCSampler(kernel, metric) = HMCSampler(kernel, metric, Adaptation.NoAdaptation()) - """ HMCState @@ -53,140 +26,39 @@ struct HMCState{ adaptor::TAdapt end -""" - $(TYPEDSIGNATURES) - -A convenient wrapper around `AbstractMCMC.sample` avoiding explicit construction of [`HMCSampler`](@ref). -""" -function AbstractMCMC.sample( - model::LogDensityModel, - kernel::AbstractMCMCKernel, - metric::AbstractMetric, - adaptor::AbstractAdaptor, - N::Integer; - kwargs..., -) - return AbstractMCMC.sample( - Random.GLOBAL_RNG, - model, - kernel, - metric, - adaptor, - N; - kwargs..., - ) -end - -function AbstractMCMC.sample( - rng::Random.AbstractRNG, - model::LogDensityModel, - kernel::AbstractMCMCKernel, - metric::AbstractMetric, - adaptor::AbstractAdaptor, - N::Integer; - progress = true, - verbose = false, - callback = nothing, - kwargs..., -) - sampler = HMCSampler(kernel, metric, adaptor) - if callback === nothing - callback = HMCProgressCallback(N, progress = progress, verbose = verbose) - progress = false # don't use AMCMC's progress-funtionality - end - - return AbstractMCMC.mcmcsample( - rng, - model, - sampler, - N; - progress = progress, - verbose = verbose, - callback = callback, - kwargs..., - ) -end - -function AbstractMCMC.sample( - model::LogDensityModel, - kernel::AbstractMCMCKernel, - metric::AbstractMetric, - adaptor::AbstractAdaptor, - parallel::AbstractMCMC.AbstractMCMCEnsemble, - N::Integer, - nchains::Integer; - kwargs..., -) - return AbstractMCMC.sample( - Random.GLOBAL_RNG, - model, - kernel, - metric, - adaptor, - N, - nchains; - kwargs..., - ) -end - -function AbstractMCMC.sample( - rng::Random.AbstractRNG, - model::LogDensityModel, - kernel::AbstractMCMCKernel, - metric::AbstractMetric, - adaptor::AbstractAdaptor, - parallel::AbstractMCMC.AbstractMCMCEnsemble, - N::Integer, - nchains::Integer; - progress = true, - verbose = false, - callback = nothing, - kwargs..., -) - sampler = HMCSampler(kernel, metric, adaptor) - if callback === nothing - callback = HMCProgressCallback(N, progress = progress, verbose = verbose) - progress = false # don't use AMCMC's progress-funtionality - end - - return AbstractMCMC.mcmcsample( - rng, - model, - sampler, - parallel, - N, - nchains; - progress = progress, - verbose = verbose, - callback = callback, - kwargs..., - ) -end - function AbstractMCMC.step( rng::AbstractRNG, - model::LogDensityModel, - spl::HMCSampler; + model::AbstractMCMC.LogDensityModel, + spl::AbstractMCMC.AbstractSampler; init_params = nothing, kwargs..., -) - metric = spl.initial_metric - κ = spl.initial_kernel - adaptor = spl.initial_adaptor +) + # Unpack model + logdensity = model.logdensity + vi = logdensity.varinfo - if init_params === nothing - init_params = randn(rng, size(metric, 1)) - end + # Define metric + metric = make_metric(spl, logdensity) # Construct the hamiltonian using the initial metric hamiltonian = Hamiltonian(metric, model) + # Define integration algorithm + # Find good eps if not provided one + integrator = make_integrator(rng, spl, hamiltonian, init_params) + + # Make kernel + κ = make_kernel(spl, integrator) + + # Make adaptor + n_adapts, adaptor = make_adaptor(spl, metric, integrator) + # Get an initial sample. h, t = AdvancedHMC.sample_init(rng, hamiltonian, init_params) # Compute next transition and state. - state = HMCState(0, t, h.metric, κ, adaptor) - + state = HMCState(0, t, metric, κ, adaptor) + # Take actual first step. return AbstractMCMC.step(rng, model, spl, state; kwargs...) end @@ -194,7 +66,7 @@ end function AbstractMCMC.step( rng::AbstractRNG, model::LogDensityModel, - spl::HMCSampler, + spl::AbstractMCMC.AbstractSampler, state::HMCState; nadapts::Int = 0, kwargs..., @@ -302,4 +174,4 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kw elseif verbose && isadapted && i == nadapts @info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric end -end +end \ No newline at end of file diff --git a/src/constructors.jl b/src/constructors.jl new file mode 100644 index 00000000..9d33eff6 --- /dev/null +++ b/src/constructors.jl @@ -0,0 +1,205 @@ +abstract type AbstractHMCSampler <:AbstractMCMC.AbstractSampler end + +########## +# Custom # +########## +""" + HMCSampler + +A `AbstractMCMC.AbstractSampler` for kernels in AdvancedHMC.jl. + +# Fields + +$(FIELDS) + +# Notes + +Note that all the fields have the prefix `initial_` to indicate +that these will not necessarily correspond to the `kernel`, `metric`, +and `adaptor` after sampling. + +To access the updated fields use the resulting [`HMCState`](@ref). +""" +Base.@kwdef struct CustomHMC{I,K,M,A} <: AbstractMCMC.AbstractSampler + "[`integrator`](@ref)." + integrator::I=Leapfrog + "[`AbstractMCMCKernel`](@ref)." + kernel::K=nothing + "[`AbstractMetric`](@ref)." + metric::M=nothing + "[`AbstractAdaptor`](@ref)." + adaptor::A=nothing +end + +######## +# NUTS # +######## +""" + NUTS(n_adapts::Int, δ::Float64; max_depth::Int=10, Δ_max::Float64=1000.0, init_ϵ::Float64=0.0) + +No-U-Turn Sampler (NUTS) sampler. + +Usage: + +```julia +NUTS() # Use default NUTS configuration. +NUTS(1000, 0.65) # Use 1000 adaption steps, and target accept ratio 0.65. +``` + +Arguments: + +- `n_adapts::Int` : The number of samples to use with adaptation. +- `δ::Float64` : Target acceptance rate for dual averaging. +- `max_depth::Int` : Maximum doubling tree depth. +- `Δ_max::Float64` : Maximum divergence during doubling tree. +- `init_ϵ::Float64` : Initial step size; 0 means automatically searching using a heuristic procedure. + +""" +Base.@kwdef struct NUTS_alg <: AbstractMCMC.AbstractSampler + n_adapts::Int # number of samples with adaption for ϵ + δ::Float64 # target accept rate + max_depth::Int=10 # maximum tree depth + Δ_max::Float64=1000.0 # maximum error + init_ϵ::Float64=0.0 # (initial) step size + integrator_method=Leapfrog # integrator method + metric_type=DiagEuclideanMetric # metric type +end + +####### +# HMC # +####### +""" + HMC(ϵ::Float64, n_leapfrog::Int) + +Hamiltonian Monte Carlo sampler with static trajectory. + +Arguments: + +- `ϵ::Float64` : The leapfrog step size to use. +- `n_leapfrog::Int` : The number of leapfrog steps to use. + +Usage: + +```julia +HMC(0.05, 10) +``` + +Tips: + +- If you are receiving gradient errors when using `HMC`, try reducing the leapfrog step size `ϵ`, e.g. + +```julia +# Original step size +sample(gdemo([1.5, 2]), HMC(0.1, 10), 1000) + +# Reduced step size +sample(gdemo([1.5, 2]), HMC(0.01, 10), 1000) +``` +""" +Base.@kwdef struct HMC_alg <: AbstractMCMC.AbstractSampler + init_ϵ::Float64 # leapfrog step size + n_leapfrog::Int # leapfrog step number + integrator_method=Leapfrog # integrator method + metric_type=DiagEuclideanMetric # metric type +end + +######### +# HMCDA # +######### +""" + HMCDA(n_adapts::Int, δ::Float64, λ::Float64; ϵ::Float64=0.0) + +Hamiltonian Monte Carlo sampler with Dual Averaging algorithm. + +Usage: + +```julia +HMCDA(200, 0.65, 0.3) +``` + +Arguments: + +- `n_adapts::Int` : Numbers of samples to use for adaptation. +- `δ::Float64` : Target acceptance rate. 65% is often recommended. +- `λ::Float64` : Target leapfrog length. +- `ϵ::Float64=0.0` : Initial step size; 0 means automatically search by Turing. + +For more information, please view the following paper ([arXiv link](https://arxiv.org/abs/1111.4246)): + +- Hoffman, Matthew D., and Andrew Gelman. "The No-U-turn sampler: adaptively + setting path lengths in Hamiltonian Monte Carlo." Journal of Machine Learning + Research 15, no. 1 (2014): 1593-1623. +""" +Base.@kwdef struct HMCDA_alg <: AbstractMCMC.AbstractSampler + n_adapts::Int # number of samples with adaption for ϵ + δ::Float64 # target accept rate + λ::Float64 # target leapfrog length + init_ϵ::Float64=0.0 # (initial) step size + integrator_method=Leapfrog # integrator method + metric_type=DiagEuclideanMetric # metric type +end + +export CustomHMC, HMC_alg, NUTS_alg, HMCDA_alg +######### +# Utils # +######### + +function make_integrator(rng, spl::Union{HMC_alg, NUTS_alg, HMCDA_alg}, + hamiltonian, init_params) + init_ϵ = spl.init_ϵ + if iszero(init_ϵ) + init_ϵ = find_good_stepsize(rng, hamiltonian, init_params) + @info string("Found initial step size ", init_ϵ) + end + return spl.integrator_method(init_ϵ) +end + +function make_integrator(rng, spl::CustomHMC, hamiltonian, init_params) + return spl.integrator +end + +######### + +function make_metric(spl::Union{HMC_alg, NUTS_alg, HMCDA_alg}, logdensity) + d = LogDensityProblems.dimension(logdensity) + return spl.metric_type(d) +end + +function make_metric(spl::CustomHMC, logdensity) + return spl.metric +end + +######### + +function make_adaptor(spl::Union{NUTS_alg, HMCDA_alg}, metric, integrator) + adaptor = StanHMCAdaptor(MassMatrixAdaptor(metric), + StepSizeAdaptor(spl.δ, integrator)) + n_adapts = spl.n_adapts + return n_adapts, adaptor + end + +function make_adaptor(spl::HMC_alg, metric, integrator) + return 0, NoAdaptation() + end + + function make_adaptor(spl::CustomHMC, metric, integrator) + return spl.n_adapts, spl.adaptor + end + +######### + +function make_kernel(spl::NUTS_alg, integrator) + return HMCKernel(Trajectory{MultinomialTS}(integrator, GeneralisedNoUTurn())) +end + +function make_kernel(spl::HMC_alg, integrator) + return HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(spl.n_leapfrog))) +end + +function make_kernel(spl::HMCDA_alg, integrator) + return HMCKernel(Trajectory{EndPointTS}(integrator, FixedIntegrationTime(spl.λ))) +end + +function make_kernel(spl::CustomHMC, integrator) + return spl.kernel +end \ No newline at end of file diff --git a/src/sampler.jl b/src/sampler.jl index 7d1b7eb5..d8b63ce8 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -246,4 +246,4 @@ function sample( @info "Finished $n_samples sampling steps for $n_chains chains in $time (s)" h κ EBFMI_est average_acceptance_rate end return θs, stats -end +end \ No newline at end of file