Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LogDensityProblems.jl #301

Merged
merged 24 commits into from
Dec 25, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6326691
removed AD compats and added implementation of LogDensityProblems
torfjelde Nov 17, 2022
364dfb4
improved README a bit
torfjelde Nov 17, 2022
4e6301b
added LogDensityProblems to test deps
torfjelde Nov 17, 2022
7dea4db
updated tests
torfjelde Nov 17, 2022
9300af9
Apply suggestions from code review
torfjelde Nov 17, 2022
c9a6792
fixed tests
torfjelde Nov 19, 2022
ee10ad3
depend on LogDensityProblemsAD and new version of AbstractMCMC
torfjelde Dec 9, 2022
0096c4d
Merge branch 'torfjelde/logdensityproblems' of github.com:TuringLang/…
torfjelde Dec 9, 2022
1f918e8
removed unnecessary comment and export of no-longer-existing model
torfjelde Dec 17, 2022
15c41df
updated to work with new AbstractMCMC
torfjelde Dec 17, 2022
89de349
Merge branch 'master' into torfjelde/logdensityproblems
torfjelde Dec 19, 2022
574427f
bump compat entry for AbstractMCMC
torfjelde Dec 19, 2022
afaea4f
fixed bug in tests
torfjelde Dec 20, 2022
4532725
added compat entries for LogDensityProblems and AD
torfjelde Dec 20, 2022
c7afe42
removed type-piracy in tests
torfjelde Dec 20, 2022
7d88cee
Apply suggestions from code review
torfjelde Dec 23, 2022
3b009cf
added FillArrays to tests
torfjelde Dec 23, 2022
b655d50
added manual specification of the gradientconfig when working with Co…
torfjelde Dec 23, 2022
aa0c11a
Apply suggestions from code review
torfjelde Dec 25, 2022
0f7bbec
updated tests to work with LogDensityProblemsAD
torfjelde Dec 25, 2022
5cd6f4a
remove usage of Distributions in README
torfjelde Dec 25, 2022
0fff007
added a TODO
torfjelde Dec 25, 2022
ff14099
Update src/AdvancedHMC.jl
torfjelde Dec 25, 2022
b683501
fixed missing end
torfjelde Dec 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InplaceOps = "505f98c9-085e-5b2c-8e89-488be7bf1f34"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -18,10 +20,12 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
AbstractMCMC = "3.2, 4"
AbstractMCMC = "4.2"
ArgCheck = "1, 2"
DocStringExtensions = "0.8, 0.9"
InplaceOps = "0.3"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
ProgressMeter = "1"
Requires = "0.5, 1"
Setfield = "0.7, 0.8, 1"
Expand Down
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,20 @@ In this section we demonstrate a minimal example of sampling from a multivariate
</details>

```julia
using AdvancedHMC, Distributions, ForwardDiff
using AdvancedHMC, ForwardDiff
using LogDensityProblems
using LinearAlgebra

# Define the target distribution using the `LogDensityProblem` interface
struct LogTargetDensity
dim::Int
end
LogDensityProblems.logdensity(p::LogTargetDensity, θ) = -sum(abs2, θ) / 2 # standard multivariate normal
LogDensityProblems.dimension(p::LogTargetDensity) = p.dim

# Choose parameter dimensionality and initial parameter value
D = 10; initial_θ = rand(D)

# Define the target distribution
ℓπ(θ) = logpdf(MvNormal(zeros(D), I), θ)
ℓπ = LogTargetDensity(D)

# Set the number of samples to draw and warmup iterations
n_samples, n_adapts = 2_000, 1_000
Expand Down
28 changes: 18 additions & 10 deletions src/AdvancedHMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ using ArgCheck: @argcheck

using DocStringExtensions

using LogDensityProblems
using LogDensityProblemsAD: LogDensityProblemsAD

import AbstractMCMC
using AbstractMCMC: LogDensityModel

import StatsBase: sample

Expand Down Expand Up @@ -139,9 +143,21 @@ include("sampler.jl")
export sample

include("abstractmcmc.jl")
export DifferentiableDensityModel

include("contrib/ad.jl")
Hamiltonian(metric::AbstractMetric, ℓ::LogDensityModel) = Hamiltonian(
metric,
Base.Fix1(LogDensityProblems.logdensity, ℓ.logdensity),
Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓ.logdensity),
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
function Hamiltonian(metric::AbstractMetric, ℓπ::LogDensityModel, kind::Union{Symbol,Val}; kwargs...)
ℓ = LogDensityModel(LogDensityProblemsAD.ADgradient(kind, ℓπ.logdensity; kwargs...))
return Hamiltonian(metric, ℓ)
end
function Hamiltonian(metric::AbstractMetric, ℓπ, kind::Union{Symbol,Val} = Val{:ForwardDiff}(); kwargs...)
ℓ = LogDensityModel(LogDensityProblemsAD.ADgradient(kind, ℓπ; kwargs...))
return Hamiltonian(metric, ℓ)
end
Hamiltonian(metric::AbstractMetric, ℓπ, m::Module; kwargs...) = Hamiltonian(metric, ℓπ, Val(Symbol(m)); kwargs...)

### Init

Expand All @@ -153,14 +169,6 @@ function __init__()
include("contrib/diffeq.jl")
end

@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
include("contrib/forwarddiff.jl")
end

@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("contrib/zygote.jl")
end

@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin
include("contrib/cuda.jl")
end
Expand Down
48 changes: 8 additions & 40 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,6 @@ struct HMCSampler{K, M, A} <: AbstractMCMC.AbstractSampler
end
HMCSampler(kernel, metric) = HMCSampler(kernel, metric, Adaptation.NoAdaptation())

"""
DifferentiableDensityModel(ℓπ, ∂ℓπ∂θ)
DifferentiableDensityModel(ℓπ, m::Module)

A `AbstractMCMC.AbstractMCMCModel` representing a differentiable log-density.

If a module `m` is given as the second argument, then `m` is assumed to be an
automatic-differentiation package and this will be used to compute the gradients.

Note that the module `m` must be imported before usage, e.g.
```julia
using Zygote: Zygote
model = DifferentiableDensityModel(ℓπ, Zygote)
```
results in a `model` which will use Zygote.jl as its AD-backend.

# Fields
$(FIELDS)
"""
struct DifferentiableDensityModel{Tlogπ, T∂logπ∂θ} <: AbstractMCMC.AbstractModel
"Log-density. Maps `AbstractArray` to value of the log-density."
ℓπ::Tlogπ
"Gradient of log-density. Returns a tuple of `ℓπ` and the gradient evaluated at the given point."
∂ℓπ∂θ::T∂logπ∂θ
end

struct DummyMetric <: AbstractMetric end
function DifferentiableDensityModel(ℓπ, m::Module)
h = Hamiltonian(DummyMetric(), ℓπ, m)
return DifferentiableDensityModel(h.ℓπ, h.∂ℓπ∂θ)
end

"""
HMCState

Expand Down Expand Up @@ -91,7 +59,7 @@ end
A convenient wrapper around `AbstractMCMC.sample` avoiding explicit construction of [`HMCSampler`](@ref).
"""
function AbstractMCMC.sample(
model::DifferentiableDensityModel,
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
Expand All @@ -103,7 +71,7 @@ end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::DifferentiableDensityModel,
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
Expand All @@ -129,7 +97,7 @@ function AbstractMCMC.sample(
end

function AbstractMCMC.sample(
model::DifferentiableDensityModel,
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
Expand All @@ -146,7 +114,7 @@ end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::DifferentiableDensityModel,
model::LogDensityModel,
kernel::AbstractMCMCKernel,
metric::AbstractMetric,
adaptor::AbstractAdaptor,
Expand Down Expand Up @@ -175,7 +143,7 @@ end

function AbstractMCMC.step(
rng::AbstractRNG,
model::DifferentiableDensityModel,
model::LogDensityModel,
spl::HMCSampler;
init_params = nothing,
kwargs...
Expand All @@ -189,7 +157,7 @@ function AbstractMCMC.step(
end

# Construct the hamiltonian using the initial metric
hamiltonian = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ)
hamiltonian = Hamiltonian(metric, model)

# Get an initial sample.
h, t = AdvancedHMC.sample_init(rng, hamiltonian, init_params)
Expand All @@ -203,7 +171,7 @@ end

function AbstractMCMC.step(
rng::AbstractRNG,
model::DifferentiableDensityModel,
model::LogDensityModel,
spl::HMCSampler,
state::HMCState;
nadapts::Int = 0,
Expand All @@ -220,7 +188,7 @@ function AbstractMCMC.step(
metric = state.metric

# Reconstruct hamiltonian.
h = Hamiltonian(metric, model.ℓπ, model.∂ℓπ∂θ)
h = Hamiltonian(metric, model)

# Make new transition.
t = transition(rng, h, κ, t_old.z)
Expand Down
18 changes: 0 additions & 18 deletions src/contrib/ad.jl

This file was deleted.

50 changes: 0 additions & 50 deletions src/contrib/forwarddiff.jl

This file was deleted.

18 changes: 0 additions & 18 deletions src/contrib/zygote.jl

This file was deleted.

6 changes: 6 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand All @@ -16,3 +19,6 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
LogDensityProblemsAD = "1.1.1"
2 changes: 1 addition & 1 deletion test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ include("common.jl")

θ_init = randn(rng, 2)

model = AdvancedHMC.DifferentiableDensityModel(ℓπ_gdemo, ForwardDiff)
model = AdvancedHMC.LogDensityModel(LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo))
init_eps = Leapfrog(1e-3)
κ = NUTS(init_eps)
metric = DiagEuclideanMetric(2)
Expand Down
9 changes: 3 additions & 6 deletions test/adaptation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ end
σ² = 1 .+ abs.(randn(D))

# Diagonal Gaussian
target = MvNormal(zeros(D), Diagonal(σ²))
ℓπ = θ -> logpdf(target, θ)
ℓπ = LogDensityDistribution(MvNormal(Diagonal(σ²)))

res = runnuts(ℓπ, DiagEuclideanMetric(D))
@test res.adaptor.pc.var ≈ σ² rtol=0.2
Expand All @@ -142,8 +141,7 @@ end
Σ = m' * m

# Correlated Gaussian
target = MvNormal(zeros(D), Σ)
ℓπ = θ -> logpdf(target, θ)
ℓπ = LogDensityDistribution(MvNormal(Σ))

res = runnuts(ℓπ, DiagEuclideanMetric(D))
@test res.adaptor.pc.var ≈ diag(Σ) rtol=0.2
Expand All @@ -156,8 +154,7 @@ end
end

@testset "Initialisation adaptor by metric" begin
target = MvNormal(zeros(D), I)
ℓπ = θ -> logpdf(target, θ)
ℓπ = LogDensityDistribution(MvNormal(Eye(D)))

mass_init = fill(0.5, D)
res = runnuts(ℓπ, DiagEuclideanMetric(mass_init); n_samples=1)
Expand Down
Loading