diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 6ac96a551..855d729e9 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -111,7 +111,7 @@ Enzyme.API.runtimeActivity!(true) alpha = 0.16 # regularizatin term var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior - @model function bnn(ts) + @model function bnn(ts, var_prior) b1 ~ MvNormal([0. ;0.; 0.], [var_prior 0. 0.; 0. var_prior 0.; 0. 0. var_prior]) w11 ~ MvNormal([0.; 0.], [var_prior 0.; 0. var_prior]) @@ -129,7 +129,7 @@ Enzyme.API.runtimeActivity!(true) end # Sampling - chain = sample(rng, bnn(ts), HMC(0.1, 5; adtype=adbackend), 10) + chain = sample(rng, bnn(ts, var_prior), HMC(0.1, 5; adtype=adbackend), 10) end @testset "hmcda inference" begin