From fce3b9d6a477115fe7df16c3aef8b5f430f453e6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Dec 2022 23:38:11 +0000 Subject: [PATCH 1/7] initial commit --- Project.toml | 3 ++- src/AdvancedMH.jl | 22 ++++++++++++++++------ src/MALA.jl | 15 +++++++++++---- src/emcee.jl | 10 +++++----- src/mcmcchains-connect.jl | 6 +++--- src/mh-core.jl | 12 ++++++------ src/proposal.jl | 18 +++++++++--------- src/structarray-connect.jl | 2 +- 8 files changed, 53 insertions(+), 35 deletions(-) diff --git a/Project.toml b/Project.toml index 637ca7e..a0254de 100644 --- a/Project.toml +++ b/Project.toml @@ -5,11 +5,12 @@ version = "0.6.4" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" [compat] -AbstractMCMC = "2, 3.0" +AbstractMCMC = "4" Distributions = "0.20, 0.21, 0.22, 0.23, 0.24, 0.25" Requires = "1" julia = "1" diff --git a/src/AdvancedMH.jl b/src/AdvancedMH.jl index d41c3e2..a72d937 100644 --- a/src/AdvancedMH.jl +++ b/src/AdvancedMH.jl @@ -5,12 +5,14 @@ using AbstractMCMC using Distributions using Requires +using LogDensityProblems: LogDensityProblems + import Random # Exports export MetropolisHastings, - DensityModel, + DensityModelOrLogDensityModel, RWMH, StaticMH, StaticProposal, @@ -48,6 +50,8 @@ struct DensityModel{F<:Function} <: AbstractMCMC.AbstractModel logdensity :: F end +const DensityModelOrLogDensityModel = Union{<:DensityModel,<:AbstractMCMC.LogDensityModel} + # Create a very basic Transition type, only stores the # parameter draws and the log probability of the draw. struct Transition{T<:Union{Vector, Real, NamedTuple}, L<:Real} <: AbstractTransition @@ -56,16 +60,22 @@ struct Transition{T<:Union{Vector, Real, NamedTuple}, L<:Real} <: AbstractTransi end # Store the new draw and its log density. -Transition(model::DensityModel, params) = Transition(params, logdensity(model, params)) +Transition(model::DensityModelOrLogDensityModel, params) = Transition(params, logdensity(model, params)) +function Transition(model::AbstractMCMC.LogDensityModel, params) + return Transition(params, LogDensityProblems.logdensity(model.logdensity, params)) +end # Calculate the density of the model given some parameterization. -logdensity(model::DensityModel, params) = model.logdensity(params) -logdensity(model::DensityModel, t::Transition) = t.lp +logdensity(model::DensityModelOrLogDensityModel, params) = model.logdensity(params) +logdensity(model::DensityModelOrLogDensityModel, t::Transition) = t.lp + +logdensity(model::AbstractMCMC.LogDensityModel, params) = LogDensityProblems.logdensity(model.logdensity, params) +logdensity(model::AbstractMCMC.LogDensityModel, t::Transition) = t.lp # A basic chains constructor that works with the Transition struct we defined. function AbstractMCMC.bundle_samples( ts::Vector{<:AbstractTransition}, - model::DensityModel, + model::Union{<:DensityModelOrLogDensityModel,<:AbstractMCMC.LogDensityModel}, sampler::MHSampler, state, chain_type::Type{Vector{NamedTuple}}; @@ -91,7 +101,7 @@ end function AbstractMCMC.bundle_samples( ts::Vector{<:Transition{<:NamedTuple}}, - model::DensityModel, + model::Union{<:DensityModelOrLogDensityModel,<:AbstractMCMC.LogDensityModel}, sampler::MHSampler, state, chain_type::Type{Vector{NamedTuple}}; diff --git a/src/MALA.jl b/src/MALA.jl index f7a2a03..04bb3d9 100644 --- a/src/MALA.jl +++ b/src/MALA.jl @@ -19,20 +19,27 @@ struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{ gradient::G end -logdensity(model::DensityModel, t::GradientTransition) = t.lp +logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp propose(rng::Random.AbstractRNG, ::MALA, model) = error("please specify initial parameters") -function transition(sampler::MALA, model::DensityModel, params) +function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params) return GradientTransition(params, logdensity_and_gradient(model, params)...) end +check_capabilities(model::DensityModelOrLogDensityModel) = nothing +function check_capabilities(model::AbstractMCMC.LogDensityModel) + @assert LogDensityProblems.capabilities(model.logdensity) !== LogDensityProblems.LogDensityOrder{0}() +end + function AbstractMCMC.step( rng::Random.AbstractRNG, - model::DensityModel, + model::DensityModelOrLogDensityModel, sampler::MALA, transition_prev::GradientTransition; kwargs... ) + check_capabilities(model) + # Extract value and gradient of the log density of the current state. state = transition_prev.params logdensity_state = transition_prev.lp @@ -70,7 +77,7 @@ end Return the value and gradient of the log density of the parameters `params` for the `model`. """ -function logdensity_and_gradient(model::DensityModel, params) +function logdensity_and_gradient(model::DensityModelOrLogDensityModel, params) res = GradientResult(params) gradient!(res, model.logdensity, params) return value(res), gradient(res) diff --git a/src/emcee.jl b/src/emcee.jl index 2504122..eca191d 100644 --- a/src/emcee.jl +++ b/src/emcee.jl @@ -3,7 +3,7 @@ struct Ensemble{D} <: MHSampler proposal::D end -function transition(sampler::Ensemble, model::DensityModel, params) +function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params) return [Transition(model, x) for x in params] end @@ -13,7 +13,7 @@ end # (if accepted) or the previous proposal (if not accepted). function AbstractMCMC.step( rng::Random.AbstractRNG, - model::DensityModel, + model::DensityModelOrLogDensityModel, spl::Ensemble, params_prev::Vector{<:Transition}; kwargs..., @@ -26,7 +26,7 @@ end # # Initial proposal # -function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModel) +function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModelOrLogDensityModel) # Make the first proposal with a static draw from the prior. static_prop = StaticProposal(spl.proposal.proposal) mh_spl = MetropolisHastings(static_prop) @@ -39,7 +39,7 @@ end function propose( rng::Random.AbstractRNG, spl::Ensemble, - model::DensityModel, + model::DensityModelOrLogDensityModel, walkers::Vector{<:Transition}, ) new_walkers = similar(walkers) @@ -68,7 +68,7 @@ StretchProposal(p) = StretchProposal(p, 2.0) function move( rng::Random.AbstractRNG, spl::Ensemble{<:StretchProposal}, - model::DensityModel, + model::DensityModelOrLogDensityModel, walker::Transition, other_walker::Transition, ) diff --git a/src/mcmcchains-connect.jl b/src/mcmcchains-connect.jl index 6e442e0..14d9730 100644 --- a/src/mcmcchains-connect.jl +++ b/src/mcmcchains-connect.jl @@ -3,7 +3,7 @@ import .MCMCChains: Chains # A basic chains constructor that works with the Transition struct we defined. function AbstractMCMC.bundle_samples( ts::Vector{<:AbstractTransition}, - model::DensityModel, + model::DensityModelOrLogDensityModel, sampler::MHSampler, state, chain_type::Type{Chains}; @@ -32,7 +32,7 @@ end function AbstractMCMC.bundle_samples( ts::Vector{<:Transition{<:NamedTuple}}, - model::DensityModel, + model::DensityModelOrLogDensityModel, sampler::MHSampler, state, chain_type::Type{Chains}; @@ -71,7 +71,7 @@ end function AbstractMCMC.bundle_samples( ts::Vector{<:Vector{<:AbstractTransition}}, - model::DensityModel, + model::DensityModelOrLogDensityModel, sampler::Ensemble, state, chain_type::Type{Chains}; diff --git a/src/mh-core.jl b/src/mh-core.jl index def5553..e2fc794 100644 --- a/src/mh-core.jl +++ b/src/mh-core.jl @@ -48,23 +48,23 @@ end StaticMH(d) = MetropolisHastings(StaticProposal(d)) RWMH(d) = MetropolisHastings(RandomWalkProposal(d)) -function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModel) +function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModelOrLogDensityModel) return propose(rng, sampler.proposal, model) end function propose( rng::Random.AbstractRNG, sampler::MHSampler, - model::DensityModel, + model::DensityModelOrLogDensityModel, transition_prev::Transition, ) return propose(rng, sampler.proposal, model, transition_prev.params) end -function transition(sampler::MHSampler, model::DensityModel, params) +function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params) logdensity = AdvancedMH.logdensity(model, params) return transition(sampler, model, params, logdensity) end -function transition(sampler::MHSampler, model::DensityModel, params, logdensity::Real) +function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real) return Transition(params, logdensity) end @@ -73,7 +73,7 @@ end # In this case they are identical. function AbstractMCMC.step( rng::Random.AbstractRNG, - model::DensityModel, + model::DensityModelOrLogDensityModel, sampler::MHSampler; init_params=nothing, kwargs... @@ -89,7 +89,7 @@ end # or the previous proposal (if not accepted). function AbstractMCMC.step( rng::Random.AbstractRNG, - model::DensityModel, + model::DensityModelOrLogDensityModel, sampler::MHSampler, transition_prev::AbstractTransition; kwargs... diff --git a/src/proposal.jl b/src/proposal.jl index c02e5f7..c5d58a3 100644 --- a/src/proposal.jl +++ b/src/proposal.jl @@ -41,7 +41,7 @@ end function propose( rng::Random.AbstractRNG, proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}}, - ::DensityModel + ::DensityModelOrLogDensityModel ) where {issymmetric} return rand(rng, proposal) end @@ -49,7 +49,7 @@ end function propose( rng::Random.AbstractRNG, proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}}, - model::DensityModel, + model::DensityModelOrLogDensityModel, t ) where {issymmetric} return t + rand(rng, proposal) @@ -70,7 +70,7 @@ end function propose( rng::Random.AbstractRNG, proposal::StaticProposal{issymmetric,<:Union{Distribution,AbstractArray}}, - model::DensityModel, + model::DensityModelOrLogDensityModel, t=nothing ) where {issymmetric} return rand(rng, proposal) @@ -103,7 +103,7 @@ end function propose( rng::Random.AbstractRNG, proposal::Proposal{<:Function}, - model::DensityModel + model::DensityModelOrLogDensityModel ) return propose(rng, proposal(), model) end @@ -111,7 +111,7 @@ end function propose( rng::Random.AbstractRNG, proposal::Proposal{<:Function}, - model::DensityModel, + model::DensityModelOrLogDensityModel, t ) return propose(rng, proposal(t), model) @@ -132,7 +132,7 @@ end function propose( rng::Random.AbstractRNG, proposals::AbstractArray{<:Proposal}, - model::DensityModel, + model::DensityModelOrLogDensityModel, ) return map(proposals) do proposal return propose(rng, proposal, model) @@ -141,7 +141,7 @@ end function propose( rng::Random.AbstractRNG, proposals::AbstractArray{<:Proposal}, - model::DensityModel, + model::DensityModelOrLogDensityModel, ts, ) return map(proposals, ts) do proposal, t @@ -152,7 +152,7 @@ end @generated function propose( rng::Random.AbstractRNG, proposals::NamedTuple{names}, - model::DensityModel, + model::DensityModelOrLogDensityModel, ) where {names} isempty(names) && return :(NamedTuple()) expr = Expr(:tuple) @@ -163,7 +163,7 @@ end @generated function propose( rng::Random.AbstractRNG, proposals::NamedTuple{names}, - model::DensityModel, + model::DensityModelOrLogDensityModel, ts, ) where {names} isempty(names) && return :(NamedTuple()) diff --git a/src/structarray-connect.jl b/src/structarray-connect.jl index a1e559c..5249a30 100644 --- a/src/structarray-connect.jl +++ b/src/structarray-connect.jl @@ -3,7 +3,7 @@ import .StructArrays: StructArray # A basic chains constructor that works with the Transition struct we defined. function AbstractMCMC.bundle_samples( ts::Vector{<:AbstractTransition}, - model::DensityModel, + model::DensityModelOrLogDensityModel, sampler::MHSampler, state, chain_type::Type{StructArray}; From 94905962f1914df8be60003f45fa078c554c2b7d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Dec 2022 23:57:40 +0000 Subject: [PATCH 2/7] added capabilities check for MALA, which requires order 1 --- src/MALA.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/MALA.jl b/src/MALA.jl index 04bb3d9..926dd5d 100644 --- a/src/MALA.jl +++ b/src/MALA.jl @@ -28,7 +28,14 @@ end check_capabilities(model::DensityModelOrLogDensityModel) = nothing function check_capabilities(model::AbstractMCMC.LogDensityModel) - @assert LogDensityProblems.capabilities(model.logdensity) !== LogDensityProblems.LogDensityOrder{0}() + cap = LogDensityProblems.capabilities(model.logdensity) + if cap === nothing + throw(ArgumentError("The log density function does not support the LogDensityProblems.jl interface")) + end + + if cap === LogDensityProblems.LogDensityOrder{0}() + throw(ArgumentError("The gradient of the log density function is not defined: Implement `LogDensityProblems.logdensity_and_gradient` or use automatic differentiation provided by LogDensityProblemsAD.jl")) + end end function AbstractMCMC.step( From 84afbe559069073ec31fd72177a6939c28fa0976 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Dec 2022 11:53:44 +0000 Subject: [PATCH 3/7] added logdensity_and_gradient implementation for LogDensityModel --- src/MALA.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/MALA.jl b/src/MALA.jl index 926dd5d..4b04768 100644 --- a/src/MALA.jl +++ b/src/MALA.jl @@ -84,9 +84,19 @@ end Return the value and gradient of the log density of the parameters `params` for the `model`. """ -function logdensity_and_gradient(model::DensityModelOrLogDensityModel, params) +function logdensity_and_gradient(model::DensityModel, params) res = GradientResult(params) gradient!(res, model.logdensity, params) return value(res), gradient(res) end +""" + logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params) + +Return the value and gradient of the log density of the parameters `params` for the `model`. +""" +function logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params) + return LogDensityProblems.logdensity_and_gradient(model.logdensity, params) + end + + From 9fd2150c38a426939c3c993638201a04fbe4c948 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Dec 2022 12:48:41 +0000 Subject: [PATCH 4/7] fixed export --- src/AdvancedMH.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/AdvancedMH.jl b/src/AdvancedMH.jl index a72d937..b7767ef 100644 --- a/src/AdvancedMH.jl +++ b/src/AdvancedMH.jl @@ -12,7 +12,7 @@ import Random # Exports export MetropolisHastings, - DensityModelOrLogDensityModel, + DensityModel, RWMH, StaticMH, StaticProposal, From 75bffde252c7933a5d71c435d34578a2ba9c21a1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Dec 2022 12:48:57 +0000 Subject: [PATCH 5/7] added tests for LogDensityModel usage --- test/runtests.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index cff349b..4ae849f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,9 @@ using ForwardDiff using MCMCChains using StructArrays +using LogDensityProblems: LogDensityProblems +using LogDensityProblemsAD: LogDensityProblemsAD + using LinearAlgebra using Random using Test @@ -26,6 +29,10 @@ include("util.jl") # Construct a DensityModel. model = DensityModel(density) + # `LogDensityModel` + LogDensityProblems.logdensity(::typeof(density), θ) = density(θ) + LogDensityProblems.dimension(::typeof(density)) = 2 + @testset "StaticMH" begin # Set up our sampler with initial parameters. spl1 = StaticMH([Normal(0,1), Normal(0, 1)]) @@ -254,6 +261,21 @@ include("util.jl") @test mean(chain1.μ) ≈ 0.0 atol=0.1 @test mean(chain1.σ) ≈ 1.0 atol=0.1 + + @testset "LogDensityProblems interface" begin + admodel = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), density) + chain2 = sample( + AdvancedMH.AbstractMCMC.LogDensityModel(admodel), + spl1, + 100000; + init_params=ones(2), + chain_type=StructArray, + param_names=["μ", "σ"] + ) + + @test mean(chain2.μ) ≈ 0.0 atol=0.1 + @test mean(chain2.σ) ≈ 1.0 atol=0.1 + end end @testset "EMCEE" begin include("emcee.jl") end From 7c322eb49bb7b1f38f7b511a30c138b4d4d83907 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Dec 2022 12:49:44 +0000 Subject: [PATCH 6/7] added example of using the LogDensityInterface to the README --- README.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/README.md b/README.md index 6f05a18..256c2d9 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,24 @@ Quantiles ``` +### Usage with [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) + +It can also be used with models defining the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface by wrapping it in `AbstractMCMC.LogDensityModel` before passing it to `sample`: + +``` julia +using AbstractMCMC: LogDensityModel +using LogDensityProblems + +# Use a struct instead of `typeof(density)` for sake of readability. +struct LogTargetDensity end + +LogDensityProblems.logdensity(p::LogTargetDensity, θ) = density(θ) # standard multivariate normal +LogDensityProblems.dimension(p::LogTargetDensity) = 2 +LogDensityProblems.capabilities(::LogTargetDensity) = LogDensityProblems.LogDensityOrder{0}() + +sample(LogDensityModel(LogTargetDensity()), spl, 100000; param_names=["μ", "σ"], chain_type=Chains) +``` + ## Proposals AdvancedMH offers various methods of defining your inference problem. Behind the scenes, a `MetropolisHastings` sampler simply holds @@ -157,3 +175,13 @@ spl = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I)) # Sample from the posterior. chain = sample(model, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"]) ``` + +### Usage with [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl) + +Using our implementation of the `LogDensityProblems.jl` interface from earlier, we can use [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl) to provide us with the gradient computation used in MALA: + +```julia +using LogDensityProblemsAD +model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity()) +sample(LogDensityModel(model_with_ad), spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"]) +``` From 8ad165a2c47713596061e58340e781e85264b384 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Dec 2022 12:52:22 +0000 Subject: [PATCH 7/7] bumped minor version and added test deps --- Project.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index a0254de..fa76bb6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedMH" uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170" -version = "0.6.4" +version = "0.7.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -19,9 +19,11 @@ julia = "1" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["DiffResults", "ForwardDiff", "LinearAlgebra", "MCMCChains", "StructArrays", "Test"] +test = ["DiffResults", "ForwardDiff", "LinearAlgebra", "LogDensityProblems", "LogDensityProblemsAD", "MCMCChains", "StructArrays", "Test"]