Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding implementation of Rstar statistic #238

Merged
merged 20 commits into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -47,9 +48,12 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"

[targets]
test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots"]
test = ["DataFrames", "FFTW", "KernelDensity", "Logging", "StatsPlots", "Test", "UnicodePlots", "MLJ", "MLJModels", "XGBoost"]
5 changes: 4 additions & 1 deletion src/MCMCChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using SpecialFunctions
using Formatting
import StatsBase: autocov, counts, sem, AbstractWeights,
autocor, describe, quantile, sample, summarystats, cov

using MLJModelInterface
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
import NaturalSort
import PrettyTables
import Tables
Expand All @@ -36,6 +36,8 @@ export summarize
export discretediag, gelmandiag, gewekediag, heideldiag, rafterydiag
export hpd, ess

export rstar

export ESSMethod, FFTESSMethod, BDAESSMethod

"""
Expand Down Expand Up @@ -73,5 +75,6 @@ include("stats.jl")
include("modelstats.jl")
include("plot.jl")
include("tables.jl")
include("rstar.jl")

end # module
125 changes: 125 additions & 0 deletions src/rstar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Data structure required to store samples for MLJ.
struct RStarTable{T<:AbstractFloat}
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
data::Matrix{T}
names::Vector{Symbol}
end
RStarTable(chn::Chains) = RStarTable(Array(chn), names(chn, :parameters))

Base.size(table::RStarTable) = size(table.data)
Base.size(table::RStarTable, i::Integer) = size(table.data, i)
Base.getindex(table::RStarTable, i::Integer) = RStarTable(table.data[i,:], table.names)
Base.getindex(table::RStarTable, i::AbstractVector{<:Integer}) = RStarTable(table.data[i,:], table.names)

Tables.istable(::Type{<:RStarTable}) = true

Tables.names(table::RStarTable) = table.names
Tables.matrix(table::RStarTable) = table.data
Tables.schema(table::RStarTable{T}) where {T} = Tables.Schema(table.names, fill(eltype(T), size(table.data,2)))

Tables.columnaccess(::Type{<:RStarTable}) = true
Tables.columns(table::RStarTable) = table

Tables.columnnames(table::RStarTable) = Tables.names(table)
Tables.getcolumn(table::RStarTable, nm::Symbol) = table.data[:,findfirst(table.names.==nm)]
Tables.getcolumn(table::RStarTable, i::Int) = table.data[:,i]

"""
rstar(chains::Chains, model::MLJModelInterface.Supervised; subset = 0.8, iterations = 10, verbosity = 0)
rstar(rng::Random.AbstractRNG, chains::Chains, model::MLJModelInterface.Supervised; subset = 0.8, iterations = 10, verbosity = 0)
rstar(rng::Random.AbstractRNG, model::MLJModelInterface.Supervised, x::MCMCChains.RStarTable, y::AbstractVector; subset = 0.8, iterations = 10, verbosity = 0)

Compute the R* statistic for convergence diagnostic of MCMC. This implementation is an adaption of Algorithm 1 & 2, described in [Lambert & Vehtari]. Note that the correctness of the statistic depends on the convergence of the classifier used internally in the statistic. You can track if the training of the classifier converged by inspection of the printed RMSE values from the XGBoost backend. To adjust the number of iterations used to train the classifier set `niter` accordingly.

# Keyword Arguments
* `subset = 0.8` ... Subset used to train the classifier, i.e. 0.8 implies 80% of the samples are used.
* `iterations = 10` ... Number of iterations used to estimate the statistic. If the classifier is not probabilistic, i.e. does not return class probabilities, it is advisable to use a value of one.
* `verbosity = 0` ... Verbosity level used during fitting of the classifier.

# Usage

```julia-repl
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
using MLJ, MLJModels
# You need to load MLJBase and the respective package your are using for classification first.

# Select a classification model to compute the Rstar statistic.
# For example the XGBoost classifier.
model = @load XGBoostClassifier()

# Compute 100 samples of the R* statistic using sampling from according to the prediction probabilities.
Rs = rstar(chn, model, iterations = 20)

# estimate Rstar
R = mean(Rs)

# visualize distribution
histogram(Rs)
```

## References:
[Lambert & Vehtari] Ben Lambert and Aki Vehtari. "R∗: A robust MCMC convergence diagnostic with uncertainty using gradient-boostined machines." Arxiv 2020.
"""
function rstar(rng::Random.AbstractRNG, classif::MLJModelInterface.Supervised, x::RStarTable, y::AbstractVector{Int}; iterations = 10, subset = 0.8, verbosity = 0)

size(x,1) != length(y) && throw(DimensionMismatch())
iterations >= 1 && ArgumentError("Number of iterations has to be positive!")

if iterations > 1 && classif isa Deterministic
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
@warn("Classifier is not a probabilistic classifier but number of iterations is > 1.")
elseif iterations == 1 && classif isa Probabilistic
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
@warn("Classifier is probabilistic but number of iterations is equal to one.")
end

N = length(y)
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
K = length(unique(y))

# randomly sub-select training and testing set
Ntrain = round(Int, N*subset)
Ntest = N - Ntrain

ids = Random.randperm(rng, N)
train_ids = view(ids, 1:Ntrain)
test_ids = view(ids, (Ntrain+1):N)

# train classifier using XGBoost
fitresult, _ = MLJModelInterface.fit(classif, verbosity, x[train_ids], MLJModelInterface.categorical(y[train_ids]))

Rstats = fill(Inf, iterations)
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
xtest = x[test_ids]
ytest = MLJModelInterface.categorical(view(y, test_ids))

for i in 1:iterations

# predict labels for "test" data
p = MLJModelInterface.predict(classif, fitresult, xtest)
pred = similar(ytest)

if classif isa Deterministic
pred[:] = p
elseif classif isa Probabilistic
pred[:] = get.(rand.(rng, p))
else
throw(ErrorException("Unknown type of classifier"))
end
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
trappmartin marked this conversation as resolved.
Show resolved Hide resolved

# compute statistic
a = mean(((p,y),) -> p == y, zip(pred, ytest))
Rstats[i] = K*a
end

return Rstats
end

function rstar(chn::Chains, model::MLJModelInterface.Supervised; kwargs...)
return rstar(Random.default_rng(), chn, model; kwargs...)
trappmartin marked this conversation as resolved.
Show resolved Hide resolved
end

function rstar(rng::Random.AbstractRNG, chn::Chains, model::MLJModelInterface.Supervised; kwargs...)
nchains = size(chn, 3)
nchains <= 1 && throw(DimensionMismatch())

# collect data
x = RStarTable(chn)
y = repeat(chains(chn); inner = size(chn,1))

return rstar(rng, model, x, y; kwargs...)
end
5 changes: 3 additions & 2 deletions test/diagnostic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
end

@testset "function tests" begin
tchain = Chains(rand(n_iter, n_name, n_chain), ["a", "b", "c"], Dict(:internals => ["c"]))
tchain = Chains(rand(niter, nparams, nchains), ["a", "b", "c"], Dict(:internals => ["c"]))

# the following tests only check if the function calls work!
@test MCMCChains.diag_all(rand(50, 2), :weiss, 1, 1, 1) != nothing
Expand Down Expand Up @@ -137,9 +137,10 @@ end
end

@testset "sorting" begin
chn_unsorted = Chains(rand(100,3,1), ["2", "1", "3"])
chn_unsorted = Chains(rand(100, nparams, 1), ["2", "1", "3"])
chn_sorted = sort(chn_unsorted)

@test names(chn_sorted) == Symbol.([1, 2, 3])
@test names(chn_unsorted) == Symbol.([2, 1, 3])
end

39 changes: 39 additions & 0 deletions test/rstar_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using MCMCChains
using Tables
using MLJ, MLJModels
using Test

val = rand(1000, 8, 4)
colnames = ["a", "b", "c", "d", "e", "f", "g", "h"]
internal_colnames = ["c", "d", "e", "f", "g", "h"]
chn = Chains(val, colnames, Dict(:internals => internal_colnames))

model = @load XGBoostClassifier()

@testset "RStarTable test" begin
t = MCMCChains.RStarTable(chn)

@test Tables.istable(typeof(t))
@test Tables.columnaccess(typeof(t))
@test Tables.matrix(t) === t.data
@test t[[1,2]] isa MCMCChains.RStarTable
end

@testset "R star test" begin

# Compute R* statistic for a mixed chain.
R = rstar(chn, model; iterations = 10)

# Resulting R value should be close to one, i.e. the classifier does not perform better than random guessing.
@test mean(R) ≈ 1 atol=0.1

# Compute R* statistic for a non-mixed chain.
niter = 1000
val = hcat(sin.(1:niter), cos.(1:niter))
val = cat(val, hcat(cos.(1:niter)*100, sin.(1:niter)*100), dims=3)
chn_notmixed = Chains(val)

# Restuling R value should be close to two, i.e. the classifier should be able to learn an almost perfect decision boundary between chains.
R = rstar(chn_notmixed, model; iterations = 10)
@test mean(R) ≈ 2 atol=0.1
end
8 changes: 7 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
using Test

@testset "MCMCChains" begin

# run tests related to rstar statistic
println("Rstar")
@time include("rstar_tests.jl")

# run tests for effective sample size
include("ess_tests.jl")
println("ESS")
@time include("ess_tests.jl")

# run plotting tests
println("Plotting")
Expand Down