Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding implementation of Rstar statistic #238

Merged
merged 20 commits into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -47,9 +48,12 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"

[targets]
test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots"]
test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots", "MLJ", "MLJModels", "XGBoost"]
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,32 @@ heideldiag(c::Chains; alpha=0.05, eps=0.1, etype=:imse)
rafterydiag(c::Chains; q=0.025, r=0.005, s=0.95, eps=0.001)
```

#### Rstar Diagnostic
Rstar diagnostic described in [https://arxiv.org/pdf/2003.07900.pdf](https://arxiv.org/pdf/2003.07900.pdf).
Note that the use requires MLJ and MLJModels to be installed.

Usage:

```julia
using MLJ, MLJModels

chn ... # sampling results of multiple chains

# select classifier used to compute the diagnostic
classif = @load XGBoostClassifier

# estimate diagnostic
Rs = rstar(chn, classif)
R = mean(Rs)

# visualize distribution
using Plots
histogram(Rs)
```

See `? rstar` for more details.


### Model Selection
#### Deviance Information Criterion (DIC)
```julia
Expand Down
5 changes: 4 additions & 1 deletion src/MCMCChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using SpecialFunctions
using Formatting
import StatsBase: autocov, counts, sem, AbstractWeights,
autocor, describe, quantile, sample, summarystats, cov

using MLJModelInterface
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
import NaturalSort
import PrettyTables
import Tables
Expand All @@ -36,6 +36,8 @@ export summarize
export discretediag, gelmandiag, gewekediag, heideldiag, rafterydiag
export hpd, ess

export rstar

export ESSMethod, FFTESSMethod, BDAESSMethod

"""
Expand Down Expand Up @@ -73,5 +75,6 @@ include("stats.jl")
include("modelstats.jl")
include("plot.jl")
include("tables.jl")
include("rstar.jl")

end # module
91 changes: 91 additions & 0 deletions src/rstar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
rstar(chains::Chains, model::MLJModelInterface.Supervised; subset = 0.8, iterations = 10, verbosity = 0)
rstar(rng::Random.AbstractRNG, chains::Chains, model::MLJModelInterface.Supervised; subset = 0.8, iterations = 10, verbosity = 0)
rstar(rng::Random.AbstractRNG, model::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector; subset = 0.8, iterations = 10, verbosity = 0)

Compute the R* statistic for convergence diagnostic of MCMC. This implementation is an adaption of Algorithm 1 & 2, described in [Lambert & Vehtari]. Note that the correctness of the statistic depends on the convergence of the classifier used internally in the statistic. You can track if the training of the classifier converged by inspection of the printed RMSE values from the XGBoost backend. To adjust the number of iterations used to train the classifier set `niter` accordingly.
trappmartin marked this conversation as resolved.
Show resolved Hide resolved

# Keyword Arguments
* `subset = 0.8` ... Subset used to train the classifier, i.e. 0.8 implies 80% of the samples are used.
* `iterations = 10` ... Number of iterations used to estimate the statistic. If the classifier is not probabilistic, i.e. does not return class probabilities, it is advisable to use a value of one.
* `verbosity = 0` ... Verbosity level used during fitting of the classifier.

# Usage

```julia-repl
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
using MLJ, MLJModels
# You need to load MLJBase and the respective package your are using for classification first.

# Select a classification model to compute the Rstar statistic.
# For example the XGBoost classifier.
model = @load XGBoostClassifier()

# Compute 100 samples of the R* statistic using sampling from according to the prediction probabilities.
Rs = rstar(chn, model, iterations = 20)

# estimate Rstar
R = mean(Rs)

# visualize distribution
histogram(Rs)
```

## References:
[Lambert & Vehtari] Ben Lambert and Aki Vehtari. "R∗: A robust MCMC convergence diagnostic with uncertainty using gradient-boostined machines." Arxiv 2020.
"""
function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, x::AbstractMatrix, y::AbstractVector{Int}; iterations = 10, subset = 0.8, verbosity = 0)

size(x,1) != length(y) && throw(DimensionMismatch())
iterations >= 1 && ArgumentError("Number of iterations has to be positive!")

if iterations > 1 && classif isa Deterministic
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
@warn("Classifier is not a probabilistic classifier but number of iterations is > 1.")
elseif iterations == 1 && classif isa Probabilistic
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
@warn("Classifier is probabilistic but number of iterations is equal to one.")
end

N = length(y)
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
K = length(unique(y))

# randomly sub-select training and testing set
Ntrain = round(Int, N*subset)
Ntest = N - Ntrain

ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain+1):N)

# train classifier using XGBoost
fitresult, _ = MLJModelInterface.fit(classif, verbosity, Tables.table(x[train_ids,:]), MLJModelInterface.categorical(y[train_ids]))

xtest = Tables.table(x[test_ids,:])
ytest = view(y, test_ids)

Rstats = map(i -> K*rstar_score(rng, classif, fitresult, xtest, ytest), 1:iterations)
return Rstats
end

function rstar(chn::Chains, model::MLJModelInterface.Supervised; kwargs...)
return rstar(Random.GLOBAL_RNG, chn, model; kwargs...)
end

function rstar(rng::Random.AbstractRNG, chn::Chains, model::MLJModelInterface.Supervised; kwargs...)
nchains = size(chn, 3)
nchains <= 1 && throw(DimensionMismatch())

# collect data
x = Array(chn)
y = repeat(chains(chn); inner = size(chn,1))

return rstar(rng, model, x, y; kwargs...)
end

function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Probabilistic, fitresult, xtest, ytest)
pred = get.(rand.(Ref(rng), MLJModelInterface.predict(classif, fitresult, xtest)))
return mean(((p,y),) -> p == y, zip(pred, ytest))
end

function rstar_score(rng::Random.AbstractRNG, classif::MLJModelInterface.Deterministic, fitresult, xtest, ytest)
pred = MLJModelInterface.predict(classif, fitresult, xtest)
return mean(((p,y),) -> p == y, zip(pred, ytest))
end
5 changes: 3 additions & 2 deletions test/diagnostic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
end

@testset "function tests" begin
tchain = Chains(rand(n_iter, n_name, n_chain), ["a", "b", "c"], Dict(:internals => ["c"]))
tchain = Chains(rand(niter, nparams, nchains), ["a", "b", "c"], Dict(:internals => ["c"]))

# the following tests only check if the function calls work!
@test MCMCChains.diag_all(rand(50, 2), :weiss, 1, 1, 1) != nothing
Expand Down Expand Up @@ -137,9 +137,10 @@ end
end

@testset "sorting" begin
chn_unsorted = Chains(rand(100,3,1), ["2", "1", "3"])
chn_unsorted = Chains(rand(100, nparams, 1), ["2", "1", "3"])
chn_sorted = sort(chn_unsorted)

@test names(chn_sorted) == Symbol.([1, 2, 3])
@test names(chn_unsorted) == Symbol.([2, 1, 3])
end

30 changes: 30 additions & 0 deletions test/rstar_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using MCMCChains
using Tables
using MLJ, MLJModels
using Test

val = rand(1000, 8, 4)
colnames = ["a", "b", "c", "d", "e", "f", "g", "h"]
internal_colnames = ["c", "d", "e", "f", "g", "h"]
chn = Chains(val, colnames, Dict(:internals => internal_colnames))

model = @load XGBoostClassifier()

@testset "R star test" begin

# Compute R* statistic for a mixed chain.
R = rstar(chn, model; iterations = 10)

# Resulting R value should be close to one, i.e. the classifier does not perform better than random guessing.
@test mean(R) ≈ 1 atol=0.1

# Compute R* statistic for a non-mixed chain.
niter = 1000
val = hcat(sin.(1:niter), cos.(1:niter))
val = cat(val, hcat(cos.(1:niter)*100, sin.(1:niter)*100), dims=3)
chn_notmixed = Chains(val)

# Restuling R value should be close to two, i.e. the classifier should be able to learn an almost perfect decision boundary between chains.
R = rstar(chn_notmixed, model; iterations = 10)
@test mean(R) ≈ 2 atol=0.1
end
8 changes: 7 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
using Test

@testset "MCMCChains" begin

# run tests related to rstar statistic
println("Rstar")
@time include("rstar_tests.jl")

# run tests for effective sample size
include("ess_tests.jl")
println("ESS")
@time include("ess_tests.jl")

# run plotting tests
println("Plotting")
Expand Down