Skip to content

Commit

Permalink
Merge pull request #1 from jahan-ai/early_stopping_rounds
Browse files Browse the repository at this point in the history
Early stopping rounds
  • Loading branch information
david-sun-1 authored Oct 8, 2023
2 parents d799a79 + f70ee6d commit 7ed00c4
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 3 deletions.
48 changes: 45 additions & 3 deletions src/booster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ function updateone!(b::Booster, Xy::DMatrix;
update_feature_names::Bool=false,
)
xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle)
isempty(watchlist) || logeval(b, watchlist, round_number)
_maybe_update_feature_names!(b, Xy, update_feature_names)
b
end
Expand Down Expand Up @@ -422,10 +421,44 @@ Run `num_round` rounds of gradient boosting on [`Booster`](@ref) `b`.
The first and second derivatives of the loss function (`ℓ′` and `ℓ″` respectively) can be provided
for custom loss.
"""
function update!(b::Booster, data, a...; num_round::Integer=1, kw...)
function update!(b::Booster, data, a...;
num_round::Integer=1,
watchlist=Dict("train"=>Xy),
early_stopping_rounds::Integer=0,
maximize=false,
kw...,
)

if !isempty(watchlist) && early_stopping_rounds > 0
early_stopping_set = collect(Iterators.map(string, keys(watchlist)))[end]
println("Will train until there has been no improvement in $early_stopping_rounds rounds.\n")
best_round = 0
best_score = maximize ? -Inf : Inf
end

for j 1:num_round
round_number = getnrounds(b) + 1
updateone!(b, data, a...; round_number, kw...)
if !isempty(watchlist)
msg = evaliter(b, watchlist, j)
@info(msg)
if early_stopping_rounds > 0
split_msg = split(msg, r"\s+|:")
metric_name_idx = findfirst(occursin.(early_stopping_set, split_msg))
score = parse(Float64, split_msg[metric_name_idx + 1])

if (maximize && score > best_score || (!maximize && score < best_score))
best_score = score
best_round = j
elseif j - best_round >= early_stopping_rounds
watchlist_metric = (split_msg[metric_name_idx])
@info(
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\t$(watchlist_metric) has not improved in $early_stopping_rounds rounds."
)
return (b)
end
end
end
end
b
end
Expand All @@ -441,6 +474,13 @@ followed by [`update!`](@ref) for `nrounds`.
`watchlist` is a dict the keys of which are strings giving the name of the data to watch
and the values of which are [`DMatrix`](@ref) objects containing the data.
`early_stopping_rounds` if 0, the early stopping function is not triggered. If set to a positive integer,
training with a validation set will stop if the performance doesn't improve for k rounds.
`maximize` If early_stopping_rounds is set, then this parameter must be set as well.
When it is false, it means the smaller the evaluation score the better. When set to true,
the larger the evaluation score the better.
All other keyword arguments are passed to [`Booster`](@ref). With few exceptions these are model
training hyper-parameters, see [here](https://xgboost.readthedocs.io/en/stable/parameter.html) for
a comprehensive list.
Expand All @@ -460,12 +500,14 @@ ŷ = predict(b, X)
function xgboost(dm::DMatrix, a...;
num_round::Integer=10,
watchlist=Dict("train"=>dm),
early_stopping_rounds::Integer=0,
maximize=false,
kw...
)
Xy = DMatrix(dm)
b = Booster(Xy; kw...)
isempty(watchlist) || @info("XGBoost: starting training.")
update!(b, Xy, a...; num_round, watchlist)
update!(b, Xy, a...; num_round, watchlist, early_stopping_rounds, maximize)
isempty(watchlist) || @info("Training rounds complete.")
b
end
Expand Down
35 changes: 35 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,41 @@ end
end
end


@testset "Early Stopping rounds" begin

dtrain = XGBoost.load(DMatrix, testfilepath("agaricus.txt.train"), format=:libsvm)
dtest = XGBoost.load(DMatrix, testfilepath("agaricus.txt.test"), format=:libsvm)
watchlist = Dict("eval"=>dtest, "train"=>dtrain)

bst = xgboost(dtrain,
num_round=30,
watchlist=watchlist,
η=1,
objective="binary:logistic",
eval_metric=["rmsle","rmse"]
)

bst_early_stopping = xgboost(dtrain,
num_round=30,
watchlist=watchlist,
η=1,
objective="binary:logistic",
eval_metric=["rmsle","rmse"],
early_stopping_rounds = 2
)

nrounds_bst = XGBoost.getnrounds(bst)
nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping)
# Check to see that running with early stopping results in less rounds
@test nrounds_bst_early_stopping < nrounds_bst

# Check number of rounds > early stopping rounds
@test nrounds_bst_early_stopping > 2
end



@testset "Blobs training" begin
(X, y) = load_classification()

Expand Down

0 comments on commit 7ed00c4

Please sign in to comment.