diff --git a/src/booster.jl b/src/booster.jl index bebb054..2769b89 100644 --- a/src/booster.jl +++ b/src/booster.jl @@ -366,7 +366,6 @@ function updateone!(b::Booster, Xy::DMatrix; update_feature_names::Bool=false, ) xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle) - isempty(watchlist) || logeval(b, watchlist, round_number) _maybe_update_feature_names!(b, Xy, update_feature_names) b end @@ -422,10 +421,44 @@ Run `num_round` rounds of gradient boosting on [`Booster`](@ref) `b`. The first and second derivatives of the loss function (`ℓ′` and `ℓ″` respectively) can be provided for custom loss. """ -function update!(b::Booster, data, a...; num_round::Integer=1, kw...) +function update!(b::Booster, data, a...; + num_round::Integer=1, + watchlist=Dict("train"=>Xy), + early_stopping_rounds::Integer=0, + maximize=false, + kw..., + ) + + if !isempty(watchlist) && early_stopping_rounds > 0 + early_stopping_set = collect(Iterators.map(string, keys(watchlist)))[end] + println("Will train until there has been no improvement in $early_stopping_rounds rounds.\n") + best_round = 0 + best_score = maximize ? -Inf : Inf + end + for j ∈ 1:num_round round_number = getnrounds(b) + 1 updateone!(b, data, a...; round_number, kw...) + if !isempty(watchlist) + msg = evaliter(b, watchlist, j) + @info(msg) + if early_stopping_rounds > 0 + split_msg = split(msg, r"\s+|:") + metric_name_idx = findfirst(occursin.(early_stopping_set, split_msg)) + score = parse(Float64, split_msg[metric_name_idx + 1]) + + if (maximize && score > best_score || (!maximize && score < best_score)) + best_score = score + best_round = j + elseif j - best_round >= early_stopping_rounds + watchlist_metric = (split_msg[metric_name_idx]) + @info( + "Xgboost: Stopping. \n\tBest iteration: $best_round. \n\t$(watchlist_metric) has not improved in $early_stopping_rounds rounds." + ) + return (b) + end + end + end end b end @@ -441,6 +474,13 @@ followed by [`update!`](@ref) for `nrounds`. `watchlist` is a dict the keys of which are strings giving the name of the data to watch and the values of which are [`DMatrix`](@ref) objects containing the data. +`early_stopping_rounds` if 0, the early stopping function is not triggered. If set to a positive integer, +training with a validation set will stop if the performance doesn't improve for k rounds. + +`maximize` If early_stopping_rounds is set, then this parameter must be set as well. +When it is false, it means the smaller the evaluation score the better. When set to true, +the larger the evaluation score the better. + All other keyword arguments are passed to [`Booster`](@ref). With few exceptions these are model training hyper-parameters, see [here](https://xgboost.readthedocs.io/en/stable/parameter.html) for a comprehensive list. @@ -460,12 +500,14 @@ ŷ = predict(b, X) function xgboost(dm::DMatrix, a...; num_round::Integer=10, watchlist=Dict("train"=>dm), + early_stopping_rounds::Integer=0, + maximize=false, kw... ) Xy = DMatrix(dm) b = Booster(Xy; kw...) isempty(watchlist) || @info("XGBoost: starting training.") - update!(b, Xy, a...; num_round, watchlist) + update!(b, Xy, a...; num_round, watchlist, early_stopping_rounds, maximize) isempty(watchlist) || @info("Training rounds complete.") b end diff --git a/test/runtests.jl b/test/runtests.jl index cade6d5..39dd5cf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -130,6 +130,41 @@ end end end + +@testset "Early Stopping rounds" begin + + dtrain = XGBoost.load(DMatrix, testfilepath("agaricus.txt.train"), format=:libsvm) + dtest = XGBoost.load(DMatrix, testfilepath("agaricus.txt.test"), format=:libsvm) + watchlist = Dict("eval"=>dtest, "train"=>dtrain) + + bst = xgboost(dtrain, + num_round=30, + watchlist=watchlist, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"] + ) + + bst_early_stopping = xgboost(dtrain, + num_round=30, + watchlist=watchlist, + η=1, + objective="binary:logistic", + eval_metric=["rmsle","rmse"], + early_stopping_rounds = 2 + ) + + nrounds_bst = XGBoost.getnrounds(bst) + nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping) + # Check to see that running with early stopping results in less rounds + @test nrounds_bst_early_stopping < nrounds_bst + + # Check number of rounds > early stopping rounds + @test nrounds_bst_early_stopping > 2 +end + + + @testset "Blobs training" begin (X, y) = load_classification()