Skip to content

Commit

Permalink
Merge pull request #4 from jahan-ai/early_stopping_rounds
Browse files Browse the repository at this point in the history
Moved log message parsing to update! instead of updateone and converted various warnings to exceptions
  • Loading branch information
wilan-wong-1 authored Nov 9, 2023
2 parents 3bee176 + b23bffd commit 375fbb0
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 81 deletions.
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ update!(bst, (X, y), num_round=10)
To help prevent overfitting to the training set, it is helpful to use a validation set to evaluate against to ensure that the XGBoost iterations continue to generalise outside training loss reduction. Early stopping provides a convenient way to automatically stop the
boosting process if it's observed that the generalisation capability of the model does not improve for `k` rounds.

If there is more than one element in watchlist, by default the last element will be used. This makes it important to use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary; as you might not be guaranteed deterministic behaviour. There will be
If there is more than one element in watchlist, by default the last element will be used. In this case, you must use an ordered data structure (`OrderedDict`) compared to a standard unordered dictionary otherwise an exception will be generated. There will be
a warning if you want to execute early stopping mechanism (`early_stopping_rounds > 0`) but have provided a watchlist with type `Dict` with
more than 1 element.

Expand Down
74 changes: 33 additions & 41 deletions src/booster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,8 @@ function updateone!(b::Booster, Xy::DMatrix;
update_feature_names::Bool=false,
)
xgbcall(XGBoosterUpdateOneIter, b.handle, round_number, Xy.handle)
# obtain the logs if watchlist is present (for early stopping and/or info)
if isempty(watchlist)
msg = nothing
else
msg = evaliter(b, watchlist, round_number)
@info msg
end
#isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
_maybe_update_feature_names!(b, Xy, update_feature_names)
b, msg
b
end

function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::AbstractVector{<:Real};
Expand All @@ -394,16 +386,8 @@ function updateone!(b::Booster, Xy::DMatrix, g::AbstractVector{<:Real}, h::Abstr
g = convert(Vector{Cfloat}, g)
h = convert(Vector{Cfloat}, h)
xgbcall(XGBoosterBoostOneIter, b.handle, Xy.handle, g, h, length(g))
# obtain the logs if watchlist is present (for early stopping and/or info)
if isempty(watchlist)
msg = nothing
else
msg = evaliter(b, watchlist, round_number)
@info msg
end
#isempty(watchlist) || (msg = evaliter(b, watchlist, round_number))
_maybe_update_feature_names!(b, Xy, update_feature_names)
b, msg
b
end

"""
Expand Down Expand Up @@ -458,20 +442,26 @@ function update!(b::Booster, data, a...;
for j 1:num_round
round_number = getnrounds(b) + 1

b, msg = updateone!(b, data, a...; round_number, watchlist, kw...)
if !isempty(watchlist) && early_stopping_rounds > 0
score, dataset, metric = extract_metric_value(msg)
if (maximize && score > best_score || (!maximize && score < best_score))
best_score = score
best_round = j
elseif j - best_round >= early_stopping_rounds
@info(
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds."
)
# add additional fields to record the best iteration
b.best_iteration = best_round
b.best_score = best_score
return b
updateone!(b, data, a...; round_number, watchlist, kw...)

# Evaluate if watchlist is not empty
if !isempty(watchlist)
msg = evaliter(b, watchlist, round_number)
@info msg
if early_stopping_rounds > 0
score, dataset, metric = extract_metric_value(msg)
if (maximize && score > best_score || (!maximize && score < best_score))
best_score = score
best_round = j
elseif j - best_round >= early_stopping_rounds
@info(
"Xgboost: Stopping. \n\tBest iteration: $best_round. \n\tNo improvement in $dataset-$metric result in $early_stopping_rounds rounds."
)
# add additional fields to record the best iteration
b.best_iteration = best_round
b.best_score = best_score
return b
end
end
end
end
Expand Down Expand Up @@ -528,10 +518,10 @@ function extract_metric_value(msg, dataset=nothing, metric=nothing)
if match_result != nothing
parsed_value = parse(Float64, match_result.captures[1])
return parsed_value, dataset, metric
else
@warn "No match found for pattern: $dataset-$metric in message: $msg"
return nothing
end

# there was no match result - should error out
error("No match found for pattern: $dataset-$metric in message: $msg")
end

"""
Expand All @@ -543,13 +533,15 @@ This is essentially an alias for constructing a [`Booster`](@ref) with `data` an
followed by [`update!`](@ref) for `nrounds`.
`watchlist` is a dict the keys of which are strings giving the name of the data to watch
and the values of which are [`DMatrix`](@ref) objects containing the data. It is critical to use an OrderedDict
when utilising early_stopping_rounds to ensure XGBoost uses the correct and intended dataset to perform early stop.
and the values of which are [`DMatrix`](@ref) objects containing the data. It is mandatory to use an OrderedDict
when utilising early_stopping_rounds and there is more than 1 element in watchlist to ensure XGBoost uses the
correct and intended dataset to perform early stop.
`early_stopping_rounds` activates early stopping if set to > 0. Validation metric needs to improve at
least once in every k rounds. If `watchlist` is not explicitly provided, it will use the training dataset
to evaluate the stopping criterion. Otherwise, it will use the last data element in `watchlist` and the
last metric in `eval_metric` (if more than one). Note that early stopping is ignored if `watchlist` is empty.
last metric in `eval_metric` (if more than one). Note that `watchlist` cannot be empty if
`early_stopping_rounds` is enabled.
`maximize` If early_stopping_rounds is set, then this parameter must be set as well.
When it is false, it means the smaller the evaluation score the better. When set to true,
Expand Down Expand Up @@ -585,7 +577,7 @@ ŷ = predict(b, dvalid, ntree_limit = b.best_iteration)
"""
function xgboost(dm::DMatrix, a...;
num_round::Integer=10,
watchlist::Any = Dict("train" => dm),
watchlist::AbstractDict = Dict("train" => dm),
early_stopping_rounds::Integer=0,
maximize=false,
kw...
Expand All @@ -597,12 +589,12 @@ function xgboost(dm::DMatrix, a...;
# We have a watchlist - give a warning if early stopping is provided and watchlist is a Dict type with length > 1
if isa(watchlist, Dict)
if early_stopping_rounds > 0 && length(watchlist) > 1
@warn "Early stopping rounds activated whilst watchlist has more than 1 element. Recommended to provide watchlist as an OrderedDict to ensure deterministic behaviour."
error("You must supply an OrderedDict type for watchlist if early stopping rounds is enabled and there is more than one element in watchlist.")
end
end

if isempty(watchlist) && early_stopping_rounds > 0
@warn "Early stopping is ignored as provided watchlist is empty."
error("Watchlist must be supplied if early_stopping_rounds is enabled.")
end

isempty(watchlist) || @info("XGBoost: starting training.")
Expand Down
77 changes: 38 additions & 39 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,23 +146,30 @@ end
objective="binary:logistic",
eval_metric=["rmsle","rmse"]
)

# test if it ran all the way till the end (baseline)
nrounds_bst = XGBoost.getnrounds(bst)
@test nrounds_bst == 30

let err = nothing
try
# Check to see that xgboost will error out when watchlist supplied is a dictionary with early_stopping_rounds enabled
bst_early_stopping = xgboost(dtrain,
num_round=30,
watchlist=watchlist,
η=1,
objective="binary:logistic",
eval_metric=["rmsle","rmse"],
early_stopping_rounds = 2
)

bst_early_stopping = xgboost(dtrain,
num_round=30,
watchlist=watchlist,
η=1,
objective="binary:logistic",
eval_metric=["rmsle","rmse"],
early_stopping_rounds = 2
)

nrounds_bst = XGBoost.getnrounds(bst)
nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping)
# Check to see that running with early stopping results in less than or equal rounds
@test nrounds_bst_early_stopping <= nrounds_bst
nrounds_bst = XGBoost.getnrounds(bst)
nrounds_bst_early_stopping = XGBoost.getnrounds(bst_early_stopping)
catch err
end

# Check number of rounds > early stopping rounds
@test nrounds_bst_early_stopping > 2
@test err isa Exception
end

# test the early stopping rounds interface with an OrderedDict data type in the watchlist
watchlist_ordered = OrderedDict("train"=>dtrain, "eval"=>dtest)
Expand All @@ -176,6 +183,8 @@ end
early_stopping_rounds = 2
)



@test XGBoost.getnrounds(bst_early_stopping) > 2
@test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst

Expand Down Expand Up @@ -220,31 +229,21 @@ end
# ensure that the results are the same (as numerically possible) with the best round
@test abs(bst_early_stopping.best_score - calc_metric) < 1e-9

# test the interface with no watchlist provided (defaults to the training dataset)
bst_early_stopping = xgboost(dtrain,
num_round=30,
η=1,
objective="binary:logistic",
eval_metric=["rmsle","rmse"],
early_stopping_rounds = 2
)

@test XGBoost.getnrounds(bst_early_stopping) > 2
@test XGBoost.getnrounds(bst_early_stopping) <= nrounds_bst

# test the interface with an empty watchlist (no output)
# this should trigger no early stopping rounds
bst_empty_watchlist = xgboost(dtrain,
num_round=30,
η=1,
watchlist = Dict(),
objective="binary:logistic",
eval_metric=["rmsle","rmse"],
early_stopping_rounds = 2 # this should be ignored
)

@test XGBoost.getnrounds(bst_empty_watchlist) == nrounds_bst
# Test the interface with no watchlist provided (it'll default to training watchlist)
let err = nothing
try
bst_early_stopping = xgboost(dtrain,
num_round=30,
η=1,
objective="binary:logistic",
eval_metric=["rmsle","rmse"],
early_stopping_rounds = 2
)
catch err
end

@test !(err isa Exception)
end
end


Expand Down

0 comments on commit 375fbb0

Please sign in to comment.