Skip to content

Commit

Permalink
Merge branch 'master' into dw/enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jul 1, 2024
2 parents f4c72bd + cbd5d79 commit 6b7159c
Show file tree
Hide file tree
Showing 15 changed files with 134 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.27.1"
DynamicPPL = "0.28"
Compat = "4.15.0"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmarks_suite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,5 @@ BenchmarkSuite["mnormal"]["forwarddiff"] = @benchmarkable sample(

# ReverseDiff
BenchmarkSuite["mnormal"]["reversediff"] = @benchmarkable sample(
$(mdemo(d, 1)), $(HMC(0.1, 5; adtype=AutoReverseDiff(false))), 5000
$(mdemo(d, 1)), $(HMC(0.1, 5; adtype=AutoReverseDiff(; compile=false))), 5000
)
15 changes: 14 additions & 1 deletion src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ getADType(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext) = getADType(Dy
getADType(alg::Hamiltonian) = alg.adtype

function LogDensityProblemsAD.ADgradient(ℓ::DynamicPPL.LogDensityFunction)
return LogDensityProblemsAD.ADgradient(getADType(.context), ℓ)
return LogDensityProblemsAD.ADgradient(getADType(DynamicPPL.getcontext(ℓ)), ℓ)
end

function LogDensityProblems.logdensity(
Expand Down Expand Up @@ -238,6 +238,15 @@ DynamicPPL.getlogp(t::Transition) = t.lp
# Metadata of VarInfo object
metadata(vi::AbstractVarInfo) = (lp = getlogp(vi),)

# TODO: Implement additional checks for certain samplers, e.g.
# HMC not supporting discrete parameters.
function _check_model(model::DynamicPPL.Model)
return DynamicPPL.check_model(model; error_on_failure=true)
end
function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm)
return _check_model(model)
end

#########################################
# Default definitions for the interface #
#########################################
Expand All @@ -256,8 +265,10 @@ function AbstractMCMC.sample(
model::AbstractModel,
alg::InferenceAlgorithm,
N::Integer;
check_model::Bool=true,
kwargs...
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; kwargs...)
end

Expand All @@ -280,8 +291,10 @@ function AbstractMCMC.sample(
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
n_chains::Integer;
check_model::Bool=true,
kwargs...
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(rng, model, Sampler(alg, model), ensemble, N, n_chains;
kwargs...)
end
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple)

x_old, lj_old = vi[sampler], getlogp(vi)
set_namedtuple!(vi, x)
vi_new = last(DynamicPPL.evaluate!!(f.model, vi, f.context))
vi_new = last(DynamicPPL.evaluate!!(f.model, vi, DynamicPPL.getcontext(f)))
lj = getlogp(vi_new)

# Reset old `vi`.
Expand Down
35 changes: 34 additions & 1 deletion src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ required by Optimization.jl.
"""
function (f::OptimLogDensity)(z::AbstractVector)
varinfo = DynamicPPL.unflatten(f.varinfo, z)
return -DynamicPPL.getlogp(last(DynamicPPL.evaluate!!(f.model, varinfo, f.context)))
return -DynamicPPL.getlogp(
last(DynamicPPL.evaluate!!(f.model, varinfo, DynamicPPL.getcontext(f)))
)
end

(f::OptimLogDensity)(z, _) = f(z)
Expand Down Expand Up @@ -273,6 +275,37 @@ StatsBase.params(m::ModeResult) = StatsBase.coefnames(m)
StatsBase.vcov(m::ModeResult) = inv(StatsBase.informationmatrix(m))
StatsBase.loglikelihood(m::ModeResult) = m.lp

"""
Base.get(m::ModeResult, var_symbol::Symbol)
Base.get(m::ModeResult, var_symbols)
Return the values of all the variables with the symbol(s) `var_symbol` in the mode result
`m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second
argument should be either a `Symbol` or an iterator of `Symbol`s.
"""
function Base.get(m::ModeResult, var_symbols)
log_density = m.f
# Get all the variable names in the model. This is the same as the list of keys in
# m.values, but they are more convenient to filter when they are VarNames rather than
# Symbols.
varnames = collect(
map(first, Turing.Inference.getparams(log_density.model, log_density.varinfo))
)
# For each symbol s in var_symbols, pick all the values from m.values for which the
# variable name has that symbol.
et = eltype(m.values)
value_vectors = Vector{et}[]
for s in var_symbols
push!(
value_vectors,
[m.values[Symbol(vn)] for vn in varnames if DynamicPPL.getsym(vn) == s],
)
end
return (; zip(var_symbols, value_vectors)...)
end

Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,))

"""
ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
Enzyme = "0.12"
DynamicPPL = "0.27"
DynamicPPL = "0.28"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
HypothesisTests = "0.11"
Expand Down
12 changes: 7 additions & 5 deletions test/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,22 @@ end
return theta ~ Dirichlet(1 ./ fill(4, 4))
end
sample(dir(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(false)), 1000)
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(true)), 1000)
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
sample(dir(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=true)), 1000)
end
@testset "PDMatDistribution AD" begin
@model function wishart()
return theta ~ Wishart(4, Matrix{Float64}(I, 4, 4))
end

sample(wishart(), HMC(0.01, 1; adtype=AutoReverseDiff(false)), 1000)
sample(wishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
sample(wishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000)

@model function invwishart()
return theta ~ InverseWishart(4, Matrix{Float64}(I, 4, 4))
end

sample(invwishart(), HMC(0.01, 1; adtype=AutoReverseDiff(false)), 1000)
sample(invwishart(), HMC(0.01, 1; adtype=AutoReverseDiff(; compile=false)), 1000)
sample(invwishart(), HMC(0.01, 1; adtype=AutoZygote()), 1000)
end
@testset "Hessian test" begin
Expand Down Expand Up @@ -231,7 +231,9 @@ end
for i in 1:5
d = Normal(0.0, i)
data = rand(d, N)
chn = sample(demo(data), NUTS(0.65; adtype=AutoReverseDiff(true)), 1000)
chn = sample(
demo(data), NUTS(0.65; adtype=AutoReverseDiff(; compile=true)), 1000
)
@test mean(Array(chn[:sigma])) std(data) atol = 0.5
end
end
Expand Down
25 changes: 23 additions & 2 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ import ReverseDiff
using Test: @test, @test_throws, @testset
using Turing

# Disable Enzyme warnings
Enzyme.API.typeWarning!(false)

# Enable runtime activity (workaround)
Enzyme.API.runtimeActivity!(true)

# @testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
# @testset "Testing inference.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
@testset "Testing inference.jl with $adbackend" for adbackend in (AutoEnzyme(),)
# Only test threading if 1.3+.
if VERSION > v"1.2"
Expand Down Expand Up @@ -570,6 +569,28 @@ Enzyme.API.runtimeActivity!(true)
@test all(xs[:, 1] .=== [1, missing, 3])
@test all(xs[:, 2] .=== [missing, 2, 4])
end

@testset "check model" begin
@model function demo_repeated_varname()
x ~ Normal(0, 1)
x ~ Normal(x, 1)
end

@test_throws ErrorException sample(
demo_repeated_varname(), NUTS(), 1000; check_model=true
)
# Make sure that disabling the check also works.
@test (sample(
demo_repeated_varname(), Prior(), 10; check_model=false
); true)

@model function demo_incorrect_missing(y)
y[1:1] ~ MvNormal(zeros(1), 1)
end
@test_throws ErrorException sample(
demo_incorrect_missing([missing]), NUTS(), 1000; check_model=true
)
end
end

end
6 changes: 4 additions & 2 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using Turing: Inference
using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess

@testset "Testing gibbs.jl with $adbackend" for adbackend in (
AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
)
@testset "gibbs constructor" begin
N = 500
Expand Down Expand Up @@ -50,7 +50,9 @@ using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess
Random.seed!(100)
alg = Gibbs(CSMC(15, :s), HMC(0.2, 4, :m; adtype=adbackend))
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.15)
check_numerical(chain, [:m], [7 / 6]; atol=0.15)
# Be more relaxed with the tolerance of the variance.
check_numerical(chain, [:s], [49 / 24]; atol=0.35)

Random.seed!(100)

Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/gibbs_conditional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Test: @test, @testset
using Turing

@testset "Testing gibbs conditionals.jl with $adbackend" for adbackend in (
AutoForwardDiff(; chunksize=0), AutoReverseDiff(false)
AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false)
)
Random.seed!(1000)
rng = StableRNG(123)
Expand Down
4 changes: 2 additions & 2 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Enzyme.API.typeWarning!(false)
# Enable runtime activity (workaround)
Enzyme.API.runtimeActivity!(true)

# @testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
# @testset "Testing hmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
@testset "Testing hmc.jl with $adbackend" for adbackend in (AutoEnzyme(),)
# Set a seed
rng = StableRNG(123)
Expand Down Expand Up @@ -327,7 +327,7 @@ Enzyme.API.runtimeActivity!(true)

# The discrepancies in the chains are in the tails, so we can't just compare the mean, etc.
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.01
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/is.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ using Turing
ref = reference(n)

Random.seed!(seed)
chain = sample(model, alg, n)
chain = sample(model, alg, n; check_model=false)
sampled = get(chain, [:a, :b, :lp])

@test vec(sampled.a) == ref.as
Expand Down
21 changes: 16 additions & 5 deletions test/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,41 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var))
# c6 = sample(gdemo_default, s6, N)
end
@testset "mh inference" begin
# Set the initial parameters, because if we get unlucky with the initial state,
# these chains are too short to converge to reasonable numbers.
discard_initial = 1000
initial_params = [1.0, 1.0]

Random.seed!(125)
alg = MH()
chain = sample(gdemo_default, alg, 10_000)
chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params)
check_gdemo(chain; atol=0.1)

Random.seed!(125)
# MH with Gaussian proposal
alg = MH((:s, InverseGamma(2, 3)), (:m, GKernel(1.0)))
chain = sample(gdemo_default, alg, 10_000)
chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params)
check_gdemo(chain; atol=0.1)

Random.seed!(125)
# MH within Gibbs
alg = Gibbs(MH(:m), MH(:s))
chain = sample(gdemo_default, alg, 10_000)
chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params)
check_gdemo(chain; atol=0.1)

Random.seed!(125)
# MoGtest
gibbs = Gibbs(
CSMC(15, :z1, :z2, :z3, :z4), MH((:mu1, GKernel(1)), (:mu2, GKernel(1)))
)
chain = sample(MoGtest_default, gibbs, 500)
check_MoGtest_default(chain; atol=0.15)
chain = sample(
MoGtest_default,
gibbs,
500;
discard_initial=100,
initial_params=[1.0, 1.0, 0.0, 0.0, 1.0, 4.0],
)
check_MoGtest_default(chain; atol=0.2)
end

# Test MH shape passing.
Expand Down
4 changes: 2 additions & 2 deletions test/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Enzyme.API.typeWarning!(false)
# Enable runtime activity (workaround)
Enzyme.API.runtimeActivity!(true)

# @testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
# @testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
@testset "Testing sghmc.jl with $adbackend" for adbackend in (AutoEnzyme(),)
@testset "sghmc constructor" begin
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=adbackend)
Expand All @@ -44,7 +44,7 @@ Enzyme.API.runtimeActivity!(true)
end
end

# @testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(false))
# @testset "Testing sgld.jl with $adbackend" for adbackend in (AutoForwardDiff(; chunksize=0), AutoReverseDiff(; compile=false))
@testset "Testing sgld.jl with $adbackend" for adbackend in (AutoEnzyme(),)
@testset "sgld constructor" begin
alg = SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adbackend)
Expand Down
27 changes: 26 additions & 1 deletion test/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ using ..Models: gdemo, gdemo_default
using Distributions
using Distributions.FillArrays: Zeros
using DynamicPPL: DynamicPPL
using LinearAlgebra: I
using LinearAlgebra: Diagonal, I
using Random: Random
using Optimization
using Optimization: Optimization
using OptimizationBBO: OptimizationBBO
using OptimizationNLopt: OptimizationNLopt
using OptimizationOptimJL: OptimizationOptimJL
using ReverseDiff: ReverseDiff
using StatsBase: StatsBase
using StatsBase: coef, coefnames, coeftable, informationmatrix, stderror, vcov
using Test: @test, @testset, @test_throws
Expand Down Expand Up @@ -591,6 +592,30 @@ using Turing
@test result.values[:x] 0 atol = 1e-1
@test result.values[:y] 100 atol = 1e-1
end

@testset "get ModeResult" begin
@model function demo_model(N)
half_N = N ÷ 2
a ~ arraydist(LogNormal.(fill(0, half_N), 1))
b ~ arraydist(LogNormal.(fill(0, N - half_N), 1))
covariance_matrix = Diagonal(vcat(a, b))
x ~ MvNormal(covariance_matrix)
return nothing
end

N = 12
m = demo_model(N) | (x=randn(N),)
result = maximum_a_posteriori(m)
get_a = get(result, :a)
get_b = get(result, :b)
get_ab = get(result, [:a, :b])
@assert keys(get_a) == (:a,)
@assert keys(get_b) == (:b,)
@assert keys(get_ab) == (:a, :b)
@assert get_b[:b] == get_ab[:b]
@assert vcat(get_a[:a], get_b[:b]) == result.values.array
@assert get(result, :c) == (; :c => Array{Float64}[])
end
end

end

0 comments on commit 6b7159c

Please sign in to comment.