Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding implementation of Rstar statistic #238

Merged
merged 20 commits into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"

[targets]
test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots"]
test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots", "XGBoost"]
5 changes: 5 additions & 0 deletions src/MCMCChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ export summarize
export discretediag, gelmandiag, gewekediag, heideldiag, rafterydiag
export hpd, ess

@init @require XGBoost="009559a3-9522-5dbb-924b-0b6ed2b22bb9" @eval begin
include("rstar.jl")
export rstar
end

export ESSMethod, FFTESSMethod, BDAESSMethod

"""
Expand Down
93 changes: 93 additions & 0 deletions src/rstar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import XGBoost
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
"""
rstar(chains::Chains; subset = 0.8, niter = 1_000, eta = 0.5, XGBoostParams)
rstar(chains::Chains, iterations::Int; subset = 0.8, niter = 1_000, eta = 0.5, XGBoostParams)
rstar(x::AbstractMatrix, y::AbstractVector, nchains::Int, iterations::Int; subset = 0.8, niter = 1_000, eta = 0.5, XGBoostParams)

Compute the R* statistic for convergence diagnostic of MCMC. This implementation is an adaption of Algorithm 1 & 2, described in [Lambert & Vehtari]. Note that the correctness of the statistic depends on the convergence of the classifier used internally in the statistic. You can track if the training of the classifier converged by inspection of the printed RMSE values from the XGBoost backend. To adjust the number of iterations used to train the classifier set `niter` accordingly.

# Usage

```julia-repl
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
using XGBoost
# You need to load XGBoost before using MCMCChains.Rstar

...

chn = ...

# Compute R⋆ using defaults settings for the gradient boosting classifier used to compute the statistic.
# This is the recomended use.
R = rstar(chn)

# Compute 100 samples of the R⋆ statistic using sampling from according to the prediction probabilities.
# This approach can be slow and results in a less accurate estimation of the R* statistic.
# See discussion in Section 3.1.3 in the paper.
Rs = rstar(chn, 100)

# estimate Rstar
R = mean(Rs)

# visualize distribution
histogram(Rs)
```

## References:
[Lambert & Vehtari] Ben Lambert and Aki Vehtari. "R∗: A robust MCMC convergence diagnostic with uncertainty using gradient-boostined machines." Arxiv 2020.
"""
function rstar(x::AbstractMatrix, y::AbstractVector, nchains::Int, iterations::Int; subset = 0.8, niter = 1_000, eta = 0.5, xgboostparams...)
trappmartin marked this conversation as resolved.
Show resolved Hide resolved

N = length(y)
trappmartin marked this conversation as resolved.
Show resolved Hide resolved

# randomly sub-select training and testing set
Ntrain = round(Int, N*subset)
Ntest = N - Ntrain

ids = Random.shuffle(collect(1:N))
train_ids = ids[1:Ntrain]
test_ids = ids[(Ntrain+1):end]
trappmartin marked this conversation as resolved.
Show resolved Hide resolved

@assert length(test_ids) == Ntest
trappmartin marked this conversation as resolved.
Show resolved Hide resolved

# use predicted probabilities?
mode = iterations > 1 ? "multi:softprob" : "multi:softmax"

@info "Training classifier"
# train classifier using XGBoost
classif = XGBoost.xgboost(x[train_ids,:], niter; label = y[train_ids],
objective = mode, num_class = nchains,
xgboostparams...)
trappmartin marked this conversation as resolved.
Show resolved Hide resolved

@info "Computing R* statistics"
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
Rstats = ones(iterations) * Inf
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
for i in 1:iterations
# predict labels for "test" data
p = XGBoost.predict(classif, x[test_ids,:])
trappmartin marked this conversation as resolved.
Show resolved Hide resolved

pred = if length(p) == Ntest*nchains
probs = reshape(p, Ntest, nchains)
map(s -> rand(Categorical(s / sum(s))), eachrow(probs))
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
else
p
end
trappmartin marked this conversation as resolved.
Show resolved Hide resolved

# compute statistic
a = mean(pred .== y[test_ids])
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
Rstats[i] = nchains*a
end

return Rstats
end

function rstar(chn::Chains, iterations; kwargs...)
nchains = length(chains(chn))
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
@assert nchains > 1
trappmartin marked this conversation as resolved.
Show resolved Hide resolved

# collect data
x = mapreduce(c -> Array(chn[:,:,c]), vcat, chains(chn))
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
y = mapreduce(c -> ones(Int, length(chn))*c, vcat, chains(chn)) .- 1
trappmartin marked this conversation as resolved.
Show resolved Hide resolved

return rstar(x, y, nchains, iterations, kwargs...)
end

rstar(chn::Chains; kwargs...) = first(rstar(chn, 1; kwargs...))
20 changes: 20 additions & 0 deletions test/diagnostic_tests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using XGBoost
using MCMCChains
using AbstractMCMC: AbstractChains
using Test
Expand Down Expand Up @@ -143,3 +144,22 @@ end
@test names(chn_sorted) == Symbol.([1, 2, 3])
@test names(chn_unsorted) == Symbol.([2, 1, 3])
end

@testset "R-star test" begin

niter = 4000
nparams = 2
nchains = 4

# some sample experiment results
val = randn(niter, nparams, nchains) .+ [1, 2]'
val = hcat(val, rand(1:2, niter, 1, nchains))

# construct a Chains object
chn = Chains(val, start = 1, thin = 2)

# compute r star statistic using 1k iterations of training
R = rstar(chn)

@test R ≈ 1.0 atol=0.1
end