diff --git a/Project.toml b/Project.toml index 2bcc8bee57..bad56848a2 100644 --- a/Project.toml +++ b/Project.toml @@ -41,10 +41,12 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" +MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" Optim = "429524aa-4258-5aef-a3af-852621145aeb" [extensions] TuringDynamicHMCExt = "DynamicHMC" +TuringMarginalLogDensitiesExt = "MarginalLogDensities" TuringOptimExt = "Optim" [compat] @@ -63,13 +65,14 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.29, 0.30.4, 0.31" +DynamicPPL = "0.31.4" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" Libtask = "0.8.8" LinearAlgebra = "1" LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" +MarginalLogDensities = "0.3.6" MCMCChains = "5, 6" NamedArrays = "0.9, 0.10" Optim = "1" diff --git a/ext/TuringMarginalLogDensitiesExt.jl b/ext/TuringMarginalLogDensitiesExt.jl new file mode 100644 index 0000000000..d26a6b4a64 --- /dev/null +++ b/ext/TuringMarginalLogDensitiesExt.jl @@ -0,0 +1,39 @@ +module TuringMarginalLogDensitiesExt + +using Turing: Turing, DynamicPPL +using Turing.Inference: LogDensityProblems +using MarginalLogDensities: MarginalLogDensities + +# Use a struct for this to avoid closure overhead. +struct Drop2ndArgAndFlipSign{F} + f::F +end + +(f::Drop2ndArgAndFlipSign)(x, _) = -f.f(x) + +function Turing.marginalize( + model::DynamicPPL.Model, + varnames::Vector, + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), +) + # Determine the indices for the variables to marginalise out. + varinfo = DynamicPPL.typed_varinfo(model) + varindices = DynamicPPL.vector_getranges(varinfo, varnames) + # Construct the marginal log-density model. + # Use linked `varinfo` to that we're working in unconstrained space and `OptimizationContext` to ensure + # that the log-abs-det jacobian terms are not included. + context = Turing.Optimisation.OptimizationContext(DynamicPPL.leafcontext(model.context)) + varinfo_linked = DynamicPPL.link(varinfo, model) + f = Base.Fix1( + LogDensityProblems.logdensity, + DynamicPPL.LogDensityFunction(varinfo_linked, model, context), + ) + # HACK: need the sign-flip here because `OptimizationContext` is a hacky impl which + # represent the _negative_ log-density. + mdl = MarginalLogDensities.MarginalLogDensity( + Drop2ndArgAndFlipSign(f), varinfo_linked[:], varindices, (), method + ) + return mdl +end + +end diff --git a/src/Turing.jl b/src/Turing.jl index dbfd5c5cf0..cc741e2890 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -58,6 +58,9 @@ using .Optimisation include("experimental/Experimental.jl") include("deprecated.jl") # to be removed in the next minor version release +include("extensions.jl") +export marginalize + ########### # Exports # ########### diff --git a/test/ext/TuringMarginalLogDensitiesExt.jl b/test/ext/TuringMarginalLogDensitiesExt.jl new file mode 100644 index 0000000000..2249653954 --- /dev/null +++ b/test/ext/TuringMarginalLogDensitiesExt.jl @@ -0,0 +1,16 @@ +module TuringMarginalLogDensitiesExt + +using Turing, MarginalLogDensities, Test + +@testset "MarginalLogDensities" begin + # Simple test case. + @model function demo() + x ~ Normal(0, 1) + y ~ Normal(x, 1) + end + model = demo(); + # Marginalize out `x`. + marginalized = marginalize(model, [@varname(x)]); + # Compute the marginal log-density of `y = 0.0`. + @test marginalized([0.0]) ≈ logpdf(Normal(0, √2), 0.0) atol=2e-1 +end diff --git a/test/runtests.jl b/test/runtests.jl index 530219c83b..543e1dc565 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -89,6 +89,10 @@ end @testset "utilities" begin @timeit_include("mcmc/utilities.jl") end + + @testset "extensions" begin + @timeit_include("ext/TuringMarginalLogDensitiesExt.jl") + end end show(TIMEROUTPUT; compact=true, sortby=:firstexec)