Skip to content

Commit

Permalink
Check model by default (#2218)
Browse files Browse the repository at this point in the history
* check model by default

* removed check_model kwargs from non-leaf method

* uncomment tests

* removed incorrect usage of check_model

* fixed IS tests

* relax gibbs tests

* Give the MH inference tests some burn-in to see if that can help

* made the MH inference tests a bit more predictable by providing
initial params

* Relaxed HMC tests a bit

---------

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai authored Jun 28, 2024
1 parent 927abcd commit cbd5d79
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 8 deletions.
13 changes: 13 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,15 @@ DynamicPPL.getlogp(t::Transition) = t.lp
# Metadata of VarInfo object
metadata(vi::AbstractVarInfo) = (lp = getlogp(vi),)

# TODO: Implement additional checks for certain samplers, e.g.
# HMC not supporting discrete parameters.
function _check_model(model::DynamicPPL.Model)
return DynamicPPL.check_model(model; error_on_failure=true)
end
function _check_model(model::DynamicPPL.Model, alg::InferenceAlgorithm)
return _check_model(model)
end

#########################################
# Default definitions for the interface #
#########################################
Expand All @@ -256,8 +265,10 @@ function AbstractMCMC.sample(
model::AbstractModel,
alg::InferenceAlgorithm,
N::Integer;
check_model::Bool=true,
kwargs...
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; kwargs...)
end

Expand All @@ -280,8 +291,10 @@ function AbstractMCMC.sample(
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
N::Integer,
n_chains::Integer;
check_model::Bool=true,
kwargs...
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(rng, model, Sampler(alg, model), ensemble, N, n_chains;
kwargs...)
end
Expand Down
22 changes: 22 additions & 0 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,28 @@ using Turing
@test all(xs[:, 1] .=== [1, missing, 3])
@test all(xs[:, 2] .=== [missing, 2, 4])
end

@testset "check model" begin
@model function demo_repeated_varname()
x ~ Normal(0, 1)
x ~ Normal(x, 1)
end

@test_throws ErrorException sample(
demo_repeated_varname(), NUTS(), 1000; check_model=true
)
# Make sure that disabling the check also works.
@test (sample(
demo_repeated_varname(), Prior(), 10; check_model=false
); true)

@model function demo_incorrect_missing(y)
y[1:1] ~ MvNormal(zeros(1), 1)
end
@test_throws ErrorException sample(
demo_incorrect_missing([missing]), NUTS(), 1000; check_model=true
)
end
end

end
4 changes: 3 additions & 1 deletion test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ using Turing.RandomMeasures: ChineseRestaurantProcess, DirichletProcess
Random.seed!(100)
alg = Gibbs(CSMC(15, :s), HMC(0.2, 4, :m; adtype=adbackend))
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.15)
check_numerical(chain, [:m], [7 / 6]; atol=0.15)
# Be more relaxed with the tolerance of the variance.
check_numerical(chain, [:s], [49 / 24]; atol=0.35)

Random.seed!(100)

Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ using Turing

# The discrepancies in the chains are in the tails, so we can't just compare the mean, etc.
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.01
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/mcmc/is.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ using Turing
ref = reference(n)

Random.seed!(seed)
chain = sample(model, alg, n)
chain = sample(model, alg, n; check_model=false)
sampled = get(chain, [:a, :b, :lp])

@test vec(sampled.a) == ref.as
Expand Down
21 changes: 16 additions & 5 deletions test/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,41 @@ GKernel(var) = (x) -> Normal(x, sqrt.(var))
# c6 = sample(gdemo_default, s6, N)
end
@testset "mh inference" begin
# Set the initial parameters, because if we get unlucky with the initial state,
# these chains are too short to converge to reasonable numbers.
discard_initial = 1000
initial_params = [1.0, 1.0]

Random.seed!(125)
alg = MH()
chain = sample(gdemo_default, alg, 10_000)
chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params)
check_gdemo(chain; atol=0.1)

Random.seed!(125)
# MH with Gaussian proposal
alg = MH((:s, InverseGamma(2, 3)), (:m, GKernel(1.0)))
chain = sample(gdemo_default, alg, 10_000)
chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params)
check_gdemo(chain; atol=0.1)

Random.seed!(125)
# MH within Gibbs
alg = Gibbs(MH(:m), MH(:s))
chain = sample(gdemo_default, alg, 10_000)
chain = sample(gdemo_default, alg, 10_000; discard_initial, initial_params)
check_gdemo(chain; atol=0.1)

Random.seed!(125)
# MoGtest
gibbs = Gibbs(
CSMC(15, :z1, :z2, :z3, :z4), MH((:mu1, GKernel(1)), (:mu2, GKernel(1)))
)
chain = sample(MoGtest_default, gibbs, 500)
check_MoGtest_default(chain; atol=0.15)
chain = sample(
MoGtest_default,
gibbs,
500;
discard_initial=100,
initial_params=[1.0, 1.0, 0.0, 0.0, 1.0, 4.0],
)
check_MoGtest_default(chain; atol=0.2)
end

# Test MH shape passing.
Expand Down

8 comments on commit cbd5d79

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yebai @mhauru is there a reason why there hasn't been a release after this and DPPL bump to 0.28 yet?

@yebai
Copy link
Member

@yebai yebai commented on cbd5d79 Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that I am aware

@mhauru
Copy link
Member

@mhauru mhauru commented on cbd5d79 Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, not that I am aware. I didn't really follow this PR.

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't mean specifically related ot this PR specifically:) E.g. the previous commit contains the DPPL bump to 0.28 (which is how I discovered that #master hadn't been released yet). Currently we can't run integration tests with Turing.jl in DPPL, etc. due to this compat bound not beingr eleased.

So it's more a question of: are there any breaking changes here? And if so, has the version entry been bumped?

@mhauru
Copy link
Member

@mhauru mhauru commented on cbd5d79 Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, not making a release after the DPPL bump was just an oversight on my part. Let me check the current master/latest-release diff.

@mhauru
Copy link
Member

@mhauru mhauru commented on cbd5d79 Jul 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhauru
Copy link
Member

@mhauru mhauru commented on cbd5d79 Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New release is out

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely; cheers @mhauru !

Please sign in to comment.