Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for models using the LogDensityProblems.jl interface #70

Merged
merged 8 commits into from
Dec 28, 2022
Merged
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.6.8"
version = "0.7.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractMCMC = "4"
Distributions = "0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
Requires = "1"
julia = "1"
Expand All @@ -18,9 +19,11 @@ julia = "1"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DiffResults", "ForwardDiff", "LinearAlgebra", "MCMCChains", "StructArrays", "Test"]
test = ["DiffResults", "ForwardDiff", "LinearAlgebra", "LogDensityProblems", "LogDensityProblemsAD", "MCMCChains", "StructArrays", "Test"]
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ Quantiles

```

### Usage with [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl)

It can also be used with models defining the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface by wrapping it in `AbstractMCMC.LogDensityModel` before passing it to `sample`:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really necessary to wrap the log density? Couldn't we wrap anything in a LogDensityModel that's not one of our specialized model types? We could even check LogDensityProblems.capabilities !== nothing to ensure that we do not wrap anything accidentally that does not implement the interface.


``` julia
using AbstractMCMC: LogDensityModel
using LogDensityProblems

# Use a struct instead of `typeof(density)` for sake of readability.
struct LogTargetDensity end

LogDensityProblems.logdensity(p::LogTargetDensity, θ) = density(θ) # standard multivariate normal
LogDensityProblems.dimension(p::LogTargetDensity) = 2
LogDensityProblems.capabilities(::LogTargetDensity) = LogDensityProblems.LogDensityOrder{0}()

sample(LogDensityModel(LogTargetDensity()), spl, 100000; param_names=["μ", "σ"], chain_type=Chains)
```

## Proposals

AdvancedMH offers various methods of defining your inference problem. Behind the scenes, a `MetropolisHastings` sampler simply holds
Expand Down Expand Up @@ -162,3 +180,13 @@ spl = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))
# Sample from the posterior.
chain = sample(model, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
```

### Usage with [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl)

Using our implementation of the `LogDensityProblems.jl` interface from earlier, we can use [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl) to provide us with the gradient computation used in MALA:

```julia
using LogDensityProblemsAD
model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity())
sample(LogDensityModel(model_with_ad), spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
```
20 changes: 15 additions & 5 deletions src/AdvancedMH.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using AbstractMCMC
using Distributions
using Requires

using LogDensityProblems: LogDensityProblems

import Random

# Exports
Expand Down Expand Up @@ -48,6 +50,8 @@ struct DensityModel{F} <: AbstractMCMC.AbstractModel
logdensity :: F
end

const DensityModelOrLogDensityModel = Union{<:DensityModel,<:AbstractMCMC.LogDensityModel}

# Create a very basic Transition type, only stores the
# parameter draws and the log probability of the draw.
struct Transition{T,L<:Real} <: AbstractTransition
Expand All @@ -56,16 +60,22 @@ struct Transition{T,L<:Real} <: AbstractTransition
end

# Store the new draw and its log density.
Transition(model::DensityModel, params) = Transition(params, logdensity(model, params))
Transition(model::DensityModelOrLogDensityModel, params) = Transition(params, logdensity(model, params))
function Transition(model::AbstractMCMC.LogDensityModel, params)
return Transition(params, LogDensityProblems.logdensity(model.logdensity, params))
end

# Calculate the density of the model given some parameterization.
logdensity(model::DensityModel, params) = model.logdensity(params)
logdensity(model::DensityModel, t::Transition) = t.lp
logdensity(model::DensityModelOrLogDensityModel, params) = model.logdensity(params)
logdensity(model::DensityModelOrLogDensityModel, t::Transition) = t.lp

logdensity(model::AbstractMCMC.LogDensityModel, params) = LogDensityProblems.logdensity(model.logdensity, params)
logdensity(model::AbstractMCMC.LogDensityModel, t::Transition) = t.lp

# A basic chains constructor that works with the Transition struct we defined.
function AbstractMCMC.bundle_samples(
ts::Vector{<:AbstractTransition},
model::DensityModel,
model::Union{<:DensityModelOrLogDensityModel,<:AbstractMCMC.LogDensityModel},
sampler::MHSampler,
state,
chain_type::Type{Vector{NamedTuple}};
Expand All @@ -91,7 +101,7 @@ end

function AbstractMCMC.bundle_samples(
ts::Vector{<:Transition{<:NamedTuple}},
model::DensityModel,
model::Union{<:DensityModelOrLogDensityModel,<:AbstractMCMC.LogDensityModel},
sampler::MHSampler,
state,
chain_type::Type{Vector{NamedTuple}};
Expand Down
30 changes: 27 additions & 3 deletions src/MALA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,34 @@ struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{
gradient::G
end

logdensity(model::DensityModel, t::GradientTransition) = t.lp
logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp

propose(rng::Random.AbstractRNG, ::MALA, model) = error("please specify initial parameters")
function transition(sampler::MALA, model::DensityModel, params)
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params)
return GradientTransition(params, logdensity_and_gradient(model, params)...)
end

check_capabilities(model::DensityModelOrLogDensityModel) = nothing
function check_capabilities(model::AbstractMCMC.LogDensityModel)
cap = LogDensityProblems.capabilities(model.logdensity)
if cap === nothing
throw(ArgumentError("The log density function does not support the LogDensityProblems.jl interface"))
end

if cap === LogDensityProblems.LogDensityOrder{0}()
throw(ArgumentError("The gradient of the log density function is not defined: Implement `LogDensityProblems.logdensity_and_gradient` or use automatic differentiation provided by LogDensityProblemsAD.jl"))
end
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MALA,
transition_prev::GradientTransition;
kwargs...
)
check_capabilities(model)

# Extract value and gradient of the log density of the current state.
state = transition_prev.params
logdensity_state = transition_prev.lp
Expand Down Expand Up @@ -76,3 +90,13 @@ function logdensity_and_gradient(model::DensityModel, params)
return value(res), gradient(res)
end

"""
logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params)

Return the value and gradient of the log density of the parameters `params` for the `model`.
"""
function logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params)
return LogDensityProblems.logdensity_and_gradient(model.logdensity, params)
end


10 changes: 5 additions & 5 deletions src/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ struct Ensemble{D} <: MHSampler
proposal::D
end

function transition(sampler::Ensemble, model::DensityModel, params)
function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params)
return [Transition(model, x) for x in params]
end

Expand All @@ -13,7 +13,7 @@ end
# (if accepted) or the previous proposal (if not accepted).
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModel,
model::DensityModelOrLogDensityModel,
spl::Ensemble,
params_prev::Vector{<:Transition};
kwargs...,
Expand All @@ -26,7 +26,7 @@ end
#
# Initial proposal
#
function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModel)
function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModelOrLogDensityModel)
# Make the first proposal with a static draw from the prior.
static_prop = StaticProposal(spl.proposal.proposal)
mh_spl = MetropolisHastings(static_prop)
Expand All @@ -39,7 +39,7 @@ end
function propose(
rng::Random.AbstractRNG,
spl::Ensemble,
model::DensityModel,
model::DensityModelOrLogDensityModel,
walkers::Vector{<:Transition},
)
new_walkers = similar(walkers)
Expand Down Expand Up @@ -68,7 +68,7 @@ StretchProposal(p) = StretchProposal(p, 2.0)
function move(
rng::Random.AbstractRNG,
spl::Ensemble{<:StretchProposal},
model::DensityModel,
model::DensityModelOrLogDensityModel,
walker::Transition,
other_walker::Transition,
)
Expand Down
6 changes: 3 additions & 3 deletions src/mcmcchains-connect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import .MCMCChains: Chains
# A basic chains constructor that works with the Transition struct we defined.
function AbstractMCMC.bundle_samples(
ts::Vector{<:AbstractTransition},
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MHSampler,
state,
chain_type::Type{Chains};
Expand Down Expand Up @@ -32,7 +32,7 @@ end

function AbstractMCMC.bundle_samples(
ts::Vector{<:Transition{<:NamedTuple}},
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MHSampler,
state,
chain_type::Type{Chains};
Expand Down Expand Up @@ -71,7 +71,7 @@ end

function AbstractMCMC.bundle_samples(
ts::Vector{<:Vector{<:AbstractTransition}},
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::Ensemble,
state,
chain_type::Type{Chains};
Expand Down
12 changes: 6 additions & 6 deletions src/mh-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,23 @@ end
StaticMH(d) = MetropolisHastings(StaticProposal(d))
RWMH(d) = MetropolisHastings(RandomWalkProposal(d))

function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModel)
function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModelOrLogDensityModel)
return propose(rng, sampler.proposal, model)
end
function propose(
rng::Random.AbstractRNG,
sampler::MHSampler,
model::DensityModel,
model::DensityModelOrLogDensityModel,
transition_prev::Transition,
)
return propose(rng, sampler.proposal, model, transition_prev.params)
end

function transition(sampler::MHSampler, model::DensityModel, params)
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params)
logdensity = AdvancedMH.logdensity(model, params)
return transition(sampler, model, params, logdensity)
end
function transition(sampler::MHSampler, model::DensityModel, params, logdensity::Real)
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real)
return Transition(params, logdensity)
end

Expand All @@ -73,7 +73,7 @@ end
# In this case they are identical.
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MHSampler;
init_params=nothing,
kwargs...
Expand All @@ -89,7 +89,7 @@ end
# or the previous proposal (if not accepted).
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MHSampler,
transition_prev::AbstractTransition;
kwargs...
Expand Down
18 changes: 9 additions & 9 deletions src/proposal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ end
function propose(
rng::Random.AbstractRNG,
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
::DensityModel
::DensityModelOrLogDensityModel
) where {issymmetric}
return rand(rng, proposal)
end

function propose(
rng::Random.AbstractRNG,
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
model::DensityModel,
model::DensityModelOrLogDensityModel,
t
) where {issymmetric}
return t + rand(rng, proposal)
Expand All @@ -70,7 +70,7 @@ end
function propose(
rng::Random.AbstractRNG,
proposal::StaticProposal{issymmetric,<:Union{Distribution,AbstractArray}},
model::DensityModel,
model::DensityModelOrLogDensityModel,
t=nothing
) where {issymmetric}
return rand(rng, proposal)
Expand Down Expand Up @@ -103,15 +103,15 @@ end
function propose(
rng::Random.AbstractRNG,
proposal::Proposal{<:Function},
model::DensityModel
model::DensityModelOrLogDensityModel
)
return propose(rng, proposal(), model)
end

function propose(
rng::Random.AbstractRNG,
proposal::Proposal{<:Function},
model::DensityModel,
model::DensityModelOrLogDensityModel,
t
)
return propose(rng, proposal(t), model)
Expand All @@ -132,7 +132,7 @@ end
function propose(
rng::Random.AbstractRNG,
proposals::AbstractArray{<:Proposal},
model::DensityModel,
model::DensityModelOrLogDensityModel,
)
return map(proposals) do proposal
return propose(rng, proposal, model)
Expand All @@ -141,7 +141,7 @@ end
function propose(
rng::Random.AbstractRNG,
proposals::AbstractArray{<:Proposal},
model::DensityModel,
model::DensityModelOrLogDensityModel,
ts,
)
return map(proposals, ts) do proposal, t
Expand All @@ -152,7 +152,7 @@ end
@generated function propose(
rng::Random.AbstractRNG,
proposals::NamedTuple{names},
model::DensityModel,
model::DensityModelOrLogDensityModel,
) where {names}
isempty(names) && return :(NamedTuple())
expr = Expr(:tuple)
Expand All @@ -163,7 +163,7 @@ end
@generated function propose(
rng::Random.AbstractRNG,
proposals::NamedTuple{names},
model::DensityModel,
model::DensityModelOrLogDensityModel,
ts,
) where {names}
isempty(names) && return :(NamedTuple())
Expand Down
2 changes: 1 addition & 1 deletion src/structarray-connect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import .StructArrays: StructArray
# A basic chains constructor that works with the Transition struct we defined.
function AbstractMCMC.bundle_samples(
ts::Vector{<:AbstractTransition},
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MHSampler,
state,
chain_type::Type{StructArray};
Expand Down
Loading