Skip to content

Commit

Permalink
Merge pull request #70 from TuringLang/torfjelde/logdensitymodel
Browse files Browse the repository at this point in the history
Added support for models using the LogDensityProblems.jl interface
  • Loading branch information
cpfiffer authored Dec 28, 2022
2 parents 5950f4a + f58c636 commit e174117
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 35 deletions.
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
name = "AdvancedMH"
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
version = "0.6.8"
version = "0.7.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractMCMC = "4"
Distributions = "0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
Requires = "1"
julia = "1"
Expand All @@ -18,9 +19,11 @@ julia = "1"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DiffResults", "ForwardDiff", "LinearAlgebra", "MCMCChains", "StructArrays", "Test"]
test = ["DiffResults", "ForwardDiff", "LinearAlgebra", "LogDensityProblems", "LogDensityProblemsAD", "MCMCChains", "StructArrays", "Test"]
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ Quantiles

```

### Usage with [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl)

It can also be used with models defining the [`LogDensityProblems.jl`](https://github.com/tpapp/LogDensityProblems.jl) interface by wrapping it in `AbstractMCMC.LogDensityModel` before passing it to `sample`:

``` julia
using AbstractMCMC: LogDensityModel
using LogDensityProblems

# Use a struct instead of `typeof(density)` for sake of readability.
struct LogTargetDensity end

LogDensityProblems.logdensity(p::LogTargetDensity, θ) = density(θ) # standard multivariate normal
LogDensityProblems.dimension(p::LogTargetDensity) = 2
LogDensityProblems.capabilities(::LogTargetDensity) = LogDensityProblems.LogDensityOrder{0}()

sample(LogDensityModel(LogTargetDensity()), spl, 100000; param_names=["μ", "σ"], chain_type=Chains)
```

## Proposals

AdvancedMH offers various methods of defining your inference problem. Behind the scenes, a `MetropolisHastings` sampler simply holds
Expand Down Expand Up @@ -162,3 +180,13 @@ spl = MALA(x -> MvNormal((σ² / 2) .* x, σ² * I))
# Sample from the posterior.
chain = sample(model, spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
```

### Usage with [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl)

Using our implementation of the `LogDensityProblems.jl` interface from earlier, we can use [`LogDensityProblemsAD.jl`](https://github.com/tpapp/LogDensityProblemsAD.jl) to provide us with the gradient computation used in MALA:

```julia
using LogDensityProblemsAD
model_with_ad = LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), LogTargetDensity())
sample(LogDensityModel(model_with_ad), spl, 100000; init_params=ones(2), chain_type=StructArray, param_names=["μ", "σ"])
```
20 changes: 15 additions & 5 deletions src/AdvancedMH.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ using AbstractMCMC
using Distributions
using Requires

using LogDensityProblems: LogDensityProblems

import Random

# Exports
Expand Down Expand Up @@ -48,6 +50,8 @@ struct DensityModel{F} <: AbstractMCMC.AbstractModel
logdensity :: F
end

const DensityModelOrLogDensityModel = Union{<:DensityModel,<:AbstractMCMC.LogDensityModel}

# Create a very basic Transition type, only stores the
# parameter draws and the log probability of the draw.
struct Transition{T,L<:Real} <: AbstractTransition
Expand All @@ -56,16 +60,22 @@ struct Transition{T,L<:Real} <: AbstractTransition
end

# Store the new draw and its log density.
Transition(model::DensityModel, params) = Transition(params, logdensity(model, params))
Transition(model::DensityModelOrLogDensityModel, params) = Transition(params, logdensity(model, params))
function Transition(model::AbstractMCMC.LogDensityModel, params)
return Transition(params, LogDensityProblems.logdensity(model.logdensity, params))
end

# Calculate the density of the model given some parameterization.
logdensity(model::DensityModel, params) = model.logdensity(params)
logdensity(model::DensityModel, t::Transition) = t.lp
logdensity(model::DensityModelOrLogDensityModel, params) = model.logdensity(params)
logdensity(model::DensityModelOrLogDensityModel, t::Transition) = t.lp

logdensity(model::AbstractMCMC.LogDensityModel, params) = LogDensityProblems.logdensity(model.logdensity, params)
logdensity(model::AbstractMCMC.LogDensityModel, t::Transition) = t.lp

# A basic chains constructor that works with the Transition struct we defined.
function AbstractMCMC.bundle_samples(
ts::Vector{<:AbstractTransition},
model::DensityModel,
model::Union{<:DensityModelOrLogDensityModel,<:AbstractMCMC.LogDensityModel},
sampler::MHSampler,
state,
chain_type::Type{Vector{NamedTuple}};
Expand All @@ -91,7 +101,7 @@ end

function AbstractMCMC.bundle_samples(
ts::Vector{<:Transition{<:NamedTuple}},
model::DensityModel,
model::Union{<:DensityModelOrLogDensityModel,<:AbstractMCMC.LogDensityModel},
sampler::MHSampler,
state,
chain_type::Type{Vector{NamedTuple}};
Expand Down
30 changes: 27 additions & 3 deletions src/MALA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,34 @@ struct GradientTransition{T<:Union{Vector, Real, NamedTuple}, L<:Real, G<:Union{
gradient::G
end

logdensity(model::DensityModel, t::GradientTransition) = t.lp
logdensity(model::DensityModelOrLogDensityModel, t::GradientTransition) = t.lp

propose(rng::Random.AbstractRNG, ::MALA, model) = error("please specify initial parameters")
function transition(sampler::MALA, model::DensityModel, params)
function transition(sampler::MALA, model::DensityModelOrLogDensityModel, params)
return GradientTransition(params, logdensity_and_gradient(model, params)...)
end

check_capabilities(model::DensityModelOrLogDensityModel) = nothing
function check_capabilities(model::AbstractMCMC.LogDensityModel)
cap = LogDensityProblems.capabilities(model.logdensity)
if cap === nothing
throw(ArgumentError("The log density function does not support the LogDensityProblems.jl interface"))
end

if cap === LogDensityProblems.LogDensityOrder{0}()
throw(ArgumentError("The gradient of the log density function is not defined: Implement `LogDensityProblems.logdensity_and_gradient` or use automatic differentiation provided by LogDensityProblemsAD.jl"))
end
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MALA,
transition_prev::GradientTransition;
kwargs...
)
check_capabilities(model)

# Extract value and gradient of the log density of the current state.
state = transition_prev.params
logdensity_state = transition_prev.lp
Expand Down Expand Up @@ -76,3 +90,13 @@ function logdensity_and_gradient(model::DensityModel, params)
return value(res), gradient(res)
end

"""
logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params)
Return the value and gradient of the log density of the parameters `params` for the `model`.
"""
function logdensity_and_gradient(model::AbstractMCMC.LogDensityModel, params)
return LogDensityProblems.logdensity_and_gradient(model.logdensity, params)
end


10 changes: 5 additions & 5 deletions src/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ struct Ensemble{D} <: MHSampler
proposal::D
end

function transition(sampler::Ensemble, model::DensityModel, params)
function transition(sampler::Ensemble, model::DensityModelOrLogDensityModel, params)
return [Transition(model, x) for x in params]
end

Expand All @@ -13,7 +13,7 @@ end
# (if accepted) or the previous proposal (if not accepted).
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModel,
model::DensityModelOrLogDensityModel,
spl::Ensemble,
params_prev::Vector{<:Transition};
kwargs...,
Expand All @@ -26,7 +26,7 @@ end
#
# Initial proposal
#
function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModel)
function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModelOrLogDensityModel)
# Make the first proposal with a static draw from the prior.
static_prop = StaticProposal(spl.proposal.proposal)
mh_spl = MetropolisHastings(static_prop)
Expand All @@ -39,7 +39,7 @@ end
function propose(
rng::Random.AbstractRNG,
spl::Ensemble,
model::DensityModel,
model::DensityModelOrLogDensityModel,
walkers::Vector{<:Transition},
)
new_walkers = similar(walkers)
Expand Down Expand Up @@ -68,7 +68,7 @@ StretchProposal(p) = StretchProposal(p, 2.0)
function move(
rng::Random.AbstractRNG,
spl::Ensemble{<:StretchProposal},
model::DensityModel,
model::DensityModelOrLogDensityModel,
walker::Transition,
other_walker::Transition,
)
Expand Down
6 changes: 3 additions & 3 deletions src/mcmcchains-connect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import .MCMCChains: Chains
# A basic chains constructor that works with the Transition struct we defined.
function AbstractMCMC.bundle_samples(
ts::Vector{<:AbstractTransition},
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MHSampler,
state,
chain_type::Type{Chains};
Expand Down Expand Up @@ -32,7 +32,7 @@ end

function AbstractMCMC.bundle_samples(
ts::Vector{<:Transition{<:NamedTuple}},
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MHSampler,
state,
chain_type::Type{Chains};
Expand Down Expand Up @@ -71,7 +71,7 @@ end

function AbstractMCMC.bundle_samples(
ts::Vector{<:Vector{<:AbstractTransition}},
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::Ensemble,
state,
chain_type::Type{Chains};
Expand Down
12 changes: 6 additions & 6 deletions src/mh-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,23 @@ end
StaticMH(d) = MetropolisHastings(StaticProposal(d))
RWMH(d) = MetropolisHastings(RandomWalkProposal(d))

function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModel)
function propose(rng::Random.AbstractRNG, sampler::MHSampler, model::DensityModelOrLogDensityModel)
return propose(rng, sampler.proposal, model)
end
function propose(
rng::Random.AbstractRNG,
sampler::MHSampler,
model::DensityModel,
model::DensityModelOrLogDensityModel,
transition_prev::Transition,
)
return propose(rng, sampler.proposal, model, transition_prev.params)
end

function transition(sampler::MHSampler, model::DensityModel, params)
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params)
logdensity = AdvancedMH.logdensity(model, params)
return transition(sampler, model, params, logdensity)
end
function transition(sampler::MHSampler, model::DensityModel, params, logdensity::Real)
function transition(sampler::MHSampler, model::DensityModelOrLogDensityModel, params, logdensity::Real)
return Transition(params, logdensity)
end

Expand All @@ -73,7 +73,7 @@ end
# In this case they are identical.
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MHSampler;
init_params=nothing,
kwargs...
Expand All @@ -89,7 +89,7 @@ end
# or the previous proposal (if not accepted).
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MHSampler,
transition_prev::AbstractTransition;
kwargs...
Expand Down
18 changes: 9 additions & 9 deletions src/proposal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ end
function propose(
rng::Random.AbstractRNG,
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
::DensityModel
::DensityModelOrLogDensityModel
) where {issymmetric}
return rand(rng, proposal)
end

function propose(
rng::Random.AbstractRNG,
proposal::RandomWalkProposal{issymmetric,<:Union{Distribution,AbstractArray}},
model::DensityModel,
model::DensityModelOrLogDensityModel,
t
) where {issymmetric}
return t + rand(rng, proposal)
Expand All @@ -70,7 +70,7 @@ end
function propose(
rng::Random.AbstractRNG,
proposal::StaticProposal{issymmetric,<:Union{Distribution,AbstractArray}},
model::DensityModel,
model::DensityModelOrLogDensityModel,
t=nothing
) where {issymmetric}
return rand(rng, proposal)
Expand Down Expand Up @@ -103,15 +103,15 @@ end
function propose(
rng::Random.AbstractRNG,
proposal::Proposal{<:Function},
model::DensityModel
model::DensityModelOrLogDensityModel
)
return propose(rng, proposal(), model)
end

function propose(
rng::Random.AbstractRNG,
proposal::Proposal{<:Function},
model::DensityModel,
model::DensityModelOrLogDensityModel,
t
)
return propose(rng, proposal(t), model)
Expand All @@ -132,7 +132,7 @@ end
function propose(
rng::Random.AbstractRNG,
proposals::AbstractArray{<:Proposal},
model::DensityModel,
model::DensityModelOrLogDensityModel,
)
return map(proposals) do proposal
return propose(rng, proposal, model)
Expand All @@ -141,7 +141,7 @@ end
function propose(
rng::Random.AbstractRNG,
proposals::AbstractArray{<:Proposal},
model::DensityModel,
model::DensityModelOrLogDensityModel,
ts,
)
return map(proposals, ts) do proposal, t
Expand All @@ -152,7 +152,7 @@ end
@generated function propose(
rng::Random.AbstractRNG,
proposals::NamedTuple{names},
model::DensityModel,
model::DensityModelOrLogDensityModel,
) where {names}
isempty(names) && return :(NamedTuple())
expr = Expr(:tuple)
Expand All @@ -163,7 +163,7 @@ end
@generated function propose(
rng::Random.AbstractRNG,
proposals::NamedTuple{names},
model::DensityModel,
model::DensityModelOrLogDensityModel,
ts,
) where {names}
isempty(names) && return :(NamedTuple())
Expand Down
2 changes: 1 addition & 1 deletion src/structarray-connect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import .StructArrays: StructArray
# A basic chains constructor that works with the Transition struct we defined.
function AbstractMCMC.bundle_samples(
ts::Vector{<:AbstractTransition},
model::DensityModel,
model::DensityModelOrLogDensityModel,
sampler::MHSampler,
state,
chain_type::Type{StructArray};
Expand Down
Loading

0 comments on commit e174117

Please sign in to comment.