Skip to content

Commit

Permalink
Use LogDensityProblems.jl (#301)
Browse files Browse the repository at this point in the history
* removed AD compats and added implementation of LogDensityProblems

* improved README a bit

* added LogDensityProblems to test deps

* updated tests

* Apply suggestions from code review

Co-authored-by: Hong Ge <[email protected]>

* fixed tests

* depend on LogDensityProblemsAD and new version of AbstractMCMC

* removed unnecessary comment and export of no-longer-existing model

* updated to work with new AbstractMCMC

* bump compat entry for AbstractMCMC

* fixed bug in tests

* added compat entries for LogDensityProblems and AD

* removed type-piracy in tests

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* added FillArrays to tests

* added manual specification of the gradientconfig when working with ComponentArrays

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* updated tests to work with LogDensityProblemsAD

* remove usage of Distributions in README

* added a TODO

* Update src/AdvancedHMC.jl

Co-authored-by: David Widmann <[email protected]>

* fixed missing end

Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
3 people authored Dec 25, 2022
1 parent 6a55a3f commit 3dc2822
Show file tree
Hide file tree
Showing 14 changed files with 117 additions and 160 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -18,10 +20,12 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
AbstractMCMC = "3.2, 4"
AbstractMCMC = "4.2"
ArgCheck = "1, 2"
DocStringExtensions = "0.8, 0.9"
InplaceOps = "0.3"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
ProgressMeter = "1"
Requires = "0.5, 1"
Setfield = "0.7, 0.8, 1"
Expand Down
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,20 @@ In this section we demonstrate a minimal example of sampling from a multivariate
</details>

```julia
using AdvancedHMC, Distributions, ForwardDiff
using AdvancedHMC, ForwardDiff
using LogDensityProblems
using LinearAlgebra

# Define the target distribution using the `LogDensityProblem` interface
struct LogTargetDensity
dim::Int
end
LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2 # standard multivariate normal
LogDensityProblems.dimension(p::LogTargetDensity) = p.dim

# Choose parameter dimensionality and initial parameter value
D = 10; initial_θ = rand(D)

# Define the target distribution
ℓπ(θ) = logpdf(MvNormal(zeros(D), I), θ)
ℓπ = LogTargetDensity(D)

# Set the number of samples to draw and warmup iterations
n_samples, n_adapts = 2_000, 1_000
Expand Down
42 changes: 32 additions & 10 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ using ArgCheck: @argcheck

using DocStringExtensions

using LogDensityProblems
using LogDensityProblemsAD: LogDensityProblemsAD

import AbstractMCMC
using AbstractMCMC: LogDensityModel

import StatsBase: sample

Expand Down Expand Up @@ -139,9 +143,35 @@ include("sampler.jl")
export sample

include("abstractmcmc.jl")
export DifferentiableDensityModel

include("contrib/ad.jl")
function Hamiltonian(metric::AbstractMetric, ℓ::LogDensityModel)
ℓπ =.logdensity

# Check we're capable of computing gradients.
cap = LogDensityProblems.capabilities(ℓπ)
if cap === nothing
throw(ArgumentError("The log density function does not support the LogDensityProblems.jl interface"))
end

if cap === LogDensityProblems.LogDensityOrder{0}()
throw(ArgumentError("The gradient of the log density function is not defined: Implement `LogDensityProblems.logdensity_and_gradient` or use automatic differentiation by calling `Hamiltionian(metric, model, AD; kwargs...)` where AD is one of the backends supported by LogDensityProblemsAD.jl"))
end

return Hamiltonian(
metric,
Base.Fix1(LogDensityProblems.logdensity, ℓ.logdensity),
Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓ.logdensity),
)
end
function Hamiltonian(metric::AbstractMetric, ℓπ::LogDensityModel, kind::Union{Symbol,Val}; kwargs...)
= LogDensityModel(LogDensityProblemsAD.ADgradient(kind, ℓπ.logdensity; kwargs...))
return Hamiltonian(metric, ℓ)
end
function Hamiltonian(metric::AbstractMetric, ℓπ, kind::Union{Symbol,Val} = Val{:ForwardDiff}(); kwargs...)
= LogDensityModel(LogDensityProblemsAD.ADgradient(kind, ℓπ; kwargs...))
return Hamiltonian(metric, ℓ)
end
Hamiltonian(metric::AbstractMetric, ℓπ, m::Module; kwargs...) = Hamiltonian(metric, ℓπ, Val(Symbol(m)); kwargs...)

### Init

Expand All @@ -153,14 +183,6 @@ function __init__()
include("contrib/diffeq.jl")
end

@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
include("contrib/forwarddiff.jl")
end

@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("contrib/zygote.jl")
end

@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin
include("contrib/cuda.jl")
end
Expand Down
48 changes: 8 additions & 40 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,6 @@ struct HMCSampler{K, M, A} <: AbstractMCMC.AbstractSampler
end
HMCSampler(kernel, metric) = HMCSampler(kernel, metric, Adaptation.NoAdaptation())

"""
DifferentiableDensityModel(ℓπ, ∂ℓπ∂θ)
DifferentiableDensityModel(ℓπ, m::Module)
A `AbstractMCMC.AbstractMCMCModel` representing a differentiable log-density.
If a module `m` is given as the second argument, then `m` is assumed to be an
automatic-differentiation package and this will be used to compute the gradients.
Note that the module `m` must be imported before usage, e.g.
```julia
using Zygote: Zygote
model = DifferentiableDensityModel(ℓπ, Zygote)
```
results in a `model` which will use Zygote.jl as its AD-backend.
# Fields
$(FIELDS)
"""
struct DifferentiableDensityModel{Tlogπ, T∂logπ∂θ} <: AbstractMCMC.AbstractModel
"Log-density. Maps `AbstractArray` to value of the log-density."
ℓπ::Tlogπ
"Gradient of log-density. Returns a tuple of `ℓπ` and the gradient evaluated at the given point."
∂ℓπ∂θ::T∂logπ∂θ
end

struct DummyMetric <: AbstractMetric end
function DifferentiableDensityModel(ℓπ, m::Module)
h = Hamiltonian(DummyMetric(), ℓπ, m)
return DifferentiableDensityModel(h.ℓπ, h.∂ℓπ∂θ)
end

"""
HMCState
Expand Down Expand Up @@ -91,7 +59,7 @@ end
A convenient wrapper around `AbstractMCMC.sample` avoiding explicit construction of [`HMCSampler`](@ref).
"""
function AbstractMCMC.sample(
model::DifferentiableDensityModel,
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
Expand All @@ -103,7 +71,7 @@ end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::DifferentiableDensityModel,
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
Expand All @@ -129,7 +97,7 @@ function AbstractMCMC.sample(
end

function AbstractMCMC.sample(
model::DifferentiableDensityModel,
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
Expand All @@ -146,7 +114,7 @@ end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::DifferentiableDensityModel,
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
Expand Down Expand Up @@ -175,7 +143,7 @@ end

function AbstractMCMC.step(
rng::AbstractRNG,
model::DifferentiableDensityModel,
model::LogDensityModel,
spl::HMCSampler;
init_params = nothing,
kwargs...
Expand All @@ -189,7 +157,7 @@ function AbstractMCMC.step(
end

# Construct the hamiltonian using the initial metric
hamiltonian = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ)
hamiltonian = Hamiltonian(metric, model)

# Get an initial sample.
h, t = AdvancedHMC.sample_init(rng, hamiltonian, init_params)
Expand All @@ -203,7 +171,7 @@ end

function AbstractMCMC.step(
rng::AbstractRNG,
model::DifferentiableDensityModel,
model::LogDensityModel,
spl::HMCSampler,
state::HMCState;
nadapts::Int = 0,
Expand All @@ -220,7 +188,7 @@ function AbstractMCMC.step(
metric = state.metric

# Reconstruct hamiltonian.
h = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ)
h = Hamiltonian(metric, model)

# Make new transition.
t = transition(rng, h, κ, t_old.z)
Expand Down
18 changes: 0 additions & 18 deletions src/contrib/ad.jl

This file was deleted.

50 changes: 0 additions & 50 deletions src/contrib/forwarddiff.jl

This file was deleted.

18 changes: 0 additions & 18 deletions src/contrib/zygote.jl

This file was deleted.

6 changes: 6 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand All @@ -16,3 +19,6 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
LogDensityProblemsAD = "1.1.1"
2 changes: 1 addition & 1 deletion test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ include("common.jl")

θ_init = randn(rng, 2)

model = AdvancedHMC.DifferentiableDensityModel(ℓπ_gdemo, ForwardDiff)
model = AdvancedHMC.LogDensityModel(LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo))
init_eps = Leapfrog(1e-3)
κ = NUTS(init_eps)
metric = DiagEuclideanMetric(2)
Expand Down
9 changes: 3 additions & 6 deletions test/adaptation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ end
σ² = 1 .+ abs.(randn(D))

# Diagonal Gaussian
target = MvNormal(zeros(D), Diagonal(σ²))
ℓπ = θ -> logpdf(target, θ)
ℓπ = LogDensityDistribution(MvNormal(Diagonal(σ²)))

res = runnuts(ℓπ, DiagEuclideanMetric(D))
@test res.adaptor.pc.var σ² rtol=0.2
Expand All @@ -142,8 +141,7 @@ end
Σ = m' * m

# Correlated Gaussian
target = MvNormal(zeros(D), Σ)
ℓπ = θ -> logpdf(target, θ)
ℓπ = LogDensityDistribution(MvNormal(Σ))

res = runnuts(ℓπ, DiagEuclideanMetric(D))
@test res.adaptor.pc.var diag(Σ) rtol=0.2
Expand All @@ -156,8 +154,7 @@ end
end

@testset "Initialisation adaptor by metric" begin
target = MvNormal(zeros(D), I)
ℓπ = θ -> logpdf(target, θ)
ℓπ = LogDensityDistribution(MvNormal(Eye(D)))

mass_init = fill(0.5, D)
res = runnuts(ℓπ, DiagEuclideanMetric(mass_init); n_samples=1)
Expand Down
Loading

2 comments on commit 3dc2822

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/74630

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" 3dc2822c5a6270e6e49e8e257be4064b5570d9cc
git push origin v0.4.0

Please sign in to comment.