-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding implementation of Rstar statistic #238
Conversation
I think it might make sense to implement this statistic in a separate package but not in MCMCChains. Requires is pretty bad and actually not needed anymore in MCMCChains, we removed all optional package dependencies to decrease long compilation and loading times (see #239). It would also make it easier to play around with and support different classifiers and other packages in addition to XGBoost. Some additional high-level comments, without having checked the implementation in detail:
|
Good points, I’ll think about it. However, I’m sceptical it’s worth implementing in a separate repo as it’s literally only a few lines. I agree that it would be good to do it more general, I’m personally not a fan of boosting. Maybe we can use the interfaces of MLJ for that. |
Co-authored-by: David Widmann <[email protected]>
Ok, I agree that it would make sense to have an external repository containing more advanced diagnostics. Let's keep this PR open for now and merge the code into the new repo once we know what goes into the repo. |
Codecov Report
@@ Coverage Diff @@
## master #238 +/- ##
==========================================
+ Coverage 72.59% 72.77% +0.17%
==========================================
Files 19 20 +1
Lines 1215 1245 +30
==========================================
+ Hits 882 906 +24
- Misses 333 339 +6
Continue to review full report at Codecov.
|
@trappmartin and everyone else, thanks for picking this up and glad that you are considering implementing R*. If you have any questions, let me know: @trappmartin, let me know if you continue to have issues with Algo 2 and I can cross check with mine if you give me details of the test being done. |
For now, let's just stick R* in a new package (AdvancedDiagnostics?) and get it registered. Excellent work so far! |
It feels slight overkill to register and maintain a new package for this single feature. Maybe consider adding this to a ’src/contrib’ folder? I appreciate the idea of keeping loading time minimal, but that’s like a language level optimisation task. Therefore, reducing loading time by splitting packages sounds a bit counter-productive IMO. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a closer look at the implementation and added some comments.
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
The problem reported in TuringLang/Turing.jl#1400 shows that providing optional features using Requires not only increases loading times but causes precompilation problems as well. Definitely, the problem of optional/conditional dependencies has to be addressed by Pkg on a global level. However, as long as it is not solved, IMO creating small packages is the more stable approach than adding optional dependencies with Requires. There are also multiple examples of small packages in the Julia ecosystem (e.g., https://github.com/SciML/PoissonRandom.jl, https://github.com/SciML/IfElse.jl, https://github.com/SciML/RuntimeGeneratedFunctions.jl, https://github.com/SciML/MuladdMacro.jl), so it is not uncommon to publish packages with only a handful of functions or one source file. In this case, one could also bundle the Rstar statistic together with other more advanced diagnostics in MCMCChains in a separate package. IMO one should be careful to not add too many dependencies to MCMCChains (as mentioned, we had problems with really massive package loading times before) but actually I'd rather take a dependency on XGBoost than Requires here. (Though I still think it would be nice if one could provide a general implementation that works with different classification algorithms - maybe by using an MLJ interface, as @trappmartin suggested above.) |
No worries, I’ll have a look at the MLJ interface. I just didn’t have the time so far. |
No worries, it was just intended as a general comment regarding optional dependencies and small packages 🙂 |
OK, so this PR is now independent of XGBoost and uses the minimal interface provided in Here is an example using XGBoost. Note that this requires the user to have MLJ, MLJModels and XGBoost installed. julia> using MLJ, MLJModels
julia> using Turing, MCMCChains
julia> # define some Turing model
julia> @model function high_dim_gauss(D)
m ~ filldist(Normal(0, 1), D)
end
julia> # instantiation
julia> model = high_dim_gauss(50)
julia> # sample with different settings
julia> chn1 = sample(model, NUTS(1, 0.8), MCMCDistributed(), 11, 4); # generate 10 samples (adaptation 1)
julia> chn2 = sample(model, NUTS(500, 0.8), MCMCDistributed(), 2000, 4); # generate 1500 samples (adaptation 500)
julia> # load classifier that should be used for the computation
julia> classif = @load XGBoostClassifier
julia> # estimate 10 samples of the statistic (10 is default).
julia> R1 = rstar(chn1, classif)
julia> R2 = rstar(chn2, classif)
julia> # some plotting
julia> using UnicodePlots
julia> boxplot(["N=10", "N=1500"], [R1, R2])
┌ ┐
╷ ┌─────╷
N=10 ├────────┤ ┤
╵ └─────╵
╷┐╷
N=1500 ├├┤
╵┘╵
└ ┘
1 2.5 4
julia> histogram(R2)
┌ ┐
[1.05, 1.1 ) ┤▇▇▇▇▇ 1
[1.1 , 1.15) ┤▇▇▇▇▇ 1
[1.15, 1.2 ) ┤▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 7
[1.2 , 1.25) ┤▇▇▇▇▇ 1
└ ┘
Frequency |
@devmotion thanks for those useful comments! |
Co-authored-by: David Widmann <[email protected]>
This PR should be ready to merge once all tests run through. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just left some additional comments/questions. I like that we can make use of all the models implementing the MLJ interface.
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, great work!
Not it I beat you to it! |
This PR adds an implementation of the R* statistic described in https://arxiv.org/pdf/2003.07900.pdf.
Todo:
Add additional tests.My implementation of Algo. 2 (generating multiple draws of R*) seems to result in rather low values on the unit-test set. Have to inspect if there is a subtle bug.Add documentation on the README.mdReplace XGBoost with MLJ interface if possibleThanks @ben18785 and @avehtari for the nice paper.