Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding implementation of Rstar statistic #238

Merged
merged 20 commits into from
Sep 17, 2020
Merged

Adding implementation of Rstar statistic #238

merged 20 commits into from
Sep 17, 2020

Conversation

trappmartin
Copy link
Member

@trappmartin trappmartin commented Sep 3, 2020

This PR adds an implementation of the R* statistic described in https://arxiv.org/pdf/2003.07900.pdf.

Todo:

  • Add additional tests.
  • My implementation of Algo. 2 (generating multiple draws of R*) seems to result in rather low values on the unit-test set. Have to inspect if there is a subtle bug.
  • Add documentation on the README.md
  • Replace XGBoost with MLJ interface if possible

Thanks @ben18785 and @avehtari for the nice paper.

@devmotion
Copy link
Member

I think it might make sense to implement this statistic in a separate package but not in MCMCChains. Requires is pretty bad and actually not needed anymore in MCMCChains, we removed all optional package dependencies to decrease long compilation and loading times (see #239). It would also make it easier to play around with and support different classifiers and other packages in addition to XGBoost. Some additional high-level comments, without having checked the implementation in detail:

  • I think it would be good to implement statistics for general array-type inputs instead of Chains, and only add a wrapper/dispatch for Chains if needed. We already changed this for a bunch of statistics in MCMCChains since then the implementation of the statistics can be focused on the actual algorithm and can be reused for other chain types (or just arrays). Moreover, it usually simplified the code and made working with Chains more efficient (since, e.g., the relevant information can be extracted once in the beginning instead of subsetting the chain multiple times).
  • Following standard Julia conventions, one might want to name the function rstar instead of Rstar.
  • As mentioned above, I think it would be nice to write the implementation in a generic way that supports different classifiers/classifier packages.

@trappmartin
Copy link
Member Author

Good points, I’ll think about it. However, I’m sceptical it’s worth implementing in a separate repo as it’s literally only a few lines. I agree that it would be good to do it more general, I’m personally not a fan of boosting. Maybe we can use the interfaces of MLJ for that.

src/rstar.jl Outdated Show resolved Hide resolved
Co-authored-by: David Widmann <[email protected]>
@trappmartin
Copy link
Member Author

Ok, I agree that it would make sense to have an external repository containing more advanced diagnostics. Let's keep this PR open for now and merge the code into the new repo once we know what goes into the repo.

@codecov
Copy link

codecov bot commented Sep 3, 2020

Codecov Report

Merging #238 into master will increase coverage by 0.17%.
The diff coverage is 83.33%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #238      +/-   ##
==========================================
+ Coverage   72.59%   72.77%   +0.17%     
==========================================
  Files          19       20       +1     
  Lines        1215     1245      +30     
==========================================
+ Hits          882      906      +24     
- Misses        333      339       +6     
Impacted Files Coverage Δ
src/MCMCChains.jl 100.00% <ø> (ø)
src/rstar.jl 83.33% <83.33%> (ø)
src/rafterydiag.jl 92.85% <0.00%> (-4.77%) ⬇️
src/heideldiag.jl 100.00% <0.00%> (+3.44%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update c560033...44e5604. Read the comment docs.

@ben18785
Copy link

ben18785 commented Sep 3, 2020

@trappmartin and everyone else, thanks for picking this up and glad that you are considering implementing R*. If you have any questions, let me know: @trappmartin, let me know if you continue to have issues with Algo 2 and I can cross check with mine if you give me details of the test being done.

@cpfiffer
Copy link
Member

cpfiffer commented Sep 3, 2020

For now, let's just stick R* in a new package (AdvancedDiagnostics?) and get it registered. Excellent work so far!

@yebai
Copy link
Member

yebai commented Sep 3, 2020

It feels slight overkill to register and maintain a new package for this single feature. Maybe consider adding this to a ’src/contrib’ folder?

I appreciate the idea of keeping loading time minimal, but that’s like a language level optimisation task. Therefore, reducing loading time by splitting packages sounds a bit counter-productive IMO.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a closer look at the implementation and added some comments.

src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

I appreciate the idea of keeping loading time minimal, but that’s like a language level optimisation task. Therefore, reducing loading time by splitting packages sounds a bit counter-productive IMO.

The problem reported in TuringLang/Turing.jl#1400 shows that providing optional features using Requires not only increases loading times but causes precompilation problems as well. Definitely, the problem of optional/conditional dependencies has to be addressed by Pkg on a global level. However, as long as it is not solved, IMO creating small packages is the more stable approach than adding optional dependencies with Requires. There are also multiple examples of small packages in the Julia ecosystem (e.g., https://github.com/SciML/PoissonRandom.jl, https://github.com/SciML/IfElse.jl, https://github.com/SciML/RuntimeGeneratedFunctions.jl, https://github.com/SciML/MuladdMacro.jl), so it is not uncommon to publish packages with only a handful of functions or one source file.

In this case, one could also bundle the Rstar statistic together with other more advanced diagnostics in MCMCChains in a separate package. IMO one should be careful to not add too many dependencies to MCMCChains (as mentioned, we had problems with really massive package loading times before) but actually I'd rather take a dependency on XGBoost than Requires here. (Though I still think it would be nice if one could provide a general implementation that works with different classification algorithms - maybe by using an MLJ interface, as @trappmartin suggested above.)

@trappmartin
Copy link
Member Author

No worries, I’ll have a look at the MLJ interface. I just didn’t have the time so far.

@devmotion
Copy link
Member

No worries, it was just intended as a general comment regarding optional dependencies and small packages 🙂

@trappmartin
Copy link
Member Author

OK, so this PR is now independent of XGBoost and uses the minimal interface provided in MLJModelInterface. Thus, the user can effectively use any MLJ supported classifier to run the statistic on inference results.

Here is an example using XGBoost. Note that this requires the user to have MLJ, MLJModels and XGBoost installed.

julia> using MLJ, MLJModels
julia> using Turing, MCMCChains

julia> # define some Turing model
julia> @model function high_dim_gauss(D)
    m ~ filldist(Normal(0, 1), D)
end

julia> # instantiation
julia> model = high_dim_gauss(50)

julia> # sample with different settings
julia> chn1 = sample(model, NUTS(1, 0.8), MCMCDistributed(), 11, 4); # generate 10 samples (adaptation 1)
julia> chn2 = sample(model, NUTS(500, 0.8), MCMCDistributed(), 2000, 4); # generate 1500 samples (adaptation 500)

julia> # load classifier that should be used for the computation
julia> classif = @load XGBoostClassifier

julia> # estimate 10 samples of the statistic (10 is default).
julia> R1 = rstar(chn1, classif)
julia> R2 = rstar(chn2, classif)

julia> # some plotting
julia> using UnicodePlots
julia> boxplot(["N=10", "N=1500"], [R1, R2])
          ┌                                        ┐
                                   ╷        ┌─────╷
     N=10                          ├────────┤     ┤
                                   ╵        └─────╵
           ╷┐╷
   N=1500  ├├┤
           ╵┘╵
          └                                        ┘
          1                  2.5                   4

julia> histogram(R2)
                ┌                                        ┐
   [1.05, 1.1 ) ┤▇▇▇▇▇ 1
   [1.1 , 1.15) ┤▇▇▇▇▇ 1
   [1.15, 1.2 ) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 7
   [1.2 , 1.25) ┤▇▇▇▇▇ 1
                └                                        ┘
                                Frequency

@trappmartin trappmartin self-assigned this Sep 15, 2020
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
@trappmartin
Copy link
Member Author

@devmotion thanks for those useful comments!

Co-authored-by: David Widmann <[email protected]>
@trappmartin
Copy link
Member Author

This PR should be ready to merge once all tests run through.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just left some additional comments/questions. I like that we can make use of all the models implementing the MLJ interface.

src/MCMCChains.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
trappmartin and others added 2 commits September 16, 2020 14:10
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
src/rstar.jl Outdated Show resolved Hide resolved
src/rstar.jl Outdated Show resolved Hide resolved
trappmartin and others added 2 commits September 16, 2020 21:02
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Project.toml Outdated Show resolved Hide resolved
trappmartin and others added 2 commits September 17, 2020 12:38
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great work!

@trappmartin
Copy link
Member Author

I'll merge this in approx. 6h if no one stops me before. :)

cc: @yebai and @cpfiffer

@cpfiffer
Copy link
Member

Not it I beat you to it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants