diff --git a/Project.toml b/Project.toml index 1a3cedec..59a8707f 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "4.1.0" +version = "4.2.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -14,6 +14,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -33,6 +34,7 @@ Compat = "2.2, 3" Distributions = "0.21, 0.22, 0.23" Formatting = "0.4" IteratorInterfaceExtensions = "0.1.1, 1" +MLJModelInterface = "0.3.5" NaturalSort = "1" PrettyTables = "0.9" RecipesBase = "0.7, 0.8, 1.0" @@ -47,9 +49,12 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" +MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" +XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9" [targets] -test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots"] +test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots", "MLJ", "MLJModels", "XGBoost"] diff --git a/README.md b/README.md index ea9d3f15..fe9a4bef 100644 --- a/README.md +++ b/README.md @@ -225,6 +225,32 @@ heideldiag(c::Chains; alpha=0.05, eps=0.1, etype=:imse) rafterydiag(c::Chains; q=0.025, r=0.005, s=0.95, eps=0.001) ``` +#### Rstar Diagnostic +Rstar diagnostic described in [https://arxiv.org/pdf/2003.07900.pdf](https://arxiv.org/pdf/2003.07900.pdf). +Note that the use requires MLJ and MLJModels to be installed. + +Usage: + +```julia +using MLJ, MLJModels + +chn ... # sampling results of multiple chains + +# select classifier used to compute the diagnostic +classif = @load XGBoostClassifier + +# estimate diagnostic +Rs = rstar(chn, classif) +R = mean(Rs) + +# visualize distribution +using Plots +histogram(Rs) +``` + +See `? rstar` for more details. + + ### Model Selection #### Deviance Information Criterion (DIC) ```julia diff --git a/src/MCMCChains.jl b/src/MCMCChains.jl index 2d1e500a..27a24a49 100644 --- a/src/MCMCChains.jl +++ b/src/MCMCChains.jl @@ -12,7 +12,7 @@ using SpecialFunctions using Formatting import StatsBase: autocov, counts, sem, AbstractWeights, autocor, describe, quantile, sample, summarystats, cov - +import MLJModelInterface import NaturalSort import PrettyTables import Tables @@ -36,6 +36,8 @@ export summarize export discretediag, gelmandiag, gewekediag, heideldiag, rafterydiag export hpd, ess +export rstar + export ESSMethod, FFTESSMethod, BDAESSMethod """ @@ -73,5 +75,6 @@ include("stats.jl") include("modelstats.jl") include("plot.jl") include("tables.jl") +include("rstar.jl") end # module diff --git a/src/rstar.jl b/src/rstar.jl new file mode 100644 index 00000000..609dd5e0 --- /dev/null +++ b/src/rstar.jl @@ -0,0 +1,96 @@ +""" + rstar([rng ,] classif::Supervised, chains::Chains; kwargs...) + rstar([rng ,] classif::Supervised, x::AbstractMatrix, y::AbstractVector; kwargs...) + +Compute the R* convergence diagnostic of MCMC. + +This implementation is an adaption of Algorithm 1 & 2, described in [Lambert & Vehtari]. Note that the correctness of the statistic depends on the convergence of the classifier used internally in the statistic. You can track if the training of the classifier converged by inspection of the printed RMSE values from the XGBoost backend. To adjust the number of iterations used to train the classifier set `niter` accordingly. + +# Keyword Arguments +* `subset = 0.8` ... Subset used to train the classifier, i.e. 0.8 implies 80% of the samples are used. +* `iterations = 10` ... Number of iterations used to estimate the statistic. If the classifier is not probabilistic, i.e. does not return class probabilities, it is advisable to use a value of one. +* `verbosity = 0` ... Verbosity level used during fitting of the classifier. + +# Usage + +```julia +using MLJ, MLJModels +# You need to load MLJBase and the respective package your are using for classification first. + +# Select a classifier to compute the Rstar statistic. +# For example the XGBoost classifier. +classif = @load XGBoostClassifier() + +# Compute 100 samples of the R* statistic using sampling from according to the prediction probabilities. +Rs = rstar(classif, chn, iterations = 20) + +# estimate Rstar +R = mean(Rs) + +# visualize distribution +histogram(Rs) +``` + +## References: +[Lambert & Vehtari] Ben Lambert and Aki Vehtari. "R∗: A robust MCMC convergence diagnostic with uncertainty using gradient-boostined machines." Arxiv 2020. +""" +function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector{Int}; iterations = 10, subset = 0.8, verbosity = 0) + + size(x,1) != length(y) && throw(DimensionMismatch()) + iterations >= 1 && ArgumentError("Number of iterations has to be positive!") + + if iterations > 1 && classif isa MLJModelInterface.Deterministic + @warn("Classifier is not a probabilistic classifier but number of iterations is > 1.") + elseif iterations == 1 && classif isa MLJModelInterface.Probabilistic + @warn("Classifier is probabilistic but number of iterations is equal to one.") + end + + N = length(y) + K = length(unique(y)) + + # randomly sub-select training and testing set + Ntrain = round(Int, N*subset) + Ntest = N - Ntrain + + ids = Random.randperm(rng, N) + train_ids = view(ids, 1:Ntrain) + test_ids = view(ids, (Ntrain+1):N) + + # train classifier using XGBoost + fitresult, _ = MLJModelInterface.fit(classif, verbosity, Tables.table(x[train_ids,:]), MLJModelInterface.categorical(y[train_ids])) + + xtest = Tables.table(x[test_ids,:]) + ytest = view(y, test_ids) + + Rstats = map(i -> K*rstar_score(rng, classif, fitresult, xtest, ytest), 1:iterations) + return Rstats +end + +function rstar(classif::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector{Int}; kwargs...) + rstar(Random.GLOBAL_RNG, classif, x, y; kwargs...) +end + +function rstar(classif::MLJModelInterface.Supervised, chn::Chains; kwargs...) + return rstar(Random.GLOBAL_RNG, classif, chn; kwargs...) +end + +function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, chn::Chains; kwargs...) + nchains = size(chn, 3) + nchains <= 1 && throw(DimensionMismatch()) + + # collect data + x = Array(chn) + y = repeat(chains(chn); inner = size(chn,1)) + + return rstar(rng, classif, x, y; kwargs...) +end + +function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Probabilistic, fitresult, xtest, ytest) + pred = get.(rand.(Ref(rng), MLJModelInterface.predict(classif, fitresult, xtest))) + return mean(((p,y),) -> p == y, zip(pred, ytest)) +end + +function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Deterministic, fitresult, xtest, ytest) + pred = MLJModelInterface.predict(classif, fitresult, xtest) + return mean(((p,y),) -> p == y, zip(pred, ytest)) +end diff --git a/test/diagnostic_tests.jl b/test/diagnostic_tests.jl index 1ab921e2..2d455a64 100644 --- a/test/diagnostic_tests.jl +++ b/test/diagnostic_tests.jl @@ -82,7 +82,7 @@ end end @testset "function tests" begin - tchain = Chains(rand(n_iter, n_name, n_chain), ["a", "b", "c"], Dict(:internals => ["c"])) + tchain = Chains(rand(niter, nparams, nchains), ["a", "b", "c"], Dict(:internals => ["c"])) # the following tests only check if the function calls work! @test MCMCChains.diag_all(rand(50, 2), :weiss, 1, 1, 1) != nothing @@ -137,9 +137,10 @@ end end @testset "sorting" begin - chn_unsorted = Chains(rand(100,3,1), ["2", "1", "3"]) + chn_unsorted = Chains(rand(100, nparams, 1), ["2", "1", "3"]) chn_sorted = sort(chn_unsorted) @test names(chn_sorted) == Symbol.([1, 2, 3]) @test names(chn_unsorted) == Symbol.([2, 1, 3]) end + diff --git a/test/rstar_tests.jl b/test/rstar_tests.jl new file mode 100644 index 00000000..4780d260 --- /dev/null +++ b/test/rstar_tests.jl @@ -0,0 +1,37 @@ +using MCMCChains +using Tables +using MLJ, MLJModels +using Test + +N = 1000 +val = rand(N, 8, 4) +colnames = ["a", "b", "c", "d", "e", "f", "g", "h"] +internal_colnames = ["c", "d", "e", "f", "g", "h"] +chn = Chains(val, colnames, Dict(:internals => internal_colnames)) + +classif = @load XGBoostClassifier() + +@testset "R star test" begin + + # Compute R* statistic for a mixed chain. + R = rstar(classif, randn(N,2), rand(1:3,N)) + + # Resulting R value should be close to one, i.e. the classifier does not perform better than random guessing. + @test mean(R) ≈ 1 atol=0.1 + + # Compute R* statistic for a mixed chain. + R = rstar(classif, chn) + + # Resulting R value should be close to one, i.e. the classifier does not perform better than random guessing. + @test mean(R) ≈ 1 atol=0.1 + + # Compute R* statistic for a non-mixed chain. + niter = 1000 + val = hcat(sin.(1:niter), cos.(1:niter)) + val = cat(val, hcat(cos.(1:niter)*100, sin.(1:niter)*100), dims=3) + chn_notmixed = Chains(val) + + # Restuling R value should be close to two, i.e. the classifier should be able to learn an almost perfect decision boundary between chains. + R = rstar(classif, chn_notmixed) + @test mean(R) ≈ 2 atol=0.1 +end diff --git a/test/runtests.jl b/test/runtests.jl index 90f48922..eed51cc4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,14 @@ using Test @testset "MCMCChains" begin + + # run tests related to rstar statistic + println("Rstar") + @time include("rstar_tests.jl") + # run tests for effective sample size - include("ess_tests.jl") + println("ESS") + @time include("ess_tests.jl") # run plotting tests println("Plotting")