diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index d0e00b45f..15ecbc5c3 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -5,17 +5,10 @@ import_to_using = false # We ignore these files because when formatting was first put in place they were being worked on. # These ignores should be removed once the relevant PRs are merged/closed. ignore = [ - # https://github.com/TuringLang/Turing.jl/pull/2231/files + # https://github.com/TuringLang/Turing.jl/pull/2328/files "src/experimental/gibbs.jl", - "src/mcmc/abstractmcmc.jl", "test/experimental/gibbs.jl", - "test/test_utils/numerical_tests.jl", - # https://github.com/TuringLang/Turing.jl/pull/2218/files - "src/mcmc/Inference.jl", - "test/mcmc/Inference.jl", # https://github.com/TuringLang/Turing.jl/pull/1887 # Enzyme PR - "test/mcmc/Inference.jl", "test/mcmc/hmc.jl", "test/mcmc/sghmc.jl", - "test/runtests.jl", ] diff --git a/.github/ISSUE_TEMPLATE/01-bug-report.yml b/.github/ISSUE_TEMPLATE/01-bug-report.yml index 598371476..cc271e419 100644 --- a/.github/ISSUE_TEMPLATE/01-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/01-bug-report.yml @@ -8,7 +8,7 @@ body: attributes: value: | Thank you for submitting a bug report to Turing.jl! - + To make sure we can pinpoint the issue and fix it as quickly as possible, we ask you to provide some information about the bug you encountered. Please fill out the form below. - type: textarea @@ -35,7 +35,7 @@ body: description: Paste the output of `versioninfo()` between the triple backticks value: |
versioninfo() - + ``` (Paste here) ``` @@ -50,7 +50,7 @@ body: description: Paste the output of `]st --manifest` between the triple backticks. value: |
]st --manifest - + ``` (Paste here) ``` diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 8de296e5e..9416cb68a 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -79,7 +79,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 # TODO: Use julia-actions/julia-runtest when test_args are supported # Custom calls of Pkg.test tend to miss features such as e.g. adjustments for CompatHelper PRs - # Ref https://github.com/julia-actions/julia-runtest/pull/73 + # Ref https://github.com/julia-actions/julia-runtest/pull/73 - name: Call Pkg.test run: julia --color=yes --inline=yes --depwarn=yes --check-bounds=yes --threads=${{ matrix.num_threads }} --project=@. -e 'import Pkg; Pkg.test(; coverage=parse(Bool, ENV["COVERAGE"]), test_args=ARGS)' -- ${{ matrix.test-args }} - uses: julia-actions/julia-processcoverage@v1 diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index b7bdf206b..fb9d5d4ef 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -1,16 +1,33 @@ module Inference using ..Essential -using DynamicPPL: Metadata, VarInfo, TypedVarInfo, - islinked, invlink!, link!, - setindex!!, push!!, - setlogp!!, getlogp, - VarName, getsym, vectorize, - _getvns, getdist, - Model, Sampler, SampleFromPrior, SampleFromUniform, - DefaultContext, PriorContext, - LikelihoodContext, set_flag!, unset_flag!, - getspace, inspace +using DynamicPPL: + Metadata, + VarInfo, + TypedVarInfo, + islinked, + invlink!, + link!, + setindex!!, + push!!, + setlogp!!, + getlogp, + VarName, + getsym, + vectorize, + _getvns, + getdist, + Model, + Sampler, + SampleFromPrior, + SampleFromUniform, + DefaultContext, + PriorContext, + LikelihoodContext, + set_flag!, + unset_flag!, + getspace, + inspace using Distributions, Libtask, Bijectors using DistributionsAD: VectorOfMultivariate using LinearAlgebra @@ -25,8 +42,10 @@ using Accessors: Accessors import ADTypes import AbstractMCMC -import AdvancedHMC; const AHMC = AdvancedHMC -import AdvancedMH; const AMH = AdvancedMH +import AdvancedHMC +const AHMC = AdvancedHMC +import AdvancedMH +const AMH = AdvancedMH import AdvancedPS import Accessors import EllipticalSliceSampling @@ -36,35 +55,35 @@ import Random import MCMCChains import StatsBase: predict -export InferenceAlgorithm, - Hamiltonian, - StaticHamiltonian, - AdaptiveHamiltonian, - SampleFromUniform, - SampleFromPrior, - MH, - ESS, - Emcee, - Gibbs, # classic sampling - GibbsConditional, - HMC, - SGLD, - PolynomialStepsize, - SGHMC, - HMCDA, - NUTS, # Hamiltonian-like sampling - IS, - SMC, - CSMC, - PG, - Prior, - assume, - dot_assume, - observe, - dot_observe, - predict, - isgibbscomponent, - externalsampler +export InferenceAlgorithm, + Hamiltonian, + StaticHamiltonian, + AdaptiveHamiltonian, + SampleFromUniform, + SampleFromPrior, + MH, + ESS, + Emcee, + Gibbs, # classic sampling + GibbsConditional, + HMC, + SGLD, + PolynomialStepsize, + SGHMC, + HMCDA, + NUTS, # Hamiltonian-like sampling + IS, + SMC, + CSMC, + PG, + Prior, + assume, + dot_assume, + observe, + dot_observe, + predict, + isgibbscomponent, + externalsampler ####################### # Sampler abstraction # @@ -86,7 +105,8 @@ The `Unconstrained` type-parameter is to indicate whether the sampler requires u # Fields $(TYPEDFIELDS) """ -struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <: InferenceAlgorithm +struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <: + InferenceAlgorithm "the sampler to wrap" sampler::S "the automatic differentiation (AD) backend to use" @@ -105,10 +125,12 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain function ExternalSampler( sampler::AbstractSampler, adtype::ADTypes.AbstractADType, - ::Val{unconstrained}=Val(true) + ::Val{unconstrained}=Val(true), ) where {unconstrained} if !(unconstrained isa Bool) - throw(ArgumentError("Expected Val{true} or Val{false}, got Val{$unconstrained}")) + throw( + ArgumentError("Expected Val{true} or Val{false}, got Val{$unconstrained}") + ) end return new{typeof(sampler),typeof(adtype),unconstrained}(sampler, adtype) end @@ -121,7 +143,9 @@ DynamicPPL.getspace(::ExternalSampler) = () Return `true` if the sampler requires unconstrained space, and `false` otherwise. """ -requires_unconstrained_space(::ExternalSampler{<:Any,<:Any,Unconstrained}) where {Unconstrained} = Unconstrained +requires_unconstrained_space( + ::ExternalSampler{<:Any,<:Any,Unconstrained} +) where {Unconstrained} = Unconstrained """ externalsampler(sampler::AbstractSampler; adtype=AutoForwardDiff(), unconstrained=true) @@ -135,18 +159,21 @@ Wrap a sampler so it can be used as an inference algorithm. - `adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff()`: The automatic differentiation (AD) backend to use. - `unconstrained::Bool=true`: Whether the sampler requires unconstrained space. """ -function externalsampler(sampler::AbstractSampler; adtype=Turing.DEFAULT_ADTYPE, unconstrained::Bool=true) +function externalsampler( + sampler::AbstractSampler; adtype=Turing.DEFAULT_ADTYPE, unconstrained::Bool=true +) return ExternalSampler(sampler, adtype, Val(unconstrained)) end - getADType(spl::Sampler) = getADType(spl.alg) getADType(::SampleFromPrior) = Turing.DEFAULT_ADTYPE getADType(ctx::DynamicPPL.SamplingContext) = getADType(ctx.sampler) getADType(ctx::DynamicPPL.AbstractContext) = getADType(DynamicPPL.NodeTrait(ctx), ctx) getADType(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = Turing.DEFAULT_ADTYPE -getADType(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext) = getADType(DynamicPPL.childcontext(ctx)) +function getADType(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext) + return getADType(DynamicPPL.childcontext(ctx)) +end getADType(alg::Hamiltonian) = alg.adtype @@ -156,7 +183,7 @@ end function LogDensityProblems.logdensity( f::Turing.LogDensityFunction{<:AbstractVarInfo,<:Model,<:DynamicPPL.DefaultContext}, - x::NamedTuple + x::NamedTuple, ) return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x)) end @@ -166,7 +193,9 @@ function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) set_namedtuple!(deepcopy(vi), θ) return vi end -DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.logp, vi.transformation) +function DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) + return SimpleVarInfo(θ, vi.logp, vi.transformation) +end # Algorithm for sampling from the prior struct Prior <: InferenceAlgorithm end @@ -178,13 +207,13 @@ function AbstractMCMC.step( state=nothing; kwargs..., ) - vi = last(DynamicPPL.evaluate!!( - model, - VarInfo(), - SamplingContext( - rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext() - ) - )) + vi = last( + DynamicPPL.evaluate!!( + model, + VarInfo(), + SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()), + ), + ) return vi, nothing end @@ -215,10 +244,10 @@ getstats(t) = nothing abstract type AbstractTransition end -struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} <: AbstractTransition - θ :: T - lp :: F # TODO: merge `lp` with `stat` - stat :: S +struct Transition{T,F<:AbstractFloat,S<:Union{NamedTuple,Nothing}} <: AbstractTransition + θ::T + lp::F # TODO: merge `lp` with `stat` + stat::S end Transition(θ, lp) = Transition(θ, lp, nothing) @@ -231,16 +260,16 @@ end function metadata(t::Transition) stat = t.stat if stat === nothing - return (lp = t.lp,) + return (lp=t.lp,) else - return merge((lp = t.lp,), stat) + return merge((lp=t.lp,), stat) end end DynamicPPL.getlogp(t::Transition) = t.lp # Metadata of VarInfo object -metadata(vi::AbstractVarInfo) = (lp = getlogp(vi),) +metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),) # TODO: Implement additional checks for certain samplers, e.g. # HMC not supporting discrete parameters. @@ -256,10 +285,7 @@ end ######################################### function AbstractMCMC.sample( - model::AbstractModel, - alg::InferenceAlgorithm, - N::Integer; - kwargs... + model::AbstractModel, alg::InferenceAlgorithm, N::Integer; kwargs... ) return AbstractMCMC.sample(Random.default_rng(), model, alg, N; kwargs...) end @@ -270,7 +296,7 @@ function AbstractMCMC.sample( alg::InferenceAlgorithm, N::Integer; check_model::Bool=true, - kwargs... + kwargs..., ) check_model && _check_model(model, alg) return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; kwargs...) @@ -282,10 +308,11 @@ function AbstractMCMC.sample( ensemble::AbstractMCMC.AbstractMCMCEnsemble, N::Integer, n_chains::Integer; - kwargs... + kwargs..., ) - return AbstractMCMC.sample(Random.default_rng(), model, alg, ensemble, N, n_chains; - kwargs...) + return AbstractMCMC.sample( + Random.default_rng(), model, alg, ensemble, N, n_chains; kwargs... + ) end function AbstractMCMC.sample( @@ -296,11 +323,12 @@ function AbstractMCMC.sample( N::Integer, n_chains::Integer; check_model::Bool=true, - kwargs... + kwargs..., ) check_model && _check_model(model, alg) - return AbstractMCMC.sample(rng, model, Sampler(alg, model), ensemble, N, n_chains; - kwargs...) + return AbstractMCMC.sample( + rng, model, Sampler(alg, model), ensemble, N, n_chains; kwargs... + ) end function AbstractMCMC.sample( @@ -312,10 +340,19 @@ function AbstractMCMC.sample( n_chains::Integer; chain_type=MCMCChains.Chains, progress=PROGRESS[], - kwargs... + kwargs..., ) - return AbstractMCMC.mcmcsample(rng, model, sampler, ensemble, N, n_chains; - chain_type=chain_type, progress=progress, kwargs...) + return AbstractMCMC.mcmcsample( + rng, + model, + sampler, + ensemble, + N, + n_chains; + chain_type=chain_type, + progress=progress, + kwargs..., + ) end ########################## @@ -349,7 +386,6 @@ function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo) return mapreduce(collect, vcat, iters) end - function _params_to_array(model::DynamicPPL.Model, ts::Vector) names_set = OrderedSet{VarName}() # Extract the parameter names and values from each transition. @@ -364,9 +400,9 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector) return OrderedDict(zip(nms, vs)) end names = collect(names_set) - vals = [get(dicts[i], key, missing) for i in eachindex(dicts), - (j, key) in enumerate(names)] - + vals = [ + get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names) + ] return names, vals end @@ -382,7 +418,7 @@ function get_transition_extras(ts::AbstractVector) return names_values(extra_data) end -function names_values(extra_data::AbstractVector{<:NamedTuple{names}}) where names +function names_values(extra_data::AbstractVector{<:NamedTuple{names}}) where {names} values = [getfield(data, name) for data in extra_data, name in names] return collect(names), values end @@ -398,10 +434,7 @@ function names_values(xs::AbstractVector{<:NamedTuple}) names_unique = collect(names_set) # Extract all values as matrix. - values = [ - haskey(x, name) ? x[name] : missing - for x in xs, name in names_unique - ] + values = [haskey(x, name) ? x[name] : missing for x in xs, name in names_unique] return names_unique, values end @@ -416,13 +449,13 @@ function AbstractMCMC.bundle_samples( spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, state, chain_type::Type{MCMCChains.Chains}; - save_state = false, - stats = missing, - sort_chain = false, - include_varname_to_symbol = true, - discard_initial = 0, - thinning = 1, - kwargs... + save_state=false, + stats=missing, + sort_chain=false, + include_varname_to_symbol=true, + discard_initial=0, + thinning=1, + kwargs..., ) # Convert transitions to array format. # Also retrieve the variable names. @@ -443,11 +476,11 @@ function AbstractMCMC.bundle_samples( info = NamedTuple() if include_varname_to_symbol - info = merge(info, (varname_to_symbol = OrderedDict(zip(varnames, varnames_symbol)),)) + info = merge(info, (varname_to_symbol=OrderedDict(zip(varnames, varnames_symbol)),)) end if save_state - info = merge(info, (model = model, sampler = spl, samplerstate = state)) + info = merge(info, (model=model, sampler=spl, samplerstate=state)) end # Merge in the timing info, if available @@ -462,7 +495,7 @@ function AbstractMCMC.bundle_samples( chain = MCMCChains.Chains( parray, nms, - (internals = extra_params,); + (internals=extra_params,); evidence=le, info=info, start=discard_initial + 1, @@ -479,7 +512,7 @@ function AbstractMCMC.bundle_samples( spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, state, chain_type::Type{Vector{NamedTuple}}; - kwargs... + kwargs..., ) return map(ts) do t # Construct a dictionary of pairs `vn => value`. @@ -545,15 +578,13 @@ for alg in (:SMC, :PG, :MH, :IS, :ESS, :Gibbs, :Emcee) @eval DynamicPPL.getspace(::$alg{space}) where {space} = space end for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) - @eval DynamicPPL.getspace(::$alg{<:Any, space}) where {space} = space + @eval DynamicPPL.getspace(::$alg{<:Any,space}) where {space} = space end function DynamicPPL.get_matching_type( - spl::Sampler{<:Union{PG, SMC}}, - vi, - ::Type{TV}, -) where {T, N, TV <: Array{T, N}} - return Array{T, N} + spl::Sampler{<:Union{PG,SMC}}, vi, ::Type{TV} +) where {T,N,TV<:Array{T,N}} + return Array{T,N} end ############## @@ -636,32 +667,34 @@ true function predict(model::Model, chain::MCMCChains.Chains; kwargs...) return predict(Random.default_rng(), model, chain; kwargs...) end -function predict(rng::AbstractRNG, model::Model, chain::MCMCChains.Chains; include_all = false) +function predict( + rng::AbstractRNG, model::Model, chain::MCMCChains.Chains; include_all=false +) # Don't need all the diagnostics chain_parameters = MCMCChains.get_sections(chain, :parameters) spl = DynamicPPL.SampleFromPrior() # Sample transitions using `spl` conditioned on values in `chain` - transitions = transitions_from_chain(rng, model, chain_parameters; sampler = spl) + transitions = transitions_from_chain(rng, model, chain_parameters; sampler=spl) # Let the Turing internals handle everything else for you chain_result = reduce( - MCMCChains.chainscat, [ + MCMCChains.chainscat, + [ AbstractMCMC.bundle_samples( - transitions[:, chain_idx], - model, - spl, - nothing, - MCMCChains.Chains - ) for chain_idx = 1:size(transitions, 2) - ] + transitions[:, chain_idx], model, spl, nothing, MCMCChains.Chains + ) for chain_idx in 1:size(transitions, 2) + ], ) parameter_names = if include_all names(chain_result, :parameters) else - filter(k -> ∉(k, names(chain_parameters, :parameters)), names(chain_result, :parameters)) + filter( + k -> ∉(k, names(chain_parameters, :parameters)), + names(chain_result, :parameters), + ) end return chain_result[parameter_names] @@ -716,11 +749,7 @@ julia> [first(t.θ.x) for t in transitions] # extract samples for `x` [-1.704630494695469] ``` """ -function transitions_from_chain( - model::Turing.Model, - chain::MCMCChains.Chains; - kwargs... -) +function transitions_from_chain(model::Turing.Model, chain::MCMCChains.Chains; kwargs...) return transitions_from_chain(Random.default_rng(), model, chain; kwargs...) end @@ -728,7 +757,7 @@ function transitions_from_chain( rng::Random.AbstractRNG, model::Turing.Model, chain::MCMCChains.Chains; - sampler = DynamicPPL.SampleFromPrior() + sampler=DynamicPPL.SampleFromPrior(), ) vi = Turing.VarInfo(model) diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index 15ec6149c..7ee23fc7a 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -70,7 +70,9 @@ ADUtils.install_tapir && import Tapir # Smoke test for default sample call. Random.seed!(100) - chain = sample(gdemo_default, HMC(0.1, 7; adtype=adbackend), MCMCThreads(), 1000, 4) + chain = sample( + gdemo_default, HMC(0.1, 7; adtype=adbackend), MCMCThreads(), 1000, 4 + ) check_gdemo(chain) # run sampler: progress logging should be disabled and @@ -114,7 +116,7 @@ ADUtils.install_tapir && import Tapir a ~ Beta() lp1 = getlogp(__varinfo__) x[1] ~ Bernoulli(a) - global loglike = getlogp(__varinfo__) - lp1 + return global loglike = getlogp(__varinfo__) - lp1 end model = testmodel1([1.0]) varinfo = Turing.VarInfo(model) @@ -124,13 +126,17 @@ ADUtils.install_tapir && import Tapir # Test MiniBatchContext @model function testmodel2(x) a ~ Beta() - x[1] ~ Bernoulli(a) + return x[1] ~ Bernoulli(a) end model = testmodel2([1.0]) varinfo1 = Turing.VarInfo(model) varinfo2 = deepcopy(varinfo1) model(varinfo1, Turing.SampleFromPrior(), Turing.LikelihoodContext()) - model(varinfo2, Turing.SampleFromPrior(), Turing.MiniBatchContext(Turing.LikelihoodContext(), 10)) + model( + varinfo2, + Turing.SampleFromPrior(), + Turing.MiniBatchContext(Turing.LikelihoodContext(), 10), + ) @test isapprox(getlogp(varinfo2) / getlogp(varinfo1), 10) end @testset "Prior" begin @@ -141,24 +147,24 @@ ADUtils.install_tapir && import Tapir chains = sample(gdemo_d(), Prior(), N) @test chains isa MCMCChains.Chains @test size(chains) == (N, 3, 1) - @test mean(chains, :s) ≈ 3 atol=0.1 - @test mean(chains, :m) ≈ 0 atol=0.1 + @test mean(chains, :s) ≈ 3 atol = 0.1 + @test mean(chains, :m) ≈ 0 atol = 0.1 Random.seed!(100) chains = sample(gdemo_d(), Prior(), MCMCThreads(), N, 4) @test chains isa MCMCChains.Chains @test size(chains) == (N, 3, 4) - @test mean(chains, :s) ≈ 3 atol=0.1 - @test mean(chains, :m) ≈ 0 atol=0.1 + @test mean(chains, :s) ≈ 3 atol = 0.1 + @test mean(chains, :m) ≈ 0 atol = 0.1 Random.seed!(100) - chains = sample(gdemo_d(), Prior(), N; chain_type = Vector{NamedTuple}) + chains = sample(gdemo_d(), Prior(), N; chain_type=Vector{NamedTuple}) @test chains isa Vector{<:NamedTuple} @test length(chains) == N @test all(length(x) == 3 for x in chains) @test all(haskey(x, :lp) for x in chains) - @test mean(x[:s][1] for x in chains) ≈ 3 atol=0.1 - @test mean(x[:m][1] for x in chains) ≈ 0 atol=0.1 + @test mean(x[:s][1] for x in chains) ≈ 3 atol = 0.1 + @test mean(x[:m][1] for x in chains) ≈ 0 atol = 0.1 @testset "#2169" begin # Not exactly the same as the issue, but similar. @@ -178,10 +184,10 @@ ADUtils.install_tapir && import Tapir @testset "chain ordering" begin for alg in (Prior(), Emcee(10, 2.0)) - chain_sorted = sample(gdemo_default, alg, 1, sort_chain=true) + chain_sorted = sample(gdemo_default, alg, 1; sort_chain=true) @test names(MCMCChains.get_sections(chain_sorted, :parameters)) == [:m, :s] - chain_unsorted = sample(gdemo_default, alg, 1, sort_chain=false) + chain_unsorted = sample(gdemo_default, alg, 1; sort_chain=false) @test names(MCMCChains.get_sections(chain_unsorted, :parameters)) == [:s, :m] end end @@ -293,8 +299,12 @@ ADUtils.install_tapir && import Tapir @test_throws ErrorException chain = sample(gauss2(; x=x), PG(10), 10) @test_throws ErrorException chain = sample(gauss2(; x=x), SMC(), 10) - @test_throws ErrorException chain = sample(gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), PG(10), 10) - @test_throws ErrorException chain = sample(gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), SMC(), 10) + @test_throws ErrorException chain = sample( + gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), PG(10), 10 + ) + @test_throws ErrorException chain = sample( + gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), SMC(), 10 + ) @model function gauss3(x, ::Type{TV}=Vector{Float64}) where {TV} priors = TV(undef, 2) @@ -324,7 +334,9 @@ ADUtils.install_tapir && import Tapir end sample( - newinterface(obs), HMC(0.75, 3, :p, :x; adtype = Turing.AutoForwardDiff(; chunksize=2)), 100 + newinterface(obs), + HMC(0.75, 3, :p, :x; adtype=Turing.AutoForwardDiff(; chunksize=2)), + 100, ) end @testset "no return" begin @@ -528,7 +540,9 @@ ADUtils.install_tapir && import Tapir t_loop = @elapsed res = sample(vdemo1(DynamicPPL.TypeWrap{Float64}()), alg, 250) vdemo1kw(; T) = vdemo1(T) - t_loop = @elapsed res = sample(vdemo1kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250) + t_loop = @elapsed res = sample( + vdemo1kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250 + ) @model function vdemo2(::Type{T}=Float64) where {T<:Real} x = Vector{T}(undef, N) @@ -539,7 +553,9 @@ ADUtils.install_tapir && import Tapir t_vec = @elapsed res = sample(vdemo2(DynamicPPL.TypeWrap{Float64}()), alg, 250) vdemo2kw(; T) = vdemo2(T) - t_vec = @elapsed res = sample(vdemo2kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250) + t_vec = @elapsed res = sample( + vdemo2kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250 + ) @model function vdemo3(::Type{TV}=Vector{Float64}) where {TV<:AbstractVector} x = TV(undef, N) @@ -554,11 +570,7 @@ ADUtils.install_tapir && import Tapir end @testset "names_values" begin - ks, xs = Turing.Inference.names_values([ - (a=1,), - (b=2,), - (a=3, b=4) - ]) + ks, xs = Turing.Inference.names_values([(a=1,), (b=2,), (a=3, b=4)]) @test all(xs[:, 1] .=== [1, missing, 3]) @test all(xs[:, 2] .=== [missing, 2, 4]) end @@ -566,19 +578,18 @@ ADUtils.install_tapir && import Tapir @testset "check model" begin @model function demo_repeated_varname() x ~ Normal(0, 1) - x ~ Normal(x, 1) + return x ~ Normal(x, 1) end @test_throws ErrorException sample( demo_repeated_varname(), NUTS(), 1000; check_model=true ) # Make sure that disabling the check also works. - @test (sample( - demo_repeated_varname(), Prior(), 10; check_model=false - ); true) + @test (sample(demo_repeated_varname(), Prior(), 10; check_model=false); + true) @model function demo_incorrect_missing(y) - y[1:1] ~ MvNormal(zeros(1), 1) + return y[1:1] ~ MvNormal(zeros(1), 1) end @test_throws ErrorException sample( demo_incorrect_missing([missing]), NUTS(), 1000; check_model=true diff --git a/test/test_utils/numerical_tests.jl b/test/test_utils/numerical_tests.jl index c44c502c1..97d174014 100644 --- a/test/test_utils/numerical_tests.jl +++ b/test/test_utils/numerical_tests.jl @@ -5,10 +5,10 @@ using MCMCChains: namesingroup using Test: @test, @testset using HypothesisTests: HypothesisTests -export check_MoGtest_default, check_MoGtest_default_z_vector, check_dist_numerical, - check_gdemo, check_numerical +export check_MoGtest_default, + check_MoGtest_default_z_vector, check_dist_numerical, check_gdemo, check_numerical -function check_dist_numerical(dist, chn; mean_tol = 0.1, var_atol = 1.0, var_tol = 0.5) +function check_dist_numerical(dist, chn; mean_tol=0.1, var_atol=1.0, var_tol=0.5) @testset "numerical" begin # Extract values. chn_xs = Array(chn[1:2:end, namesingroup(chn, :x), :]) @@ -17,14 +17,14 @@ function check_dist_numerical(dist, chn; mean_tol = 0.1, var_atol = 1.0, var_tol dist_mean = mean(dist) mean_shape = size(dist_mean) if !all(isnan, dist_mean) && !all(isinf, dist_mean) - chn_mean = vec(mean(chn_xs, dims=1)) - chn_mean = length(chn_mean) == 1 ? - chn_mean[1] : - reshape(chn_mean, mean_shape) - atol_m = length(chn_mean) > 1 ? - mean_tol * length(chn_mean) : + chn_mean = vec(mean(chn_xs; dims=1)) + chn_mean = length(chn_mean) == 1 ? chn_mean[1] : reshape(chn_mean, mean_shape) + atol_m = if length(chn_mean) > 1 + mean_tol * length(chn_mean) + else max(mean_tol, mean_tol * chn_mean) - @test chn_mean ≈ dist_mean atol=atol_m + end + @test chn_mean ≈ dist_mean atol = atol_m end # Check variances. @@ -34,52 +34,52 @@ function check_dist_numerical(dist, chn; mean_tol = 0.1, var_atol = 1.0, var_tol dist_var = var(dist) var_shape = size(dist_var) if !all(isnan, dist_var) && !all(isinf, dist_var) - chn_var = vec(var(chn_xs, dims=1)) - chn_var = length(chn_var) == 1 ? - chn_var[1] : - reshape(chn_var, var_shape) - atol_v = length(chn_mean) > 1 ? - mean_tol * length(chn_mean) : + chn_var = vec(var(chn_xs; dims=1)) + chn_var = length(chn_var) == 1 ? chn_var[1] : reshape(chn_var, var_shape) + atol_v = if length(chn_mean) > 1 + mean_tol * length(chn_mean) + else max(mean_tol, mean_tol * chn_mean) - @test chn_mean ≈ dist_mean atol=atol_v + end + @test chn_mean ≈ dist_mean atol = atol_v end end end end # Helper function for numerical tests -function check_numerical(chain, - symbols::Vector, - exact_vals::Vector; - atol=0.2, - rtol=0.0) +function check_numerical(chain, symbols::Vector, exact_vals::Vector; atol=0.2, rtol=0.0) for (sym, val) in zip(symbols, exact_vals) - E = val isa Real ? - mean(chain[sym]) : - vec(mean(chain[sym], dims=1)) + E = val isa Real ? mean(chain[sym]) : vec(mean(chain[sym]; dims=1)) @info (symbol=sym, exact=val, evaluated=E) - @test E ≈ val atol=atol rtol=rtol + @test E ≈ val atol = atol rtol = rtol end end # Wrapper function to quickly check gdemo accuracy. function check_gdemo(chain; atol=0.2, rtol=0.0) - check_numerical(chain, [:s, :m], [49/24, 7/6], atol=atol, rtol=rtol) + return check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=atol, rtol=rtol) end # Wrapper function to check MoGtest. function check_MoGtest_default(chain; atol=0.2, rtol=0.0) - check_numerical(chain, + return check_numerical( + chain, [:z1, :z2, :z3, :z4, :mu1, :mu2], - [1.0, 1.0, 2.0, 2.0, 1.0, 4.0], - atol=atol, rtol=rtol) + [1.0, 1.0, 2.0, 2.0, 1.0, 4.0]; + atol=atol, + rtol=rtol, + ) end function check_MoGtest_default_z_vector(chain; atol=0.2, rtol=0.0) - check_numerical(chain, + return check_numerical( + chain, [Symbol("z[1]"), Symbol("z[2]"), Symbol("z[3]"), Symbol("z[4]"), :mu1, :mu2], - [1.0, 1.0, 2.0, 2.0, 1.0, 4.0], - atol=atol, rtol=rtol) + [1.0, 1.0, 2.0, 2.0, 1.0, 4.0]; + atol=atol, + rtol=rtol, + ) end """ @@ -104,9 +104,12 @@ function two_sample_test(xs_left, xs_right; α=1e-3, warn_on_fail=false) if HypothesisTests.pvalue(t) > α true else - warn_on_fail && @warn "Two-sample AD test failed with p-value $(HypothesisTests.pvalue(t))" - warn_on_fail && @warn "Means of the two samples: $(mean(xs_left)), $(mean(xs_right))" - warn_on_fail && @warn "Variances of the two samples: $(var(xs_left)), $(var(xs_right))" + warn_on_fail && + @warn "Two-sample AD test failed with p-value $(HypothesisTests.pvalue(t))" + warn_on_fail && + @warn "Means of the two samples: $(mean(xs_left)), $(mean(xs_right))" + warn_on_fail && + @warn "Variances of the two samples: $(var(xs_left)), $(var(xs_right))" false end end